Better graceful shutdown for KeyboardInterrupt (#19976)

This commit is contained in:
awaelchli 2024-06-16 16:43:42 +02:00 committed by GitHub
parent b16e998a6e
commit c1af4d0527
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
11 changed files with 134 additions and 22 deletions

View File

@ -5,6 +5,40 @@ All notable changes to this project will be documented in this file.
The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
## [unreleased] - YYYY-MM-DD
### Added
-
-
### Changed
-
-
### Deprecated
-
-
### Removed
-
-
### Fixed
-
-
## [2.3.0] - 2024-06-13
### Added

View File

@ -2,6 +2,7 @@ import atexit
import contextlib
import logging
import os
import signal
import time
from contextlib import nullcontext
from datetime import timedelta
@ -306,8 +307,11 @@ def _init_dist_connection(
def _destroy_dist_connection() -> None:
# Don't allow Ctrl+C to interrupt this handler
signal.signal(signal.SIGINT, signal.SIG_IGN)
if _distributed_is_initialized():
torch.distributed.destroy_process_group()
signal.signal(signal.SIGINT, signal.SIG_DFL)
def _get_default_process_group_backend_for_device(device: torch.device) -> str:

View File

@ -4,6 +4,41 @@ All notable changes to this project will be documented in this file.
The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
## [unreleased] - YYYY-MM-DD
### Added
-
-
### Changed
- Triggering KeyboardInterrupt (Ctrl+C) during `.fit()`, `.evaluate()`, `.test()` or `.predict()` now terminates all processes launched by the Trainer and exits the program ([#19976](https://github.com/Lightning-AI/pytorch-lightning/pull/19976))
-
### Deprecated
-
-
### Removed
-
-
### Fixed
-
-
## [2.3.0] - 2024-06-13
### Added

View File

@ -259,7 +259,7 @@ class _MultiProcessingLauncher(_Launcher):
def kill(self, signum: _SIGNUM) -> None:
for proc in self.procs:
if proc.is_alive() and proc.pid is not None:
log.info(f"pid {os.getpid()} killing {proc.pid} with {signum}")
log.debug(f"Process {os.getpid()} is terminating {proc.pid} with {signum}")
with suppress(ProcessLookupError):
os.kill(proc.pid, signum)

View File

@ -107,7 +107,7 @@ class _SubprocessScriptLauncher(_Launcher):
@override
def kill(self, signum: _SIGNUM) -> None:
for proc in self.procs:
log.info(f"pid {os.getpid()} killing {proc.pid} with {signum}")
log.debug(f"Process {os.getpid()} is terminating {proc.pid} with {signum}")
# this skips subprocesses already terminated
proc.send_signal(signum)

View File

@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import signal
from copy import deepcopy
from typing import Any, Callable, Dict, Optional, Type, Union
@ -20,10 +21,12 @@ from packaging.version import Version
import lightning.pytorch as pl
from lightning.fabric.utilities.device_dtype_mixin import _DeviceDtypeModuleMixin
from lightning.pytorch.callbacks import Checkpoint, EarlyStopping
from lightning.pytorch.strategies.launchers import _SubprocessScriptLauncher
from lightning.pytorch.trainer.connectors.signal_connector import _get_sigkill_signal
from lightning.pytorch.trainer.states import TrainerStatus
from lightning.pytorch.utilities.exceptions import _TunerExitException
from lightning.pytorch.utilities.model_helpers import is_overridden
from lightning.pytorch.utilities.rank_zero import rank_zero_warn
from lightning.pytorch.utilities.rank_zero import rank_zero_info, rank_zero_warn
log = logging.getLogger(__name__)
@ -49,12 +52,17 @@ def _call_and_handle_interrupt(trainer: "pl.Trainer", trainer_fn: Callable, *arg
trainer.state.status = TrainerStatus.FINISHED
trainer.state.stage = None
# TODO: Unify both exceptions below, where `KeyboardError` doesn't re-raise
except KeyboardInterrupt as exception:
rank_zero_warn("Detected KeyboardInterrupt, attempting graceful shutdown...")
# user could press Ctrl+c many times... only shutdown once
if not trainer.interrupted:
_interrupt(trainer, exception)
rank_zero_info("\nDetected KeyboardInterrupt, attempting graceful shutdown ...")
# user could press Ctrl+C many times, disable KeyboardInterrupt for shutdown
signal.signal(signal.SIGINT, signal.SIG_IGN)
_interrupt(trainer, exception)
trainer._teardown()
launcher = trainer.strategy.launcher
if isinstance(launcher, _SubprocessScriptLauncher):
launcher.kill(_get_sigkill_signal())
exit(1)
except BaseException as exception:
_interrupt(trainer, exception)
trainer._teardown()

View File

@ -2,7 +2,6 @@ import logging
import os
import re
import signal
import sys
import threading
from subprocess import call
from types import FrameType
@ -54,7 +53,7 @@ class _SignalConnector:
sigterm_handlers.append(self._sigterm_handler_fn)
# Windows seems to have signal incompatibilities
if not self._is_on_windows():
if not _IS_WINDOWS:
sigusr = environment.requeue_signal if isinstance(environment, SLURMEnvironment) else signal.SIGUSR1
assert sigusr is not None
if sigusr_handlers and not self._has_already_handler(sigusr):
@ -155,10 +154,6 @@ class _SignalConnector:
}
return set(signal.Signals)
@staticmethod
def _is_on_windows() -> bool:
return sys.platform == "win32"
@staticmethod
def _has_already_handler(signum: _SIGNUM) -> bool:
return signal.getsignal(signum) not in (None, signal.SIG_DFL)
@ -172,3 +167,7 @@ class _SignalConnector:
state = self.__dict__.copy()
state["_original_handlers"] = {}
return state
def _get_sigkill_signal() -> _SIGNUM:
return signal.SIGTERM if _IS_WINDOWS else signal.SIGKILL

View File

@ -143,7 +143,7 @@ def test_rich_progress_bar_keyboard_interrupt(tmp_path):
with mock.patch(
"lightning.pytorch.callbacks.progress.rich_progress.Progress.stop", autospec=True
) as mock_progress_stop:
) as mock_progress_stop, pytest.raises(SystemExit):
progress_bar = RichProgressBar()
trainer = Trainer(
default_root_dir=tmp_path,

View File

@ -13,6 +13,7 @@
# limitations under the License.
from functools import partial
import pytest
from lightning.pytorch import Trainer, seed_everything
from lightning.pytorch.callbacks import Callback, LambdaCallback
from lightning.pytorch.demos.boring_classes import BoringModel
@ -23,10 +24,13 @@ from tests_pytorch.models.test_hooks import get_members
def test_lambda_call(tmp_path):
seed_everything(42)
class CustomException(Exception):
pass
class CustomModel(BoringModel):
def on_train_epoch_start(self):
if self.current_epoch > 1:
raise KeyboardInterrupt
raise CustomException("Custom exception to trigger `on_exception` hooks")
checker = set()
@ -59,7 +63,8 @@ def test_lambda_call(tmp_path):
limit_predict_batches=1,
callbacks=[LambdaCallback(**hooks_args)],
)
trainer.fit(model, ckpt_path=ckpt_path)
with pytest.raises(CustomException):
trainer.fit(model, ckpt_path=ckpt_path)
trainer.test(model)
trainer.predict(model)

View File

@ -84,5 +84,6 @@ def test_interrupt_state_on_keyboard_interrupt(tmp_path, extra_params):
trainer = Trainer(callbacks=[InterruptCallback()], default_root_dir=tmp_path, **extra_params)
trainer.fit(model)
with pytest.raises(SystemExit):
trainer.fit(model)
assert trainer.interrupted

View File

@ -28,6 +28,7 @@ import pytest
import torch
import torch.nn as nn
from lightning.fabric.utilities.cloud_io import _load as pl_load
from lightning.fabric.utilities.imports import _IS_WINDOWS
from lightning.fabric.utilities.seed import seed_everything
from lightning.pytorch import Callback, LightningDataModule, LightningModule, Trainer
from lightning.pytorch.accelerators import CPUAccelerator, CUDAAccelerator
@ -45,7 +46,7 @@ from lightning.pytorch.demos.boring_classes import (
from lightning.pytorch.loggers import TensorBoardLogger
from lightning.pytorch.overrides.distributed import UnrepeatedDistributedSampler, _IndexBatchSamplerWrapper
from lightning.pytorch.strategies import DDPStrategy, SingleDeviceStrategy
from lightning.pytorch.strategies.launchers import _MultiProcessingLauncher
from lightning.pytorch.strategies.launchers import _MultiProcessingLauncher, _SubprocessScriptLauncher
from lightning.pytorch.trainer.states import RunningStage, TrainerFn
from lightning.pytorch.utilities.exceptions import MisconfigurationException
from lightning.pytorch.utilities.imports import _OMEGACONF_AVAILABLE
@ -1007,7 +1008,8 @@ def test_on_exception_hook(tmp_path):
)
assert not trainer.interrupted
assert handle_interrupt_callback.exception is None
trainer.fit(model)
with pytest.raises(SystemExit):
trainer.fit(model)
assert trainer.interrupted
assert isinstance(handle_interrupt_callback.exception, KeyboardInterrupt)
with pytest.raises(MisconfigurationException):
@ -1016,6 +1018,30 @@ def test_on_exception_hook(tmp_path):
assert isinstance(handle_interrupt_callback.exception, MisconfigurationException)
def test_keyboard_interrupt(tmp_path):
class InterruptCallback(Callback):
def __init__(self):
super().__init__()
def on_train_batch_start(self, trainer, pl_module, batch, batch_idx):
raise KeyboardInterrupt
model = BoringModel()
trainer = Trainer(
callbacks=[InterruptCallback()],
barebones=True,
default_root_dir=tmp_path,
)
trainer.strategy._launcher = Mock(spec=_SubprocessScriptLauncher)
trainer.strategy._launcher.launch = lambda function, *args, trainer, **kwargs: function(*args, **kwargs)
with pytest.raises(SystemExit) as exc_info:
trainer.fit(model)
assert exc_info.value.args[0] == 1
trainer.strategy._launcher.kill.assert_called_once_with(15 if _IS_WINDOWS else 9)
@pytest.mark.parametrize("precision", ["32-true", pytest.param("16-mixed", marks=RunIf(min_cuda_gpus=1))])
@RunIf(sklearn=True)
def test_gradient_clipping_by_norm(tmp_path, precision):
@ -2042,7 +2068,7 @@ def test_trainer_calls_strategy_on_exception(exception_type, tmp_path):
trainer = Trainer(default_root_dir=tmp_path)
with mock.patch("lightning.pytorch.strategies.strategy.Strategy.on_exception") as on_exception_mock, suppress(
Exception
Exception, SystemExit
):
trainer.fit(ExceptionModel())
on_exception_mock.assert_called_once_with(exception)
@ -2061,7 +2087,7 @@ def test_trainer_calls_datamodule_on_exception(exception_type, tmp_path):
datamodule.on_exception = Mock()
trainer = Trainer(default_root_dir=tmp_path)
with suppress(Exception):
with suppress(Exception, SystemExit):
trainer.fit(ExceptionModel(), datamodule=datamodule)
datamodule.on_exception.assert_called_once_with(exception)