diff --git a/src/lightning/pytorch/CHANGELOG.md b/src/lightning/pytorch/CHANGELOG.md index 4b175d0279..f400e21536 100644 --- a/src/lightning/pytorch/CHANGELOG.md +++ b/src/lightning/pytorch/CHANGELOG.md @@ -14,6 +14,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Added robust timer duration parsing with an informative error message when parsing fails ([#19513](https://github.com/Lightning-AI/pytorch-lightning/pull/19513)) +- Added `on_exception` hook to `LightningDataModule` ([#19601](https://github.com/Lightning-AI/pytorch-lightning/pull/19601)) + - ### Changed diff --git a/src/lightning/pytorch/core/datamodule.py b/src/lightning/pytorch/core/datamodule.py index 9c3518b5db..982fe46412 100644 --- a/src/lightning/pytorch/core/datamodule.py +++ b/src/lightning/pytorch/core/datamodule.py @@ -62,6 +62,10 @@ class LightningDataModule(DataHooks, HyperparametersMixin): def test_dataloader(self): return data.DataLoader(self.test) + def on_exception(self, exception): + # clean up state after the trainer faced an exception + ... + def teardown(self): # clean up state after the trainer stops, delete files... # called on every process in DDP @@ -161,6 +165,10 @@ class LightningDataModule(DataHooks, HyperparametersMixin): """ pass + def on_exception(self, exception: BaseException) -> None: + """Called when the trainer execution is interrupted by an exception.""" + pass + @_restricted_classmethod def load_from_checkpoint( cls, diff --git a/src/lightning/pytorch/trainer/call.py b/src/lightning/pytorch/trainer/call.py index b9c270b620..befd7f0df8 100644 --- a/src/lightning/pytorch/trainer/call.py +++ b/src/lightning/pytorch/trainer/call.py @@ -54,23 +54,25 @@ def _call_and_handle_interrupt(trainer: "pl.Trainer", trainer_fn: Callable, *arg rank_zero_warn("Detected KeyboardInterrupt, attempting graceful shutdown...") # user could press Ctrl+c many times... only shutdown once if not trainer.interrupted: - trainer.state.status = TrainerStatus.INTERRUPTED - _call_callback_hooks(trainer, "on_exception", exception) - trainer.strategy.on_exception(exception) - for logger in trainer.loggers: - logger.finalize("failed") + _interrupt(trainer, exception) except BaseException as exception: - trainer.state.status = TrainerStatus.INTERRUPTED - _call_callback_hooks(trainer, "on_exception", exception) - trainer.strategy.on_exception(exception) - for logger in trainer.loggers: - logger.finalize("failed") + _interrupt(trainer, exception) trainer._teardown() # teardown might access the stage so we reset it after trainer.state.stage = None raise +def _interrupt(trainer: "pl.Trainer", exception: BaseException) -> None: + trainer.state.status = TrainerStatus.INTERRUPTED + _call_callback_hooks(trainer, "on_exception", exception) + if trainer.datamodule is not None: + _call_lightning_datamodule_hook(trainer, "on_exception", exception) + trainer.strategy.on_exception(exception) + for logger in trainer.loggers: + logger.finalize("failed") + + def _call_setup_hook(trainer: "pl.Trainer") -> None: assert trainer.state.fn is not None fn = trainer.state.fn diff --git a/tests/tests_pytorch/trainer/test_trainer.py b/tests/tests_pytorch/trainer/test_trainer.py index 612aec11e6..565971e155 100644 --- a/tests/tests_pytorch/trainer/test_trainer.py +++ b/tests/tests_pytorch/trainer/test_trainer.py @@ -38,6 +38,7 @@ from lightning.pytorch.callbacks.on_exception_checkpoint import OnExceptionCheck from lightning.pytorch.callbacks.prediction_writer import BasePredictionWriter from lightning.pytorch.core.saving import load_hparams_from_tags_csv, load_hparams_from_yaml, save_hparams_to_tags_csv from lightning.pytorch.demos.boring_classes import ( + BoringDataModule, BoringModel, RandomDataset, RandomIterableDataset, @@ -2050,6 +2051,24 @@ def test_trainer_calls_strategy_on_exception(exception_type): on_exception_mock.assert_called_once_with(exception) +@pytest.mark.parametrize("exception_type", [KeyboardInterrupt, RuntimeError]) +def test_trainer_calls_datamodule_on_exception(exception_type): + """Test that when an exception occurs, the Trainer lets the data module process it.""" + exception = exception_type("Test exception") + + class ExceptionModel(BoringModel): + def on_fit_start(self): + raise exception + + datamodule = BoringDataModule() + datamodule.on_exception = Mock() + trainer = Trainer() + + with suppress(Exception): + trainer.fit(ExceptionModel(), datamodule=datamodule) + datamodule.on_exception.assert_called_once_with(exception) + + def test_init_module_context(monkeypatch): """Test that the strategy returns the context manager for initializing the module.""" trainer = Trainer(accelerator="cpu", devices=1)