Graceful shutdown on python interpreter exit (#1631)
* Fraceful shutdown on python interpreter exit * Update CHANGELOG.md * Update training_loop.py * Update training_loop.py * Update CHANGELOG.md Co-Authored-By: Jirka Borovec <Borda@users.noreply.github.com> * pep8, move to constant * Update training_loop.py * Update training_loop.py * Update training_loop.py * pep8, move to constant * pep8 * timeout Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> Co-authored-by: Jirka <jirka.borovec@seznam.cz>
This commit is contained in:
parent
3af3f37d43
commit
ceecf1cea9
|
@ -22,7 +22,7 @@ references:
|
|||
command: |
|
||||
python --version ; pip --version ; pip list
|
||||
py.test pytorch_lightning tests -v --doctest-modules --junitxml=test-reports/pytest_junit.xml
|
||||
no_output_timeout: 30m
|
||||
no_output_timeout: 15m
|
||||
|
||||
examples: &examples
|
||||
run:
|
||||
|
|
|
@ -36,6 +36,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
|
|||
|
||||
### Fixed
|
||||
|
||||
- Run graceful training teardown on interpreter exit ([#1631](https://github.com/PyTorchLightning/pytorch-lightning/pull/1631))
|
||||
|
||||
- Fixed user warning when apex was used together with learning rate schedulers ([#1873](https://github.com/PyTorchLightning/pytorch-lightning/pull/1873))
|
||||
|
||||
- Fixed multiple calls of `EarlyStopping` callback ([#1751](https://github.com/PyTorchLightning/pytorch-lightning/issues/1751))
|
||||
|
|
|
@ -141,21 +141,23 @@ in your model.
|
|||
|
||||
"""
|
||||
|
||||
import atexit
|
||||
import signal
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Callable
|
||||
from typing import Union, List
|
||||
|
||||
import numpy as np
|
||||
from torch.utils.data import DataLoader
|
||||
import torch
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
from pytorch_lightning import _logger as log
|
||||
from pytorch_lightning.callbacks.base import Callback
|
||||
from pytorch_lightning.core.lightning import LightningModule
|
||||
from pytorch_lightning.loggers import LightningLoggerBase
|
||||
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
||||
from pytorch_lightning.trainer.supporters import TensorRunningAccum
|
||||
from pytorch_lightning.utilities import rank_zero_warn
|
||||
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
||||
|
||||
try:
|
||||
from apex import amp
|
||||
|
@ -179,9 +181,11 @@ except ImportError:
|
|||
else:
|
||||
HOROVOD_AVAILABLE = True
|
||||
|
||||
# constant which signals should be catched for graceful trainer shutdown
|
||||
SIGNAL_TERMINATE = ('SIGTERM', 'SIGSEGV', 'SIGINT')
|
||||
|
||||
|
||||
class TrainerTrainLoopMixin(ABC):
|
||||
|
||||
# this is just a summary on variables used in this abstract class,
|
||||
# the proper values/initialisation should be done in child class
|
||||
max_epochs: int
|
||||
|
@ -300,6 +304,15 @@ class TrainerTrainLoopMixin(ABC):
|
|||
"""Warning: this is just empty shell for code implemented in other class."""
|
||||
|
||||
def train(self):
|
||||
# add signal handlers for process kills
|
||||
def _signal_kill_handler(*args):
|
||||
return TrainerTrainLoopMixin.run_training_teardown(self)
|
||||
|
||||
orig_signal_handlers = {}
|
||||
for sig_name in SIGNAL_TERMINATE:
|
||||
orig_signal_handlers[sig_name] = signal.signal(getattr(signal, sig_name),
|
||||
_signal_kill_handler)
|
||||
|
||||
# get model
|
||||
model = self.get_model()
|
||||
|
||||
|
@ -371,6 +384,10 @@ class TrainerTrainLoopMixin(ABC):
|
|||
|
||||
self.run_training_teardown()
|
||||
|
||||
# reset signal handlers
|
||||
for sig_name in SIGNAL_TERMINATE:
|
||||
signal.signal(getattr(signal, sig_name), orig_signal_handlers[sig_name])
|
||||
|
||||
except KeyboardInterrupt:
|
||||
if self.proc_rank == 0:
|
||||
log.info('Detected KeyboardInterrupt, attempting graceful shutdown...')
|
||||
|
@ -405,7 +422,7 @@ class TrainerTrainLoopMixin(ABC):
|
|||
|
||||
# run epoch
|
||||
for batch_idx, (batch, is_last_batch) in self.profiler.profile_iterable(
|
||||
enumerate(_with_is_last(train_dataloader)), "get_train_batch"
|
||||
enumerate(_with_is_last(train_dataloader)), "get_train_batch"
|
||||
):
|
||||
# stop epoch if we limited the number of training batches
|
||||
if batch_idx >= self.num_training_batches:
|
||||
|
@ -661,7 +678,10 @@ class TrainerTrainLoopMixin(ABC):
|
|||
opt_idx = np.argmax(optimizer_freq_cumsum > current_place_in_loop)
|
||||
return [(opt_idx, self.optimizers[opt_idx])]
|
||||
|
||||
@atexit.register
|
||||
def run_training_teardown(self):
|
||||
if hasattr(self, '_teardown_already_run') and self._teardown_already_run:
|
||||
return
|
||||
# Train end events
|
||||
with self.profiler.profile('on_train_end'):
|
||||
# callbacks
|
||||
|
@ -676,6 +696,8 @@ class TrainerTrainLoopMixin(ABC):
|
|||
# summarize profile results
|
||||
self.profiler.describe()
|
||||
|
||||
self._teardown_already_run = True
|
||||
|
||||
def training_forward(self, batch, batch_idx, opt_idx, hiddens):
|
||||
"""
|
||||
Handle forward for each training case (distributed, single gpu, etc...)
|
||||
|
|
Loading…
Reference in New Issue