Add `on_exception` to DataModule (#19601)

Co-authored-by: Alexander Jipa <azzhipa@amazon.com>
This commit is contained in:
Alexander Jipa 2024-03-22 15:56:12 -04:00 committed by GitHub
parent 6cfc590716
commit d5a9b775ce
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 41 additions and 10 deletions

View File

@ -14,6 +14,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Added robust timer duration parsing with an informative error message when parsing fails ([#19513](https://github.com/Lightning-AI/pytorch-lightning/pull/19513))
- Added `on_exception` hook to `LightningDataModule` ([#19601](https://github.com/Lightning-AI/pytorch-lightning/pull/19601))
-
### Changed

View File

@ -62,6 +62,10 @@ class LightningDataModule(DataHooks, HyperparametersMixin):
def test_dataloader(self):
return data.DataLoader(self.test)
def on_exception(self, exception):
# clean up state after the trainer faced an exception
...
def teardown(self):
# clean up state after the trainer stops, delete files...
# called on every process in DDP
@ -161,6 +165,10 @@ class LightningDataModule(DataHooks, HyperparametersMixin):
"""
pass
def on_exception(self, exception: BaseException) -> None:
"""Called when the trainer execution is interrupted by an exception."""
pass
@_restricted_classmethod
def load_from_checkpoint(
cls,

View File

@ -54,23 +54,25 @@ def _call_and_handle_interrupt(trainer: "pl.Trainer", trainer_fn: Callable, *arg
rank_zero_warn("Detected KeyboardInterrupt, attempting graceful shutdown...")
# user could press Ctrl+c many times... only shutdown once
if not trainer.interrupted:
trainer.state.status = TrainerStatus.INTERRUPTED
_call_callback_hooks(trainer, "on_exception", exception)
trainer.strategy.on_exception(exception)
for logger in trainer.loggers:
logger.finalize("failed")
_interrupt(trainer, exception)
except BaseException as exception:
trainer.state.status = TrainerStatus.INTERRUPTED
_call_callback_hooks(trainer, "on_exception", exception)
trainer.strategy.on_exception(exception)
for logger in trainer.loggers:
logger.finalize("failed")
_interrupt(trainer, exception)
trainer._teardown()
# teardown might access the stage so we reset it after
trainer.state.stage = None
raise
def _interrupt(trainer: "pl.Trainer", exception: BaseException) -> None:
trainer.state.status = TrainerStatus.INTERRUPTED
_call_callback_hooks(trainer, "on_exception", exception)
if trainer.datamodule is not None:
_call_lightning_datamodule_hook(trainer, "on_exception", exception)
trainer.strategy.on_exception(exception)
for logger in trainer.loggers:
logger.finalize("failed")
def _call_setup_hook(trainer: "pl.Trainer") -> None:
assert trainer.state.fn is not None
fn = trainer.state.fn

View File

@ -38,6 +38,7 @@ from lightning.pytorch.callbacks.on_exception_checkpoint import OnExceptionCheck
from lightning.pytorch.callbacks.prediction_writer import BasePredictionWriter
from lightning.pytorch.core.saving import load_hparams_from_tags_csv, load_hparams_from_yaml, save_hparams_to_tags_csv
from lightning.pytorch.demos.boring_classes import (
BoringDataModule,
BoringModel,
RandomDataset,
RandomIterableDataset,
@ -2050,6 +2051,24 @@ def test_trainer_calls_strategy_on_exception(exception_type):
on_exception_mock.assert_called_once_with(exception)
@pytest.mark.parametrize("exception_type", [KeyboardInterrupt, RuntimeError])
def test_trainer_calls_datamodule_on_exception(exception_type):
"""Test that when an exception occurs, the Trainer lets the data module process it."""
exception = exception_type("Test exception")
class ExceptionModel(BoringModel):
def on_fit_start(self):
raise exception
datamodule = BoringDataModule()
datamodule.on_exception = Mock()
trainer = Trainer()
with suppress(Exception):
trainer.fit(ExceptionModel(), datamodule=datamodule)
datamodule.on_exception.assert_called_once_with(exception)
def test_init_module_context(monkeypatch):
"""Test that the strategy returns the context manager for initializing the module."""
trainer = Trainer(accelerator="cpu", devices=1)