Fix: passing wrong strings for scheduler interval doesn't throw an error (#5923)
* Raise if scheduler interval not 'step' or 'epoch' * Add test for unknown 'interval' value in scheduler * Use BoringModel instead of EvalModelTemplate Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> * Fix import order * Apply yapf in test_datamodules * Add missing imports to test_datamodules * Fix too long comment * Update pytorch_lightning/trainer/optimizers.py * Fix unused imports and exception message * Fix failing test Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> Co-authored-by: Carlos Mocholí <carlossmocholi@gmail.com>
This commit is contained in:
parent
ae19c9723b
commit
309ce7a966
|
@ -117,6 +117,12 @@ class TrainerOptimizersMixin(ABC):
|
|||
raise MisconfigurationException(
|
||||
'The lr scheduler dict must have the key "scheduler" with its item being an lr scheduler'
|
||||
)
|
||||
if 'interval' in scheduler and scheduler['interval'] not in ('step', 'epoch'):
|
||||
raise MisconfigurationException(
|
||||
f'The "interval" key in lr scheduler dict must be "step" or "epoch"'
|
||||
f' but is "{scheduler["interval"]}"'
|
||||
)
|
||||
|
||||
scheduler['reduce_on_plateau'] = isinstance(
|
||||
scheduler['scheduler'], optim.lr_scheduler.ReduceLROnPlateau
|
||||
)
|
||||
|
|
|
@ -459,6 +459,24 @@ def test_unknown_configure_optimizers_raises(tmpdir):
|
|||
trainer.fit(model)
|
||||
|
||||
|
||||
def test_lr_scheduler_with_unknown_interval_raises(tmpdir):
|
||||
"""
|
||||
Test exception when lr_scheduler dict has unknown interval param value
|
||||
"""
|
||||
model = BoringModel()
|
||||
optimizer = torch.optim.Adam(model.parameters())
|
||||
model.configure_optimizers = lambda: {
|
||||
'optimizer': optimizer,
|
||||
'lr_scheduler': {
|
||||
'scheduler': torch.optim.lr_scheduler.StepLR(optimizer, 1),
|
||||
'interval': "incorrect_unknown_value"
|
||||
},
|
||||
}
|
||||
trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True)
|
||||
with pytest.raises(MisconfigurationException, match=r'The "interval" key in lr scheduler dict must be'):
|
||||
trainer.fit(model)
|
||||
|
||||
|
||||
def test_lr_scheduler_with_extra_keys_warns(tmpdir):
|
||||
"""
|
||||
Test warning when lr_scheduler dict has extra keys
|
||||
|
|
Loading…
Reference in New Issue