Handle KeyboardInterrupt during training (#2134)
* Handle KeyboardInterrupt during training Fixes #2079. * chlog * Fix whitespace * Update callback_hook.py * Update base.py * Update training_loop.py * Update test_trainer.py * Update CHANGELOG.md Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com> * Update CHANGELOG.md * on_keyboard_interrupt Co-authored-by: Jirka <jirka@pytorchlightning.ai> Co-authored-by: William Falcon <waf2107@columbia.edu> Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com>
This commit is contained in:
parent
bd3a1f7dd4
commit
fd1693e289
|
@ -21,7 +21,6 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
|
|||
|
||||
### Added
|
||||
|
||||
- Remove explicit flush from tensorboard logger ([#2126](https://github.com/PyTorchLightning/pytorch-lightning/pull/2126))
|
||||
- Added metric Base classes ([#1326](https://github.com/PyTorchLightning/pytorch-lightning/pull/1326), [#1877](https://github.com/PyTorchLightning/pytorch-lightning/pull/1877))
|
||||
- Added Sklearn metrics classes ([#1327](https://github.com/PyTorchLightning/pytorch-lightning/pull/1327))
|
||||
- Added Native torch metrics ([#1488](https://github.com/PyTorchLightning/pytorch-lightning/pull/1488))
|
||||
|
@ -35,6 +34,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
|
|||
- Added [black](https://black.readthedocs.io/en/stable/) formatter for the code with code-checker on pull ([1610](https://github.com/PyTorchLightning/pytorch-lightning/pull/1610))
|
||||
- Added back the slow spawn ddp implementation as `ddp_spawn` ([#2115](https://github.com/PyTorchLightning/pytorch-lightning/pull/2115))
|
||||
- Added loading checkpoints from URLs ([#1667](https://github.com/PyTorchLightning/pytorch-lightning/issues/1667))
|
||||
- Added a callback method `on_keyboard_interrupt` for handling KeyboardInterrupt events during training ([#2134](https://github.com/PyTorchLightning/pytorch-lightning/pull/2134))
|
||||
|
||||
### Changed
|
||||
|
||||
|
@ -46,6 +46,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
|
|||
- Changed the default value of the Trainer argument `weights_summary` from `full` to `top` ([#2029](https://github.com/PyTorchLightning/pytorch-lightning/pull/2029))
|
||||
- Raise an error when lightning replaces an existing sampler ([#2020](https://github.com/PyTorchLightning/pytorch-lightning/pull/2020))
|
||||
- Enabled prepare_data from correct processes - clarify local vs global rank ([#2166](https://github.com/PyTorchLightning/pytorch-lightning/pull/2166))
|
||||
- Remove explicit flush from tensorboard logger ([#2126](https://github.com/PyTorchLightning/pytorch-lightning/pull/2126))
|
||||
|
||||
### Deprecated
|
||||
|
||||
|
|
|
@ -85,3 +85,6 @@ class Callback(abc.ABC):
|
|||
def on_test_end(self, trainer, pl_module):
|
||||
"""Called when the test ends."""
|
||||
pass
|
||||
|
||||
def on_keyboard_interrupt(self, trainer, pl_module):
|
||||
"""Called when the training is interrupted by KeyboardInterrupt."""
|
||||
|
|
|
@ -100,3 +100,8 @@ class TrainerCallbackHookMixin(ABC):
|
|||
"""Called when the test ends."""
|
||||
for callback in self.callbacks:
|
||||
callback.on_test_end(self, self.get_model())
|
||||
|
||||
def on_keyboard_interrupt(self):
|
||||
"""Called when the training is interrupted by KeyboardInterrupt."""
|
||||
for callback in self.callbacks:
|
||||
callback.on_keyboard_interrupt(self, self.get_model())
|
||||
|
|
|
@ -237,6 +237,7 @@ class TrainerTrainLoopMixin(ABC):
|
|||
checkpoint_callback: ...
|
||||
terminate_on_nan: bool
|
||||
tpu_id: int
|
||||
interactive_ddp_procs: ...
|
||||
|
||||
# Callback system
|
||||
callbacks: List[Callback]
|
||||
|
@ -247,6 +248,7 @@ class TrainerTrainLoopMixin(ABC):
|
|||
on_epoch_start: Callable
|
||||
on_epoch_end: Callable
|
||||
on_validation_end: Callable
|
||||
on_keyboard_interrupt: Callable
|
||||
|
||||
@abstractmethod
|
||||
def get_model(self) -> LightningModule:
|
||||
|
@ -395,6 +397,7 @@ class TrainerTrainLoopMixin(ABC):
|
|||
# user could press ctrl+c many times... only shutdown once
|
||||
if not self.interrupted:
|
||||
self.interrupted = True
|
||||
self.on_keyboard_interrupt()
|
||||
|
||||
for proc in self.interactive_ddp_procs:
|
||||
subprocess.Popen.kill(proc)
|
||||
|
|
|
@ -3,6 +3,7 @@ import math
|
|||
import os
|
||||
import pickle
|
||||
import types
|
||||
import sys
|
||||
from argparse import Namespace
|
||||
|
||||
import cloudpickle
|
||||
|
@ -10,10 +11,10 @@ import pytest
|
|||
import torch
|
||||
|
||||
import tests.base.utils as tutils
|
||||
from pytorch_lightning import Callback, LightningModule
|
||||
from pytorch_lightning import Trainer
|
||||
from pytorch_lightning import Callback, LightningModule, Trainer
|
||||
from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint
|
||||
from pytorch_lightning.core.saving import load_hparams_from_tags_csv, load_hparams_from_yaml, save_hparams_to_tags_csv
|
||||
from pytorch_lightning.core.saving import (
|
||||
load_hparams_from_tags_csv, load_hparams_from_yaml, save_hparams_to_tags_csv)
|
||||
from pytorch_lightning.loggers import TensorBoardLogger
|
||||
from pytorch_lightning.trainer.logging import TrainerLoggingMixin
|
||||
from pytorch_lightning.utilities.io import load as pl_load
|
||||
|
@ -660,10 +661,19 @@ def test_trainer_interrupted_flag(tmpdir):
|
|||
def on_batch_start(self, trainer, pl_module):
|
||||
raise KeyboardInterrupt
|
||||
|
||||
class HandleInterruptCallback(Callback):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.exc_info = None
|
||||
|
||||
def on_keyboard_interrupt(self, trainer, pl_module):
|
||||
self.exc_info = sys.exc_info()
|
||||
|
||||
interrupt_callback = InterruptCallback()
|
||||
handle_interrupt_callback = HandleInterruptCallback()
|
||||
|
||||
trainer = Trainer(
|
||||
callbacks=[interrupt_callback],
|
||||
callbacks=[interrupt_callback, handle_interrupt_callback],
|
||||
max_epochs=1,
|
||||
val_percent_check=0.1,
|
||||
train_percent_check=0.2,
|
||||
|
@ -672,8 +682,10 @@ def test_trainer_interrupted_flag(tmpdir):
|
|||
default_root_dir=tmpdir,
|
||||
)
|
||||
assert not trainer.interrupted
|
||||
assert handle_interrupt_callback.exc_info is None
|
||||
trainer.fit(model)
|
||||
assert trainer.interrupted
|
||||
assert isinstance(handle_interrupt_callback.exc_info[1], KeyboardInterrupt)
|
||||
|
||||
|
||||
def test_gradient_clipping(tmpdir):
|
||||
|
|
Loading…
Reference in New Issue