From ceecf1cea92dc2d8c29b1364237ac9467abf2f9b Mon Sep 17 00:00:00 2001 From: Justus Schock <12886177+justusschock@users.noreply.github.com> Date: Fri, 29 May 2020 16:20:04 +0200 Subject: [PATCH] Graceful shutdown on python interpreter exit (#1631) * Fraceful shutdown on python interpreter exit * Update CHANGELOG.md * Update training_loop.py * Update training_loop.py * Update CHANGELOG.md Co-Authored-By: Jirka Borovec * pep8, move to constant * Update training_loop.py * Update training_loop.py * Update training_loop.py * pep8, move to constant * pep8 * timeout Co-authored-by: Jirka Borovec Co-authored-by: Jirka --- .circleci/config.yml | 2 +- CHANGELOG.md | 2 ++ pytorch_lightning/trainer/training_loop.py | 30 +++++++++++++++++++--- 3 files changed, 29 insertions(+), 5 deletions(-) diff --git a/.circleci/config.yml b/.circleci/config.yml index 2237e39423..1cd6ac7a4d 100755 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -22,7 +22,7 @@ references: command: | python --version ; pip --version ; pip list py.test pytorch_lightning tests -v --doctest-modules --junitxml=test-reports/pytest_junit.xml - no_output_timeout: 30m + no_output_timeout: 15m examples: &examples run: diff --git a/CHANGELOG.md b/CHANGELOG.md index 382d4f8ff7..e10d13976f 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -36,6 +36,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Fixed +- Run graceful training teardown on interpreter exit ([#1631](https://github.com/PyTorchLightning/pytorch-lightning/pull/1631)) + - Fixed user warning when apex was used together with learning rate schedulers ([#1873](https://github.com/PyTorchLightning/pytorch-lightning/pull/1873)) - Fixed multiple calls of `EarlyStopping` callback ([#1751](https://github.com/PyTorchLightning/pytorch-lightning/issues/1751)) diff --git a/pytorch_lightning/trainer/training_loop.py b/pytorch_lightning/trainer/training_loop.py index 72fb77065f..ff3ed0e4fe 100644 --- a/pytorch_lightning/trainer/training_loop.py +++ b/pytorch_lightning/trainer/training_loop.py @@ -141,21 +141,23 @@ in your model. """ +import atexit +import signal from abc import ABC, abstractmethod from typing import Callable from typing import Union, List import numpy as np -from torch.utils.data import DataLoader import torch +from torch.utils.data import DataLoader from pytorch_lightning import _logger as log from pytorch_lightning.callbacks.base import Callback from pytorch_lightning.core.lightning import LightningModule from pytorch_lightning.loggers import LightningLoggerBase -from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.trainer.supporters import TensorRunningAccum from pytorch_lightning.utilities import rank_zero_warn +from pytorch_lightning.utilities.exceptions import MisconfigurationException try: from apex import amp @@ -179,9 +181,11 @@ except ImportError: else: HOROVOD_AVAILABLE = True +# constant which signals should be catched for graceful trainer shutdown +SIGNAL_TERMINATE = ('SIGTERM', 'SIGSEGV', 'SIGINT') + class TrainerTrainLoopMixin(ABC): - # this is just a summary on variables used in this abstract class, # the proper values/initialisation should be done in child class max_epochs: int @@ -300,6 +304,15 @@ class TrainerTrainLoopMixin(ABC): """Warning: this is just empty shell for code implemented in other class.""" def train(self): + # add signal handlers for process kills + def _signal_kill_handler(*args): + return TrainerTrainLoopMixin.run_training_teardown(self) + + orig_signal_handlers = {} + for sig_name in SIGNAL_TERMINATE: + orig_signal_handlers[sig_name] = signal.signal(getattr(signal, sig_name), + _signal_kill_handler) + # get model model = self.get_model() @@ -371,6 +384,10 @@ class TrainerTrainLoopMixin(ABC): self.run_training_teardown() + # reset signal handlers + for sig_name in SIGNAL_TERMINATE: + signal.signal(getattr(signal, sig_name), orig_signal_handlers[sig_name]) + except KeyboardInterrupt: if self.proc_rank == 0: log.info('Detected KeyboardInterrupt, attempting graceful shutdown...') @@ -405,7 +422,7 @@ class TrainerTrainLoopMixin(ABC): # run epoch for batch_idx, (batch, is_last_batch) in self.profiler.profile_iterable( - enumerate(_with_is_last(train_dataloader)), "get_train_batch" + enumerate(_with_is_last(train_dataloader)), "get_train_batch" ): # stop epoch if we limited the number of training batches if batch_idx >= self.num_training_batches: @@ -661,7 +678,10 @@ class TrainerTrainLoopMixin(ABC): opt_idx = np.argmax(optimizer_freq_cumsum > current_place_in_loop) return [(opt_idx, self.optimizers[opt_idx])] + @atexit.register def run_training_teardown(self): + if hasattr(self, '_teardown_already_run') and self._teardown_already_run: + return # Train end events with self.profiler.profile('on_train_end'): # callbacks @@ -676,6 +696,8 @@ class TrainerTrainLoopMixin(ABC): # summarize profile results self.profiler.describe() + self._teardown_already_run = True + def training_forward(self, batch, batch_idx, opt_idx, hiddens): """ Handle forward for each training case (distributed, single gpu, etc...)