diff --git a/pytorch_lightning/trainer/connectors/data_connector.py b/pytorch_lightning/trainer/connectors/data_connector.py index 9e08cf0311..7a0e0f39ca 100644 --- a/pytorch_lightning/trainer/connectors/data_connector.py +++ b/pytorch_lightning/trainer/connectors/data_connector.py @@ -26,10 +26,17 @@ class DataConnector(object): def __init__(self, trainer): self.trainer = trainer - def on_trainer_init(self, check_val_every_n_epoch, reload_dataloaders_every_epoch, prepare_data_per_node): + def on_trainer_init( + self, check_val_every_n_epoch: int, reload_dataloaders_every_epoch: bool, prepare_data_per_node: bool + ) -> None: self.trainer.datamodule = None self.trainer.prepare_data_per_node = prepare_data_per_node + if not isinstance(check_val_every_n_epoch, int): + raise MisconfigurationException( + f"check_val_every_n_epoch should be an integer. Found {check_val_every_n_epoch}" + ) + self.trainer.check_val_every_n_epoch = check_val_every_n_epoch self.trainer.reload_dataloaders_every_epoch = reload_dataloaders_every_epoch self.trainer._is_data_prepared = False diff --git a/tests/trainer/test_trainer.py b/tests/trainer/test_trainer.py index 385d8c1c6b..f1a3687b43 100644 --- a/tests/trainer/test_trainer.py +++ b/tests/trainer/test_trainer.py @@ -1828,3 +1828,13 @@ def test_init_optimizers_resets_lightning_optimizers(tmpdir): trainer.max_epochs = 2 # simulate multiple fit calls trainer.fit(model) compare_optimizers() + + +def test_check_val_every_n_epoch_exception(tmpdir): + + with pytest.raises(MisconfigurationException, match="should be an integer."): + Trainer( + default_root_dir=tmpdir, + max_epochs=1, + check_val_every_n_epoch=1.2, + )