Graceful shutdown on python interpreter exit (#1631)

* Fraceful shutdown on python interpreter exit

* Update CHANGELOG.md

* Update training_loop.py

* Update training_loop.py

* Update CHANGELOG.md

Co-Authored-By: Jirka Borovec <Borda@users.noreply.github.com>

* pep8, move to constant

* Update training_loop.py

* Update training_loop.py

* Update training_loop.py

* pep8, move to constant

* pep8

* timeout

Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>
Co-authored-by: Jirka <jirka.borovec@seznam.cz>
This commit is contained in:
Justus Schock 2020-05-29 16:20:04 +02:00 committed by GitHub
parent 3af3f37d43
commit ceecf1cea9
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 29 additions and 5 deletions

View File

@ -22,7 +22,7 @@ references:
command: |
python --version ; pip --version ; pip list
py.test pytorch_lightning tests -v --doctest-modules --junitxml=test-reports/pytest_junit.xml
no_output_timeout: 30m
no_output_timeout: 15m
examples: &examples
run:

View File

@ -36,6 +36,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
### Fixed
- Run graceful training teardown on interpreter exit ([#1631](https://github.com/PyTorchLightning/pytorch-lightning/pull/1631))
- Fixed user warning when apex was used together with learning rate schedulers ([#1873](https://github.com/PyTorchLightning/pytorch-lightning/pull/1873))
- Fixed multiple calls of `EarlyStopping` callback ([#1751](https://github.com/PyTorchLightning/pytorch-lightning/issues/1751))

View File

@ -141,21 +141,23 @@ in your model.
"""
import atexit
import signal
from abc import ABC, abstractmethod
from typing import Callable
from typing import Union, List
import numpy as np
from torch.utils.data import DataLoader
import torch
from torch.utils.data import DataLoader
from pytorch_lightning import _logger as log
from pytorch_lightning.callbacks.base import Callback
from pytorch_lightning.core.lightning import LightningModule
from pytorch_lightning.loggers import LightningLoggerBase
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.trainer.supporters import TensorRunningAccum
from pytorch_lightning.utilities import rank_zero_warn
from pytorch_lightning.utilities.exceptions import MisconfigurationException
try:
from apex import amp
@ -179,9 +181,11 @@ except ImportError:
else:
HOROVOD_AVAILABLE = True
# constant which signals should be catched for graceful trainer shutdown
SIGNAL_TERMINATE = ('SIGTERM', 'SIGSEGV', 'SIGINT')
class TrainerTrainLoopMixin(ABC):
# this is just a summary on variables used in this abstract class,
# the proper values/initialisation should be done in child class
max_epochs: int
@ -300,6 +304,15 @@ class TrainerTrainLoopMixin(ABC):
"""Warning: this is just empty shell for code implemented in other class."""
def train(self):
# add signal handlers for process kills
def _signal_kill_handler(*args):
return TrainerTrainLoopMixin.run_training_teardown(self)
orig_signal_handlers = {}
for sig_name in SIGNAL_TERMINATE:
orig_signal_handlers[sig_name] = signal.signal(getattr(signal, sig_name),
_signal_kill_handler)
# get model
model = self.get_model()
@ -371,6 +384,10 @@ class TrainerTrainLoopMixin(ABC):
self.run_training_teardown()
# reset signal handlers
for sig_name in SIGNAL_TERMINATE:
signal.signal(getattr(signal, sig_name), orig_signal_handlers[sig_name])
except KeyboardInterrupt:
if self.proc_rank == 0:
log.info('Detected KeyboardInterrupt, attempting graceful shutdown...')
@ -405,7 +422,7 @@ class TrainerTrainLoopMixin(ABC):
# run epoch
for batch_idx, (batch, is_last_batch) in self.profiler.profile_iterable(
enumerate(_with_is_last(train_dataloader)), "get_train_batch"
enumerate(_with_is_last(train_dataloader)), "get_train_batch"
):
# stop epoch if we limited the number of training batches
if batch_idx >= self.num_training_batches:
@ -661,7 +678,10 @@ class TrainerTrainLoopMixin(ABC):
opt_idx = np.argmax(optimizer_freq_cumsum > current_place_in_loop)
return [(opt_idx, self.optimizers[opt_idx])]
@atexit.register
def run_training_teardown(self):
if hasattr(self, '_teardown_already_run') and self._teardown_already_run:
return
# Train end events
with self.profiler.profile('on_train_end'):
# callbacks
@ -676,6 +696,8 @@ class TrainerTrainLoopMixin(ABC):
# summarize profile results
self.profiler.describe()
self._teardown_already_run = True
def training_forward(self, batch, batch_idx, opt_idx, hiddens):
"""
Handle forward for each training case (distributed, single gpu, etc...)