Fix: passing wrong strings for scheduler interval doesn't throw an error (#5923)

* Raise if scheduler interval not 'step' or 'epoch'

* Add test for unknown 'interval' value in scheduler

* Use BoringModel instead of EvalModelTemplate

Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>

* Fix import order

* Apply yapf in test_datamodules

* Add missing imports to test_datamodules

* Fix too long comment

* Update pytorch_lightning/trainer/optimizers.py

* Fix unused imports and exception message

* Fix failing test

Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>
Co-authored-by: Carlos Mocholí <carlossmocholi@gmail.com>
This commit is contained in:
Dusan Drevicky 2021-02-12 21:01:22 +01:00 committed by GitHub
parent ae19c9723b
commit 309ce7a966
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 24 additions and 0 deletions

View File

@ -117,6 +117,12 @@ class TrainerOptimizersMixin(ABC):
raise MisconfigurationException(
'The lr scheduler dict must have the key "scheduler" with its item being an lr scheduler'
)
if 'interval' in scheduler and scheduler['interval'] not in ('step', 'epoch'):
raise MisconfigurationException(
f'The "interval" key in lr scheduler dict must be "step" or "epoch"'
f' but is "{scheduler["interval"]}"'
)
scheduler['reduce_on_plateau'] = isinstance(
scheduler['scheduler'], optim.lr_scheduler.ReduceLROnPlateau
)

View File

@ -459,6 +459,24 @@ def test_unknown_configure_optimizers_raises(tmpdir):
trainer.fit(model)
def test_lr_scheduler_with_unknown_interval_raises(tmpdir):
"""
Test exception when lr_scheduler dict has unknown interval param value
"""
model = BoringModel()
optimizer = torch.optim.Adam(model.parameters())
model.configure_optimizers = lambda: {
'optimizer': optimizer,
'lr_scheduler': {
'scheduler': torch.optim.lr_scheduler.StepLR(optimizer, 1),
'interval': "incorrect_unknown_value"
},
}
trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True)
with pytest.raises(MisconfigurationException, match=r'The "interval" key in lr scheduler dict must be'):
trainer.fit(model)
def test_lr_scheduler_with_extra_keys_warns(tmpdir):
"""
Test warning when lr_scheduler dict has extra keys