From c1af4d05279af7a4630d2a27b57bd85699797465 Mon Sep 17 00:00:00 2001 From: awaelchli Date: Sun, 16 Jun 2024 16:43:42 +0200 Subject: [PATCH] Better graceful shutdown for KeyboardInterrupt (#19976) --- src/lightning/fabric/CHANGELOG.md | 34 ++++++++++++++++++ src/lightning/fabric/utilities/distributed.py | 4 +++ src/lightning/pytorch/CHANGELOG.md | 35 +++++++++++++++++++ .../strategies/launchers/multiprocessing.py | 2 +- .../strategies/launchers/subprocess_script.py | 2 +- src/lightning/pytorch/trainer/call.py | 20 +++++++---- .../trainer/connectors/signal_connector.py | 11 +++--- .../progress/test_rich_progress_bar.py | 2 +- .../callbacks/test_lambda_function.py | 9 +++-- tests/tests_pytorch/trainer/test_states.py | 3 +- tests/tests_pytorch/trainer/test_trainer.py | 34 +++++++++++++++--- 11 files changed, 134 insertions(+), 22 deletions(-) diff --git a/src/lightning/fabric/CHANGELOG.md b/src/lightning/fabric/CHANGELOG.md index 102783ea81..37322981c5 100644 --- a/src/lightning/fabric/CHANGELOG.md +++ b/src/lightning/fabric/CHANGELOG.md @@ -5,6 +5,40 @@ All notable changes to this project will be documented in this file. The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). +## [unreleased] - YYYY-MM-DD + +### Added + +- + +- + +### Changed + +- + +- + +### Deprecated + +- + +- + +### Removed + +- + +- + +### Fixed + +- + +- + + + ## [2.3.0] - 2024-06-13 ### Added diff --git a/src/lightning/fabric/utilities/distributed.py b/src/lightning/fabric/utilities/distributed.py index bb20b889ec..75b2f7c580 100644 --- a/src/lightning/fabric/utilities/distributed.py +++ b/src/lightning/fabric/utilities/distributed.py @@ -2,6 +2,7 @@ import atexit import contextlib import logging import os +import signal import time from contextlib import nullcontext from datetime import timedelta @@ -306,8 +307,11 @@ def _init_dist_connection( def _destroy_dist_connection() -> None: + # Don't allow Ctrl+C to interrupt this handler + signal.signal(signal.SIGINT, signal.SIG_IGN) if _distributed_is_initialized(): torch.distributed.destroy_process_group() + signal.signal(signal.SIGINT, signal.SIG_DFL) def _get_default_process_group_backend_for_device(device: torch.device) -> str: diff --git a/src/lightning/pytorch/CHANGELOG.md b/src/lightning/pytorch/CHANGELOG.md index 34ef2aa421..08562a9eb8 100644 --- a/src/lightning/pytorch/CHANGELOG.md +++ b/src/lightning/pytorch/CHANGELOG.md @@ -4,6 +4,41 @@ All notable changes to this project will be documented in this file. The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). + +## [unreleased] - YYYY-MM-DD + +### Added + +- + +- + +### Changed + +- Triggering KeyboardInterrupt (Ctrl+C) during `.fit()`, `.evaluate()`, `.test()` or `.predict()` now terminates all processes launched by the Trainer and exits the program ([#19976](https://github.com/Lightning-AI/pytorch-lightning/pull/19976)) + +- + +### Deprecated + +- + +- + +### Removed + +- + +- + +### Fixed + +- + +- + + + ## [2.3.0] - 2024-06-13 ### Added diff --git a/src/lightning/pytorch/strategies/launchers/multiprocessing.py b/src/lightning/pytorch/strategies/launchers/multiprocessing.py index aa96da63ad..58d9f2b16e 100644 --- a/src/lightning/pytorch/strategies/launchers/multiprocessing.py +++ b/src/lightning/pytorch/strategies/launchers/multiprocessing.py @@ -259,7 +259,7 @@ class _MultiProcessingLauncher(_Launcher): def kill(self, signum: _SIGNUM) -> None: for proc in self.procs: if proc.is_alive() and proc.pid is not None: - log.info(f"pid {os.getpid()} killing {proc.pid} with {signum}") + log.debug(f"Process {os.getpid()} is terminating {proc.pid} with {signum}") with suppress(ProcessLookupError): os.kill(proc.pid, signum) diff --git a/src/lightning/pytorch/strategies/launchers/subprocess_script.py b/src/lightning/pytorch/strategies/launchers/subprocess_script.py index 03dbbc5236..d2035d03d2 100644 --- a/src/lightning/pytorch/strategies/launchers/subprocess_script.py +++ b/src/lightning/pytorch/strategies/launchers/subprocess_script.py @@ -107,7 +107,7 @@ class _SubprocessScriptLauncher(_Launcher): @override def kill(self, signum: _SIGNUM) -> None: for proc in self.procs: - log.info(f"pid {os.getpid()} killing {proc.pid} with {signum}") + log.debug(f"Process {os.getpid()} is terminating {proc.pid} with {signum}") # this skips subprocesses already terminated proc.send_signal(signum) diff --git a/src/lightning/pytorch/trainer/call.py b/src/lightning/pytorch/trainer/call.py index befd7f0df8..4c3bc5ef41 100644 --- a/src/lightning/pytorch/trainer/call.py +++ b/src/lightning/pytorch/trainer/call.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import logging +import signal from copy import deepcopy from typing import Any, Callable, Dict, Optional, Type, Union @@ -20,10 +21,12 @@ from packaging.version import Version import lightning.pytorch as pl from lightning.fabric.utilities.device_dtype_mixin import _DeviceDtypeModuleMixin from lightning.pytorch.callbacks import Checkpoint, EarlyStopping +from lightning.pytorch.strategies.launchers import _SubprocessScriptLauncher +from lightning.pytorch.trainer.connectors.signal_connector import _get_sigkill_signal from lightning.pytorch.trainer.states import TrainerStatus from lightning.pytorch.utilities.exceptions import _TunerExitException from lightning.pytorch.utilities.model_helpers import is_overridden -from lightning.pytorch.utilities.rank_zero import rank_zero_warn +from lightning.pytorch.utilities.rank_zero import rank_zero_info, rank_zero_warn log = logging.getLogger(__name__) @@ -49,12 +52,17 @@ def _call_and_handle_interrupt(trainer: "pl.Trainer", trainer_fn: Callable, *arg trainer.state.status = TrainerStatus.FINISHED trainer.state.stage = None - # TODO: Unify both exceptions below, where `KeyboardError` doesn't re-raise except KeyboardInterrupt as exception: - rank_zero_warn("Detected KeyboardInterrupt, attempting graceful shutdown...") - # user could press Ctrl+c many times... only shutdown once - if not trainer.interrupted: - _interrupt(trainer, exception) + rank_zero_info("\nDetected KeyboardInterrupt, attempting graceful shutdown ...") + # user could press Ctrl+C many times, disable KeyboardInterrupt for shutdown + signal.signal(signal.SIGINT, signal.SIG_IGN) + _interrupt(trainer, exception) + trainer._teardown() + launcher = trainer.strategy.launcher + if isinstance(launcher, _SubprocessScriptLauncher): + launcher.kill(_get_sigkill_signal()) + exit(1) + except BaseException as exception: _interrupt(trainer, exception) trainer._teardown() diff --git a/src/lightning/pytorch/trainer/connectors/signal_connector.py b/src/lightning/pytorch/trainer/connectors/signal_connector.py index 728d8b6b6e..ca9e3eb249 100644 --- a/src/lightning/pytorch/trainer/connectors/signal_connector.py +++ b/src/lightning/pytorch/trainer/connectors/signal_connector.py @@ -2,7 +2,6 @@ import logging import os import re import signal -import sys import threading from subprocess import call from types import FrameType @@ -54,7 +53,7 @@ class _SignalConnector: sigterm_handlers.append(self._sigterm_handler_fn) # Windows seems to have signal incompatibilities - if not self._is_on_windows(): + if not _IS_WINDOWS: sigusr = environment.requeue_signal if isinstance(environment, SLURMEnvironment) else signal.SIGUSR1 assert sigusr is not None if sigusr_handlers and not self._has_already_handler(sigusr): @@ -155,10 +154,6 @@ class _SignalConnector: } return set(signal.Signals) - @staticmethod - def _is_on_windows() -> bool: - return sys.platform == "win32" - @staticmethod def _has_already_handler(signum: _SIGNUM) -> bool: return signal.getsignal(signum) not in (None, signal.SIG_DFL) @@ -172,3 +167,7 @@ class _SignalConnector: state = self.__dict__.copy() state["_original_handlers"] = {} return state + + +def _get_sigkill_signal() -> _SIGNUM: + return signal.SIGTERM if _IS_WINDOWS else signal.SIGKILL diff --git a/tests/tests_pytorch/callbacks/progress/test_rich_progress_bar.py b/tests/tests_pytorch/callbacks/progress/test_rich_progress_bar.py index 22e83443ef..de41035d4d 100644 --- a/tests/tests_pytorch/callbacks/progress/test_rich_progress_bar.py +++ b/tests/tests_pytorch/callbacks/progress/test_rich_progress_bar.py @@ -143,7 +143,7 @@ def test_rich_progress_bar_keyboard_interrupt(tmp_path): with mock.patch( "lightning.pytorch.callbacks.progress.rich_progress.Progress.stop", autospec=True - ) as mock_progress_stop: + ) as mock_progress_stop, pytest.raises(SystemExit): progress_bar = RichProgressBar() trainer = Trainer( default_root_dir=tmp_path, diff --git a/tests/tests_pytorch/callbacks/test_lambda_function.py b/tests/tests_pytorch/callbacks/test_lambda_function.py index 483c8c73e9..40d694bb35 100644 --- a/tests/tests_pytorch/callbacks/test_lambda_function.py +++ b/tests/tests_pytorch/callbacks/test_lambda_function.py @@ -13,6 +13,7 @@ # limitations under the License. from functools import partial +import pytest from lightning.pytorch import Trainer, seed_everything from lightning.pytorch.callbacks import Callback, LambdaCallback from lightning.pytorch.demos.boring_classes import BoringModel @@ -23,10 +24,13 @@ from tests_pytorch.models.test_hooks import get_members def test_lambda_call(tmp_path): seed_everything(42) + class CustomException(Exception): + pass + class CustomModel(BoringModel): def on_train_epoch_start(self): if self.current_epoch > 1: - raise KeyboardInterrupt + raise CustomException("Custom exception to trigger `on_exception` hooks") checker = set() @@ -59,7 +63,8 @@ def test_lambda_call(tmp_path): limit_predict_batches=1, callbacks=[LambdaCallback(**hooks_args)], ) - trainer.fit(model, ckpt_path=ckpt_path) + with pytest.raises(CustomException): + trainer.fit(model, ckpt_path=ckpt_path) trainer.test(model) trainer.predict(model) diff --git a/tests/tests_pytorch/trainer/test_states.py b/tests/tests_pytorch/trainer/test_states.py index bd5fd1c67e..d89e99c931 100644 --- a/tests/tests_pytorch/trainer/test_states.py +++ b/tests/tests_pytorch/trainer/test_states.py @@ -84,5 +84,6 @@ def test_interrupt_state_on_keyboard_interrupt(tmp_path, extra_params): trainer = Trainer(callbacks=[InterruptCallback()], default_root_dir=tmp_path, **extra_params) - trainer.fit(model) + with pytest.raises(SystemExit): + trainer.fit(model) assert trainer.interrupted diff --git a/tests/tests_pytorch/trainer/test_trainer.py b/tests/tests_pytorch/trainer/test_trainer.py index 1791f498d5..802c1a17bc 100644 --- a/tests/tests_pytorch/trainer/test_trainer.py +++ b/tests/tests_pytorch/trainer/test_trainer.py @@ -28,6 +28,7 @@ import pytest import torch import torch.nn as nn from lightning.fabric.utilities.cloud_io import _load as pl_load +from lightning.fabric.utilities.imports import _IS_WINDOWS from lightning.fabric.utilities.seed import seed_everything from lightning.pytorch import Callback, LightningDataModule, LightningModule, Trainer from lightning.pytorch.accelerators import CPUAccelerator, CUDAAccelerator @@ -45,7 +46,7 @@ from lightning.pytorch.demos.boring_classes import ( from lightning.pytorch.loggers import TensorBoardLogger from lightning.pytorch.overrides.distributed import UnrepeatedDistributedSampler, _IndexBatchSamplerWrapper from lightning.pytorch.strategies import DDPStrategy, SingleDeviceStrategy -from lightning.pytorch.strategies.launchers import _MultiProcessingLauncher +from lightning.pytorch.strategies.launchers import _MultiProcessingLauncher, _SubprocessScriptLauncher from lightning.pytorch.trainer.states import RunningStage, TrainerFn from lightning.pytorch.utilities.exceptions import MisconfigurationException from lightning.pytorch.utilities.imports import _OMEGACONF_AVAILABLE @@ -1007,7 +1008,8 @@ def test_on_exception_hook(tmp_path): ) assert not trainer.interrupted assert handle_interrupt_callback.exception is None - trainer.fit(model) + with pytest.raises(SystemExit): + trainer.fit(model) assert trainer.interrupted assert isinstance(handle_interrupt_callback.exception, KeyboardInterrupt) with pytest.raises(MisconfigurationException): @@ -1016,6 +1018,30 @@ def test_on_exception_hook(tmp_path): assert isinstance(handle_interrupt_callback.exception, MisconfigurationException) +def test_keyboard_interrupt(tmp_path): + class InterruptCallback(Callback): + def __init__(self): + super().__init__() + + def on_train_batch_start(self, trainer, pl_module, batch, batch_idx): + raise KeyboardInterrupt + + model = BoringModel() + trainer = Trainer( + callbacks=[InterruptCallback()], + barebones=True, + default_root_dir=tmp_path, + ) + + trainer.strategy._launcher = Mock(spec=_SubprocessScriptLauncher) + trainer.strategy._launcher.launch = lambda function, *args, trainer, **kwargs: function(*args, **kwargs) + + with pytest.raises(SystemExit) as exc_info: + trainer.fit(model) + assert exc_info.value.args[0] == 1 + trainer.strategy._launcher.kill.assert_called_once_with(15 if _IS_WINDOWS else 9) + + @pytest.mark.parametrize("precision", ["32-true", pytest.param("16-mixed", marks=RunIf(min_cuda_gpus=1))]) @RunIf(sklearn=True) def test_gradient_clipping_by_norm(tmp_path, precision): @@ -2042,7 +2068,7 @@ def test_trainer_calls_strategy_on_exception(exception_type, tmp_path): trainer = Trainer(default_root_dir=tmp_path) with mock.patch("lightning.pytorch.strategies.strategy.Strategy.on_exception") as on_exception_mock, suppress( - Exception + Exception, SystemExit ): trainer.fit(ExceptionModel()) on_exception_mock.assert_called_once_with(exception) @@ -2061,7 +2087,7 @@ def test_trainer_calls_datamodule_on_exception(exception_type, tmp_path): datamodule.on_exception = Mock() trainer = Trainer(default_root_dir=tmp_path) - with suppress(Exception): + with suppress(Exception, SystemExit): trainer.fit(ExceptionModel(), datamodule=datamodule) datamodule.on_exception.assert_called_once_with(exception)