Handle KeyboardInterrupt during training (#2134)

* Handle KeyboardInterrupt during training

Fixes #2079.

* chlog

* Fix whitespace

* Update callback_hook.py

* Update base.py

* Update training_loop.py

* Update test_trainer.py

* Update CHANGELOG.md

Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com>

* Update CHANGELOG.md

* on_keyboard_interrupt

Co-authored-by: Jirka <jirka@pytorchlightning.ai>
Co-authored-by: William Falcon <waf2107@columbia.edu>
Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>
Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com>
This commit is contained in:
Simon-Martin Schröder 2020-06-15 12:35:26 +02:00 committed by GitHub
parent bd3a1f7dd4
commit fd1693e289
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 29 additions and 5 deletions

View File

@ -21,7 +21,6 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
### Added
- Remove explicit flush from tensorboard logger ([#2126](https://github.com/PyTorchLightning/pytorch-lightning/pull/2126))
- Added metric Base classes ([#1326](https://github.com/PyTorchLightning/pytorch-lightning/pull/1326), [#1877](https://github.com/PyTorchLightning/pytorch-lightning/pull/1877))
- Added Sklearn metrics classes ([#1327](https://github.com/PyTorchLightning/pytorch-lightning/pull/1327))
- Added Native torch metrics ([#1488](https://github.com/PyTorchLightning/pytorch-lightning/pull/1488))
@ -35,6 +34,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Added [black](https://black.readthedocs.io/en/stable/) formatter for the code with code-checker on pull ([1610](https://github.com/PyTorchLightning/pytorch-lightning/pull/1610))
- Added back the slow spawn ddp implementation as `ddp_spawn` ([#2115](https://github.com/PyTorchLightning/pytorch-lightning/pull/2115))
- Added loading checkpoints from URLs ([#1667](https://github.com/PyTorchLightning/pytorch-lightning/issues/1667))
- Added a callback method `on_keyboard_interrupt` for handling KeyboardInterrupt events during training ([#2134](https://github.com/PyTorchLightning/pytorch-lightning/pull/2134))
### Changed
@ -46,6 +46,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Changed the default value of the Trainer argument `weights_summary` from `full` to `top` ([#2029](https://github.com/PyTorchLightning/pytorch-lightning/pull/2029))
- Raise an error when lightning replaces an existing sampler ([#2020](https://github.com/PyTorchLightning/pytorch-lightning/pull/2020))
- Enabled prepare_data from correct processes - clarify local vs global rank ([#2166](https://github.com/PyTorchLightning/pytorch-lightning/pull/2166))
- Remove explicit flush from tensorboard logger ([#2126](https://github.com/PyTorchLightning/pytorch-lightning/pull/2126))
### Deprecated

View File

@ -85,3 +85,6 @@ class Callback(abc.ABC):
def on_test_end(self, trainer, pl_module):
"""Called when the test ends."""
pass
def on_keyboard_interrupt(self, trainer, pl_module):
"""Called when the training is interrupted by KeyboardInterrupt."""

View File

@ -100,3 +100,8 @@ class TrainerCallbackHookMixin(ABC):
"""Called when the test ends."""
for callback in self.callbacks:
callback.on_test_end(self, self.get_model())
def on_keyboard_interrupt(self):
"""Called when the training is interrupted by KeyboardInterrupt."""
for callback in self.callbacks:
callback.on_keyboard_interrupt(self, self.get_model())

View File

@ -237,6 +237,7 @@ class TrainerTrainLoopMixin(ABC):
checkpoint_callback: ...
terminate_on_nan: bool
tpu_id: int
interactive_ddp_procs: ...
# Callback system
callbacks: List[Callback]
@ -247,6 +248,7 @@ class TrainerTrainLoopMixin(ABC):
on_epoch_start: Callable
on_epoch_end: Callable
on_validation_end: Callable
on_keyboard_interrupt: Callable
@abstractmethod
def get_model(self) -> LightningModule:
@ -395,6 +397,7 @@ class TrainerTrainLoopMixin(ABC):
# user could press ctrl+c many times... only shutdown once
if not self.interrupted:
self.interrupted = True
self.on_keyboard_interrupt()
for proc in self.interactive_ddp_procs:
subprocess.Popen.kill(proc)

View File

@ -3,6 +3,7 @@ import math
import os
import pickle
import types
import sys
from argparse import Namespace
import cloudpickle
@ -10,10 +11,10 @@ import pytest
import torch
import tests.base.utils as tutils
from pytorch_lightning import Callback, LightningModule
from pytorch_lightning import Trainer
from pytorch_lightning import Callback, LightningModule, Trainer
from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint
from pytorch_lightning.core.saving import load_hparams_from_tags_csv, load_hparams_from_yaml, save_hparams_to_tags_csv
from pytorch_lightning.core.saving import (
load_hparams_from_tags_csv, load_hparams_from_yaml, save_hparams_to_tags_csv)
from pytorch_lightning.loggers import TensorBoardLogger
from pytorch_lightning.trainer.logging import TrainerLoggingMixin
from pytorch_lightning.utilities.io import load as pl_load
@ -660,10 +661,19 @@ def test_trainer_interrupted_flag(tmpdir):
def on_batch_start(self, trainer, pl_module):
raise KeyboardInterrupt
class HandleInterruptCallback(Callback):
def __init__(self):
super().__init__()
self.exc_info = None
def on_keyboard_interrupt(self, trainer, pl_module):
self.exc_info = sys.exc_info()
interrupt_callback = InterruptCallback()
handle_interrupt_callback = HandleInterruptCallback()
trainer = Trainer(
callbacks=[interrupt_callback],
callbacks=[interrupt_callback, handle_interrupt_callback],
max_epochs=1,
val_percent_check=0.1,
train_percent_check=0.2,
@ -672,8 +682,10 @@ def test_trainer_interrupted_flag(tmpdir):
default_root_dir=tmpdir,
)
assert not trainer.interrupted
assert handle_interrupt_callback.exc_info is None
trainer.fit(model)
assert trainer.interrupted
assert isinstance(handle_interrupt_callback.exc_info[1], KeyboardInterrupt)
def test_gradient_clipping(tmpdir):