explicit test

This commit is contained in:
awaelchli 2024-03-22 16:22:40 +01:00
parent cfcb55e2ee
commit 64a93e0196
1 changed files with 7 additions and 3 deletions

View File

@ -38,6 +38,7 @@ from lightning.pytorch.callbacks.on_exception_checkpoint import OnExceptionCheck
from lightning.pytorch.callbacks.prediction_writer import BasePredictionWriter
from lightning.pytorch.core.saving import load_hparams_from_tags_csv, load_hparams_from_yaml, save_hparams_to_tags_csv
from lightning.pytorch.demos.boring_classes import (
BoringDataModule,
BoringModel,
RandomDataset,
RandomIterableDataset,
@ -2059,10 +2060,13 @@ def test_trainer_calls_datamodule_on_exception(exception_type):
def on_fit_start(self):
raise exception
datamodule = BoringDataModule()
datamodule.on_exception = Mock()
trainer = Trainer()
with mock.patch("lightning.pytorch.LightningDataModule.on_exception") as on_exception_mock, suppress(Exception):
trainer.fit(ExceptionModel(), datamodule=LightningDataModule())
on_exception_mock.assert_called_once_with(exception)
with suppress(Exception):
trainer.fit(ExceptionModel(), datamodule=datamodule)
datamodule.on_exception.assert_called_once_with(exception)
def test_init_module_context(monkeypatch):