Better graceful shutdown for KeyboardInterrupt (#19976)
This commit is contained in:
parent
b16e998a6e
commit
c1af4d0527
|
@ -5,6 +5,40 @@ All notable changes to this project will be documented in this file.
|
|||
The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
|
||||
|
||||
|
||||
## [unreleased] - YYYY-MM-DD
|
||||
|
||||
### Added
|
||||
|
||||
-
|
||||
|
||||
-
|
||||
|
||||
### Changed
|
||||
|
||||
-
|
||||
|
||||
-
|
||||
|
||||
### Deprecated
|
||||
|
||||
-
|
||||
|
||||
-
|
||||
|
||||
### Removed
|
||||
|
||||
-
|
||||
|
||||
-
|
||||
|
||||
### Fixed
|
||||
|
||||
-
|
||||
|
||||
-
|
||||
|
||||
|
||||
|
||||
## [2.3.0] - 2024-06-13
|
||||
|
||||
### Added
|
||||
|
|
|
@ -2,6 +2,7 @@ import atexit
|
|||
import contextlib
|
||||
import logging
|
||||
import os
|
||||
import signal
|
||||
import time
|
||||
from contextlib import nullcontext
|
||||
from datetime import timedelta
|
||||
|
@ -306,8 +307,11 @@ def _init_dist_connection(
|
|||
|
||||
|
||||
def _destroy_dist_connection() -> None:
|
||||
# Don't allow Ctrl+C to interrupt this handler
|
||||
signal.signal(signal.SIGINT, signal.SIG_IGN)
|
||||
if _distributed_is_initialized():
|
||||
torch.distributed.destroy_process_group()
|
||||
signal.signal(signal.SIGINT, signal.SIG_DFL)
|
||||
|
||||
|
||||
def _get_default_process_group_backend_for_device(device: torch.device) -> str:
|
||||
|
|
|
@ -4,6 +4,41 @@ All notable changes to this project will be documented in this file.
|
|||
|
||||
The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
|
||||
|
||||
|
||||
## [unreleased] - YYYY-MM-DD
|
||||
|
||||
### Added
|
||||
|
||||
-
|
||||
|
||||
-
|
||||
|
||||
### Changed
|
||||
|
||||
- Triggering KeyboardInterrupt (Ctrl+C) during `.fit()`, `.evaluate()`, `.test()` or `.predict()` now terminates all processes launched by the Trainer and exits the program ([#19976](https://github.com/Lightning-AI/pytorch-lightning/pull/19976))
|
||||
|
||||
-
|
||||
|
||||
### Deprecated
|
||||
|
||||
-
|
||||
|
||||
-
|
||||
|
||||
### Removed
|
||||
|
||||
-
|
||||
|
||||
-
|
||||
|
||||
### Fixed
|
||||
|
||||
-
|
||||
|
||||
-
|
||||
|
||||
|
||||
|
||||
## [2.3.0] - 2024-06-13
|
||||
|
||||
### Added
|
||||
|
|
|
@ -259,7 +259,7 @@ class _MultiProcessingLauncher(_Launcher):
|
|||
def kill(self, signum: _SIGNUM) -> None:
|
||||
for proc in self.procs:
|
||||
if proc.is_alive() and proc.pid is not None:
|
||||
log.info(f"pid {os.getpid()} killing {proc.pid} with {signum}")
|
||||
log.debug(f"Process {os.getpid()} is terminating {proc.pid} with {signum}")
|
||||
with suppress(ProcessLookupError):
|
||||
os.kill(proc.pid, signum)
|
||||
|
||||
|
|
|
@ -107,7 +107,7 @@ class _SubprocessScriptLauncher(_Launcher):
|
|||
@override
|
||||
def kill(self, signum: _SIGNUM) -> None:
|
||||
for proc in self.procs:
|
||||
log.info(f"pid {os.getpid()} killing {proc.pid} with {signum}")
|
||||
log.debug(f"Process {os.getpid()} is terminating {proc.pid} with {signum}")
|
||||
# this skips subprocesses already terminated
|
||||
proc.send_signal(signum)
|
||||
|
||||
|
|
|
@ -12,6 +12,7 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import logging
|
||||
import signal
|
||||
from copy import deepcopy
|
||||
from typing import Any, Callable, Dict, Optional, Type, Union
|
||||
|
||||
|
@ -20,10 +21,12 @@ from packaging.version import Version
|
|||
import lightning.pytorch as pl
|
||||
from lightning.fabric.utilities.device_dtype_mixin import _DeviceDtypeModuleMixin
|
||||
from lightning.pytorch.callbacks import Checkpoint, EarlyStopping
|
||||
from lightning.pytorch.strategies.launchers import _SubprocessScriptLauncher
|
||||
from lightning.pytorch.trainer.connectors.signal_connector import _get_sigkill_signal
|
||||
from lightning.pytorch.trainer.states import TrainerStatus
|
||||
from lightning.pytorch.utilities.exceptions import _TunerExitException
|
||||
from lightning.pytorch.utilities.model_helpers import is_overridden
|
||||
from lightning.pytorch.utilities.rank_zero import rank_zero_warn
|
||||
from lightning.pytorch.utilities.rank_zero import rank_zero_info, rank_zero_warn
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
@ -49,12 +52,17 @@ def _call_and_handle_interrupt(trainer: "pl.Trainer", trainer_fn: Callable, *arg
|
|||
trainer.state.status = TrainerStatus.FINISHED
|
||||
trainer.state.stage = None
|
||||
|
||||
# TODO: Unify both exceptions below, where `KeyboardError` doesn't re-raise
|
||||
except KeyboardInterrupt as exception:
|
||||
rank_zero_warn("Detected KeyboardInterrupt, attempting graceful shutdown...")
|
||||
# user could press Ctrl+c many times... only shutdown once
|
||||
if not trainer.interrupted:
|
||||
_interrupt(trainer, exception)
|
||||
rank_zero_info("\nDetected KeyboardInterrupt, attempting graceful shutdown ...")
|
||||
# user could press Ctrl+C many times, disable KeyboardInterrupt for shutdown
|
||||
signal.signal(signal.SIGINT, signal.SIG_IGN)
|
||||
_interrupt(trainer, exception)
|
||||
trainer._teardown()
|
||||
launcher = trainer.strategy.launcher
|
||||
if isinstance(launcher, _SubprocessScriptLauncher):
|
||||
launcher.kill(_get_sigkill_signal())
|
||||
exit(1)
|
||||
|
||||
except BaseException as exception:
|
||||
_interrupt(trainer, exception)
|
||||
trainer._teardown()
|
||||
|
|
|
@ -2,7 +2,6 @@ import logging
|
|||
import os
|
||||
import re
|
||||
import signal
|
||||
import sys
|
||||
import threading
|
||||
from subprocess import call
|
||||
from types import FrameType
|
||||
|
@ -54,7 +53,7 @@ class _SignalConnector:
|
|||
sigterm_handlers.append(self._sigterm_handler_fn)
|
||||
|
||||
# Windows seems to have signal incompatibilities
|
||||
if not self._is_on_windows():
|
||||
if not _IS_WINDOWS:
|
||||
sigusr = environment.requeue_signal if isinstance(environment, SLURMEnvironment) else signal.SIGUSR1
|
||||
assert sigusr is not None
|
||||
if sigusr_handlers and not self._has_already_handler(sigusr):
|
||||
|
@ -155,10 +154,6 @@ class _SignalConnector:
|
|||
}
|
||||
return set(signal.Signals)
|
||||
|
||||
@staticmethod
|
||||
def _is_on_windows() -> bool:
|
||||
return sys.platform == "win32"
|
||||
|
||||
@staticmethod
|
||||
def _has_already_handler(signum: _SIGNUM) -> bool:
|
||||
return signal.getsignal(signum) not in (None, signal.SIG_DFL)
|
||||
|
@ -172,3 +167,7 @@ class _SignalConnector:
|
|||
state = self.__dict__.copy()
|
||||
state["_original_handlers"] = {}
|
||||
return state
|
||||
|
||||
|
||||
def _get_sigkill_signal() -> _SIGNUM:
|
||||
return signal.SIGTERM if _IS_WINDOWS else signal.SIGKILL
|
||||
|
|
|
@ -143,7 +143,7 @@ def test_rich_progress_bar_keyboard_interrupt(tmp_path):
|
|||
|
||||
with mock.patch(
|
||||
"lightning.pytorch.callbacks.progress.rich_progress.Progress.stop", autospec=True
|
||||
) as mock_progress_stop:
|
||||
) as mock_progress_stop, pytest.raises(SystemExit):
|
||||
progress_bar = RichProgressBar()
|
||||
trainer = Trainer(
|
||||
default_root_dir=tmp_path,
|
||||
|
|
|
@ -13,6 +13,7 @@
|
|||
# limitations under the License.
|
||||
from functools import partial
|
||||
|
||||
import pytest
|
||||
from lightning.pytorch import Trainer, seed_everything
|
||||
from lightning.pytorch.callbacks import Callback, LambdaCallback
|
||||
from lightning.pytorch.demos.boring_classes import BoringModel
|
||||
|
@ -23,10 +24,13 @@ from tests_pytorch.models.test_hooks import get_members
|
|||
def test_lambda_call(tmp_path):
|
||||
seed_everything(42)
|
||||
|
||||
class CustomException(Exception):
|
||||
pass
|
||||
|
||||
class CustomModel(BoringModel):
|
||||
def on_train_epoch_start(self):
|
||||
if self.current_epoch > 1:
|
||||
raise KeyboardInterrupt
|
||||
raise CustomException("Custom exception to trigger `on_exception` hooks")
|
||||
|
||||
checker = set()
|
||||
|
||||
|
@ -59,7 +63,8 @@ def test_lambda_call(tmp_path):
|
|||
limit_predict_batches=1,
|
||||
callbacks=[LambdaCallback(**hooks_args)],
|
||||
)
|
||||
trainer.fit(model, ckpt_path=ckpt_path)
|
||||
with pytest.raises(CustomException):
|
||||
trainer.fit(model, ckpt_path=ckpt_path)
|
||||
trainer.test(model)
|
||||
trainer.predict(model)
|
||||
|
||||
|
|
|
@ -84,5 +84,6 @@ def test_interrupt_state_on_keyboard_interrupt(tmp_path, extra_params):
|
|||
|
||||
trainer = Trainer(callbacks=[InterruptCallback()], default_root_dir=tmp_path, **extra_params)
|
||||
|
||||
trainer.fit(model)
|
||||
with pytest.raises(SystemExit):
|
||||
trainer.fit(model)
|
||||
assert trainer.interrupted
|
||||
|
|
|
@ -28,6 +28,7 @@ import pytest
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
from lightning.fabric.utilities.cloud_io import _load as pl_load
|
||||
from lightning.fabric.utilities.imports import _IS_WINDOWS
|
||||
from lightning.fabric.utilities.seed import seed_everything
|
||||
from lightning.pytorch import Callback, LightningDataModule, LightningModule, Trainer
|
||||
from lightning.pytorch.accelerators import CPUAccelerator, CUDAAccelerator
|
||||
|
@ -45,7 +46,7 @@ from lightning.pytorch.demos.boring_classes import (
|
|||
from lightning.pytorch.loggers import TensorBoardLogger
|
||||
from lightning.pytorch.overrides.distributed import UnrepeatedDistributedSampler, _IndexBatchSamplerWrapper
|
||||
from lightning.pytorch.strategies import DDPStrategy, SingleDeviceStrategy
|
||||
from lightning.pytorch.strategies.launchers import _MultiProcessingLauncher
|
||||
from lightning.pytorch.strategies.launchers import _MultiProcessingLauncher, _SubprocessScriptLauncher
|
||||
from lightning.pytorch.trainer.states import RunningStage, TrainerFn
|
||||
from lightning.pytorch.utilities.exceptions import MisconfigurationException
|
||||
from lightning.pytorch.utilities.imports import _OMEGACONF_AVAILABLE
|
||||
|
@ -1007,7 +1008,8 @@ def test_on_exception_hook(tmp_path):
|
|||
)
|
||||
assert not trainer.interrupted
|
||||
assert handle_interrupt_callback.exception is None
|
||||
trainer.fit(model)
|
||||
with pytest.raises(SystemExit):
|
||||
trainer.fit(model)
|
||||
assert trainer.interrupted
|
||||
assert isinstance(handle_interrupt_callback.exception, KeyboardInterrupt)
|
||||
with pytest.raises(MisconfigurationException):
|
||||
|
@ -1016,6 +1018,30 @@ def test_on_exception_hook(tmp_path):
|
|||
assert isinstance(handle_interrupt_callback.exception, MisconfigurationException)
|
||||
|
||||
|
||||
def test_keyboard_interrupt(tmp_path):
|
||||
class InterruptCallback(Callback):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
def on_train_batch_start(self, trainer, pl_module, batch, batch_idx):
|
||||
raise KeyboardInterrupt
|
||||
|
||||
model = BoringModel()
|
||||
trainer = Trainer(
|
||||
callbacks=[InterruptCallback()],
|
||||
barebones=True,
|
||||
default_root_dir=tmp_path,
|
||||
)
|
||||
|
||||
trainer.strategy._launcher = Mock(spec=_SubprocessScriptLauncher)
|
||||
trainer.strategy._launcher.launch = lambda function, *args, trainer, **kwargs: function(*args, **kwargs)
|
||||
|
||||
with pytest.raises(SystemExit) as exc_info:
|
||||
trainer.fit(model)
|
||||
assert exc_info.value.args[0] == 1
|
||||
trainer.strategy._launcher.kill.assert_called_once_with(15 if _IS_WINDOWS else 9)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("precision", ["32-true", pytest.param("16-mixed", marks=RunIf(min_cuda_gpus=1))])
|
||||
@RunIf(sklearn=True)
|
||||
def test_gradient_clipping_by_norm(tmp_path, precision):
|
||||
|
@ -2042,7 +2068,7 @@ def test_trainer_calls_strategy_on_exception(exception_type, tmp_path):
|
|||
|
||||
trainer = Trainer(default_root_dir=tmp_path)
|
||||
with mock.patch("lightning.pytorch.strategies.strategy.Strategy.on_exception") as on_exception_mock, suppress(
|
||||
Exception
|
||||
Exception, SystemExit
|
||||
):
|
||||
trainer.fit(ExceptionModel())
|
||||
on_exception_mock.assert_called_once_with(exception)
|
||||
|
@ -2061,7 +2087,7 @@ def test_trainer_calls_datamodule_on_exception(exception_type, tmp_path):
|
|||
datamodule.on_exception = Mock()
|
||||
trainer = Trainer(default_root_dir=tmp_path)
|
||||
|
||||
with suppress(Exception):
|
||||
with suppress(Exception, SystemExit):
|
||||
trainer.fit(ExceptionModel(), datamodule=datamodule)
|
||||
datamodule.on_exception.assert_called_once_with(exception)
|
||||
|
||||
|
|
Loading…
Reference in New Issue