From 74d79e7e0ec38b2f892f7424d224124a6b896ef0 Mon Sep 17 00:00:00 2001 From: Kaushik B <45285388+kaushikb11@users.noreply.github.com> Date: Wed, 10 Mar 2021 12:08:53 +0530 Subject: [PATCH] Raise an exception if check_val_every_n_epoch is not an integer (#6411) * raise an exception if check_val_every_n_epoch is not an integer * remove unused object * add type hints * add return type * update exception message * update exception message --- pytorch_lightning/trainer/connectors/data_connector.py | 9 ++++++++- tests/trainer/test_trainer.py | 10 ++++++++++ 2 files changed, 18 insertions(+), 1 deletion(-) diff --git a/pytorch_lightning/trainer/connectors/data_connector.py b/pytorch_lightning/trainer/connectors/data_connector.py index 9e08cf0311..7a0e0f39ca 100644 --- a/pytorch_lightning/trainer/connectors/data_connector.py +++ b/pytorch_lightning/trainer/connectors/data_connector.py @@ -26,10 +26,17 @@ class DataConnector(object): def __init__(self, trainer): self.trainer = trainer - def on_trainer_init(self, check_val_every_n_epoch, reload_dataloaders_every_epoch, prepare_data_per_node): + def on_trainer_init( + self, check_val_every_n_epoch: int, reload_dataloaders_every_epoch: bool, prepare_data_per_node: bool + ) -> None: self.trainer.datamodule = None self.trainer.prepare_data_per_node = prepare_data_per_node + if not isinstance(check_val_every_n_epoch, int): + raise MisconfigurationException( + f"check_val_every_n_epoch should be an integer. Found {check_val_every_n_epoch}" + ) + self.trainer.check_val_every_n_epoch = check_val_every_n_epoch self.trainer.reload_dataloaders_every_epoch = reload_dataloaders_every_epoch self.trainer._is_data_prepared = False diff --git a/tests/trainer/test_trainer.py b/tests/trainer/test_trainer.py index 385d8c1c6b..f1a3687b43 100644 --- a/tests/trainer/test_trainer.py +++ b/tests/trainer/test_trainer.py @@ -1828,3 +1828,13 @@ def test_init_optimizers_resets_lightning_optimizers(tmpdir): trainer.max_epochs = 2 # simulate multiple fit calls trainer.fit(model) compare_optimizers() + + +def test_check_val_every_n_epoch_exception(tmpdir): + + with pytest.raises(MisconfigurationException, match="should be an integer."): + Trainer( + default_root_dir=tmpdir, + max_epochs=1, + check_val_every_n_epoch=1.2, + )