diff --git a/tests/tests_pytorch/trainer/test_trainer.py b/tests/tests_pytorch/trainer/test_trainer.py index dbc9d1ec4d..565971e155 100644 --- a/tests/tests_pytorch/trainer/test_trainer.py +++ b/tests/tests_pytorch/trainer/test_trainer.py @@ -38,6 +38,7 @@ from lightning.pytorch.callbacks.on_exception_checkpoint import OnExceptionCheck from lightning.pytorch.callbacks.prediction_writer import BasePredictionWriter from lightning.pytorch.core.saving import load_hparams_from_tags_csv, load_hparams_from_yaml, save_hparams_to_tags_csv from lightning.pytorch.demos.boring_classes import ( + BoringDataModule, BoringModel, RandomDataset, RandomIterableDataset, @@ -2059,10 +2060,13 @@ def test_trainer_calls_datamodule_on_exception(exception_type): def on_fit_start(self): raise exception + datamodule = BoringDataModule() + datamodule.on_exception = Mock() trainer = Trainer() - with mock.patch("lightning.pytorch.LightningDataModule.on_exception") as on_exception_mock, suppress(Exception): - trainer.fit(ExceptionModel(), datamodule=LightningDataModule()) - on_exception_mock.assert_called_once_with(exception) + + with suppress(Exception): + trainer.fit(ExceptionModel(), datamodule=datamodule) + datamodule.on_exception.assert_called_once_with(exception) def test_init_module_context(monkeypatch):