diff --git a/CHANGELOG.md b/CHANGELOG.md index df6ed5b2d6..efae799ae4 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -21,7 +21,6 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Added -- Remove explicit flush from tensorboard logger ([#2126](https://github.com/PyTorchLightning/pytorch-lightning/pull/2126)) - Added metric Base classes ([#1326](https://github.com/PyTorchLightning/pytorch-lightning/pull/1326), [#1877](https://github.com/PyTorchLightning/pytorch-lightning/pull/1877)) - Added Sklearn metrics classes ([#1327](https://github.com/PyTorchLightning/pytorch-lightning/pull/1327)) - Added Native torch metrics ([#1488](https://github.com/PyTorchLightning/pytorch-lightning/pull/1488)) @@ -35,6 +34,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Added [black](https://black.readthedocs.io/en/stable/) formatter for the code with code-checker on pull ([1610](https://github.com/PyTorchLightning/pytorch-lightning/pull/1610)) - Added back the slow spawn ddp implementation as `ddp_spawn` ([#2115](https://github.com/PyTorchLightning/pytorch-lightning/pull/2115)) - Added loading checkpoints from URLs ([#1667](https://github.com/PyTorchLightning/pytorch-lightning/issues/1667)) +- Added a callback method `on_keyboard_interrupt` for handling KeyboardInterrupt events during training ([#2134](https://github.com/PyTorchLightning/pytorch-lightning/pull/2134)) ### Changed @@ -46,6 +46,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Changed the default value of the Trainer argument `weights_summary` from `full` to `top` ([#2029](https://github.com/PyTorchLightning/pytorch-lightning/pull/2029)) - Raise an error when lightning replaces an existing sampler ([#2020](https://github.com/PyTorchLightning/pytorch-lightning/pull/2020)) - Enabled prepare_data from correct processes - clarify local vs global rank ([#2166](https://github.com/PyTorchLightning/pytorch-lightning/pull/2166)) +- Remove explicit flush from tensorboard logger ([#2126](https://github.com/PyTorchLightning/pytorch-lightning/pull/2126)) ### Deprecated diff --git a/pytorch_lightning/callbacks/base.py b/pytorch_lightning/callbacks/base.py index 50ea061df6..1f21435e63 100644 --- a/pytorch_lightning/callbacks/base.py +++ b/pytorch_lightning/callbacks/base.py @@ -85,3 +85,6 @@ class Callback(abc.ABC): def on_test_end(self, trainer, pl_module): """Called when the test ends.""" pass + + def on_keyboard_interrupt(self, trainer, pl_module): + """Called when the training is interrupted by KeyboardInterrupt.""" diff --git a/pytorch_lightning/trainer/callback_hook.py b/pytorch_lightning/trainer/callback_hook.py index 0ba6a54ddd..6d89ea3abb 100644 --- a/pytorch_lightning/trainer/callback_hook.py +++ b/pytorch_lightning/trainer/callback_hook.py @@ -100,3 +100,8 @@ class TrainerCallbackHookMixin(ABC): """Called when the test ends.""" for callback in self.callbacks: callback.on_test_end(self, self.get_model()) + + def on_keyboard_interrupt(self): + """Called when the training is interrupted by KeyboardInterrupt.""" + for callback in self.callbacks: + callback.on_keyboard_interrupt(self, self.get_model()) diff --git a/pytorch_lightning/trainer/training_loop.py b/pytorch_lightning/trainer/training_loop.py index dac73c6747..a80f3b6283 100644 --- a/pytorch_lightning/trainer/training_loop.py +++ b/pytorch_lightning/trainer/training_loop.py @@ -237,6 +237,7 @@ class TrainerTrainLoopMixin(ABC): checkpoint_callback: ... terminate_on_nan: bool tpu_id: int + interactive_ddp_procs: ... # Callback system callbacks: List[Callback] @@ -247,6 +248,7 @@ class TrainerTrainLoopMixin(ABC): on_epoch_start: Callable on_epoch_end: Callable on_validation_end: Callable + on_keyboard_interrupt: Callable @abstractmethod def get_model(self) -> LightningModule: @@ -395,6 +397,7 @@ class TrainerTrainLoopMixin(ABC): # user could press ctrl+c many times... only shutdown once if not self.interrupted: self.interrupted = True + self.on_keyboard_interrupt() for proc in self.interactive_ddp_procs: subprocess.Popen.kill(proc) diff --git a/tests/trainer/test_trainer.py b/tests/trainer/test_trainer.py index 1e99236e7b..c5965da3c0 100644 --- a/tests/trainer/test_trainer.py +++ b/tests/trainer/test_trainer.py @@ -3,6 +3,7 @@ import math import os import pickle import types +import sys from argparse import Namespace import cloudpickle @@ -10,10 +11,10 @@ import pytest import torch import tests.base.utils as tutils -from pytorch_lightning import Callback, LightningModule -from pytorch_lightning import Trainer +from pytorch_lightning import Callback, LightningModule, Trainer from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint -from pytorch_lightning.core.saving import load_hparams_from_tags_csv, load_hparams_from_yaml, save_hparams_to_tags_csv +from pytorch_lightning.core.saving import ( + load_hparams_from_tags_csv, load_hparams_from_yaml, save_hparams_to_tags_csv) from pytorch_lightning.loggers import TensorBoardLogger from pytorch_lightning.trainer.logging import TrainerLoggingMixin from pytorch_lightning.utilities.io import load as pl_load @@ -660,10 +661,19 @@ def test_trainer_interrupted_flag(tmpdir): def on_batch_start(self, trainer, pl_module): raise KeyboardInterrupt + class HandleInterruptCallback(Callback): + def __init__(self): + super().__init__() + self.exc_info = None + + def on_keyboard_interrupt(self, trainer, pl_module): + self.exc_info = sys.exc_info() + interrupt_callback = InterruptCallback() + handle_interrupt_callback = HandleInterruptCallback() trainer = Trainer( - callbacks=[interrupt_callback], + callbacks=[interrupt_callback, handle_interrupt_callback], max_epochs=1, val_percent_check=0.1, train_percent_check=0.2, @@ -672,8 +682,10 @@ def test_trainer_interrupted_flag(tmpdir): default_root_dir=tmpdir, ) assert not trainer.interrupted + assert handle_interrupt_callback.exc_info is None trainer.fit(model) assert trainer.interrupted + assert isinstance(handle_interrupt_callback.exc_info[1], KeyboardInterrupt) def test_gradient_clipping(tmpdir):