Add `on_exception` to DataModule (#19601)
Co-authored-by: Alexander Jipa <azzhipa@amazon.com>
This commit is contained in:
parent
6cfc590716
commit
d5a9b775ce
|
@ -14,6 +14,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
|
|||
|
||||
- Added robust timer duration parsing with an informative error message when parsing fails ([#19513](https://github.com/Lightning-AI/pytorch-lightning/pull/19513))
|
||||
|
||||
- Added `on_exception` hook to `LightningDataModule` ([#19601](https://github.com/Lightning-AI/pytorch-lightning/pull/19601))
|
||||
|
||||
-
|
||||
|
||||
### Changed
|
||||
|
|
|
@ -62,6 +62,10 @@ class LightningDataModule(DataHooks, HyperparametersMixin):
|
|||
def test_dataloader(self):
|
||||
return data.DataLoader(self.test)
|
||||
|
||||
def on_exception(self, exception):
|
||||
# clean up state after the trainer faced an exception
|
||||
...
|
||||
|
||||
def teardown(self):
|
||||
# clean up state after the trainer stops, delete files...
|
||||
# called on every process in DDP
|
||||
|
@ -161,6 +165,10 @@ class LightningDataModule(DataHooks, HyperparametersMixin):
|
|||
"""
|
||||
pass
|
||||
|
||||
def on_exception(self, exception: BaseException) -> None:
|
||||
"""Called when the trainer execution is interrupted by an exception."""
|
||||
pass
|
||||
|
||||
@_restricted_classmethod
|
||||
def load_from_checkpoint(
|
||||
cls,
|
||||
|
|
|
@ -54,23 +54,25 @@ def _call_and_handle_interrupt(trainer: "pl.Trainer", trainer_fn: Callable, *arg
|
|||
rank_zero_warn("Detected KeyboardInterrupt, attempting graceful shutdown...")
|
||||
# user could press Ctrl+c many times... only shutdown once
|
||||
if not trainer.interrupted:
|
||||
trainer.state.status = TrainerStatus.INTERRUPTED
|
||||
_call_callback_hooks(trainer, "on_exception", exception)
|
||||
trainer.strategy.on_exception(exception)
|
||||
for logger in trainer.loggers:
|
||||
logger.finalize("failed")
|
||||
_interrupt(trainer, exception)
|
||||
except BaseException as exception:
|
||||
trainer.state.status = TrainerStatus.INTERRUPTED
|
||||
_call_callback_hooks(trainer, "on_exception", exception)
|
||||
trainer.strategy.on_exception(exception)
|
||||
for logger in trainer.loggers:
|
||||
logger.finalize("failed")
|
||||
_interrupt(trainer, exception)
|
||||
trainer._teardown()
|
||||
# teardown might access the stage so we reset it after
|
||||
trainer.state.stage = None
|
||||
raise
|
||||
|
||||
|
||||
def _interrupt(trainer: "pl.Trainer", exception: BaseException) -> None:
|
||||
trainer.state.status = TrainerStatus.INTERRUPTED
|
||||
_call_callback_hooks(trainer, "on_exception", exception)
|
||||
if trainer.datamodule is not None:
|
||||
_call_lightning_datamodule_hook(trainer, "on_exception", exception)
|
||||
trainer.strategy.on_exception(exception)
|
||||
for logger in trainer.loggers:
|
||||
logger.finalize("failed")
|
||||
|
||||
|
||||
def _call_setup_hook(trainer: "pl.Trainer") -> None:
|
||||
assert trainer.state.fn is not None
|
||||
fn = trainer.state.fn
|
||||
|
|
|
@ -38,6 +38,7 @@ from lightning.pytorch.callbacks.on_exception_checkpoint import OnExceptionCheck
|
|||
from lightning.pytorch.callbacks.prediction_writer import BasePredictionWriter
|
||||
from lightning.pytorch.core.saving import load_hparams_from_tags_csv, load_hparams_from_yaml, save_hparams_to_tags_csv
|
||||
from lightning.pytorch.demos.boring_classes import (
|
||||
BoringDataModule,
|
||||
BoringModel,
|
||||
RandomDataset,
|
||||
RandomIterableDataset,
|
||||
|
@ -2050,6 +2051,24 @@ def test_trainer_calls_strategy_on_exception(exception_type):
|
|||
on_exception_mock.assert_called_once_with(exception)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("exception_type", [KeyboardInterrupt, RuntimeError])
|
||||
def test_trainer_calls_datamodule_on_exception(exception_type):
|
||||
"""Test that when an exception occurs, the Trainer lets the data module process it."""
|
||||
exception = exception_type("Test exception")
|
||||
|
||||
class ExceptionModel(BoringModel):
|
||||
def on_fit_start(self):
|
||||
raise exception
|
||||
|
||||
datamodule = BoringDataModule()
|
||||
datamodule.on_exception = Mock()
|
||||
trainer = Trainer()
|
||||
|
||||
with suppress(Exception):
|
||||
trainer.fit(ExceptionModel(), datamodule=datamodule)
|
||||
datamodule.on_exception.assert_called_once_with(exception)
|
||||
|
||||
|
||||
def test_init_module_context(monkeypatch):
|
||||
"""Test that the strategy returns the context manager for initializing the module."""
|
||||
trainer = Trainer(accelerator="cpu", devices=1)
|
||||
|
|
Loading…
Reference in New Issue