Raise an exception if check_val_every_n_epoch is not an integer (#6411)
* raise an exception if check_val_every_n_epoch is not an integer * remove unused object * add type hints * add return type * update exception message * update exception message
This commit is contained in:
parent
615b2f7363
commit
74d79e7e0e
|
@ -26,10 +26,17 @@ class DataConnector(object):
|
|||
def __init__(self, trainer):
|
||||
self.trainer = trainer
|
||||
|
||||
def on_trainer_init(self, check_val_every_n_epoch, reload_dataloaders_every_epoch, prepare_data_per_node):
|
||||
def on_trainer_init(
|
||||
self, check_val_every_n_epoch: int, reload_dataloaders_every_epoch: bool, prepare_data_per_node: bool
|
||||
) -> None:
|
||||
self.trainer.datamodule = None
|
||||
self.trainer.prepare_data_per_node = prepare_data_per_node
|
||||
|
||||
if not isinstance(check_val_every_n_epoch, int):
|
||||
raise MisconfigurationException(
|
||||
f"check_val_every_n_epoch should be an integer. Found {check_val_every_n_epoch}"
|
||||
)
|
||||
|
||||
self.trainer.check_val_every_n_epoch = check_val_every_n_epoch
|
||||
self.trainer.reload_dataloaders_every_epoch = reload_dataloaders_every_epoch
|
||||
self.trainer._is_data_prepared = False
|
||||
|
|
|
@ -1828,3 +1828,13 @@ def test_init_optimizers_resets_lightning_optimizers(tmpdir):
|
|||
trainer.max_epochs = 2 # simulate multiple fit calls
|
||||
trainer.fit(model)
|
||||
compare_optimizers()
|
||||
|
||||
|
||||
def test_check_val_every_n_epoch_exception(tmpdir):
|
||||
|
||||
with pytest.raises(MisconfigurationException, match="should be an integer."):
|
||||
Trainer(
|
||||
default_root_dir=tmpdir,
|
||||
max_epochs=1,
|
||||
check_val_every_n_epoch=1.2,
|
||||
)
|
||||
|
|
Loading…
Reference in New Issue