From 309ce7a96664316b0f42147d6a849001d841559a Mon Sep 17 00:00:00 2001 From: Dusan Drevicky <55678224+ddrevicky@users.noreply.github.com> Date: Fri, 12 Feb 2021 21:01:22 +0100 Subject: [PATCH] Fix: passing wrong strings for scheduler interval doesn't throw an error (#5923) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Raise if scheduler interval not 'step' or 'epoch' * Add test for unknown 'interval' value in scheduler * Use BoringModel instead of EvalModelTemplate Co-authored-by: Jirka Borovec * Fix import order * Apply yapf in test_datamodules * Add missing imports to test_datamodules * Fix too long comment * Update pytorch_lightning/trainer/optimizers.py * Fix unused imports and exception message * Fix failing test Co-authored-by: Jirka Borovec Co-authored-by: Carlos MocholĂ­ --- pytorch_lightning/trainer/optimizers.py | 6 ++++++ tests/trainer/optimization/test_optimizers.py | 18 ++++++++++++++++++ 2 files changed, 24 insertions(+) diff --git a/pytorch_lightning/trainer/optimizers.py b/pytorch_lightning/trainer/optimizers.py index 6772dcc645..6793a370fd 100644 --- a/pytorch_lightning/trainer/optimizers.py +++ b/pytorch_lightning/trainer/optimizers.py @@ -117,6 +117,12 @@ class TrainerOptimizersMixin(ABC): raise MisconfigurationException( 'The lr scheduler dict must have the key "scheduler" with its item being an lr scheduler' ) + if 'interval' in scheduler and scheduler['interval'] not in ('step', 'epoch'): + raise MisconfigurationException( + f'The "interval" key in lr scheduler dict must be "step" or "epoch"' + f' but is "{scheduler["interval"]}"' + ) + scheduler['reduce_on_plateau'] = isinstance( scheduler['scheduler'], optim.lr_scheduler.ReduceLROnPlateau ) diff --git a/tests/trainer/optimization/test_optimizers.py b/tests/trainer/optimization/test_optimizers.py index c9a9250995..7172b2dca7 100644 --- a/tests/trainer/optimization/test_optimizers.py +++ b/tests/trainer/optimization/test_optimizers.py @@ -459,6 +459,24 @@ def test_unknown_configure_optimizers_raises(tmpdir): trainer.fit(model) +def test_lr_scheduler_with_unknown_interval_raises(tmpdir): + """ + Test exception when lr_scheduler dict has unknown interval param value + """ + model = BoringModel() + optimizer = torch.optim.Adam(model.parameters()) + model.configure_optimizers = lambda: { + 'optimizer': optimizer, + 'lr_scheduler': { + 'scheduler': torch.optim.lr_scheduler.StepLR(optimizer, 1), + 'interval': "incorrect_unknown_value" + }, + } + trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True) + with pytest.raises(MisconfigurationException, match=r'The "interval" key in lr scheduler dict must be'): + trainer.fit(model) + + def test_lr_scheduler_with_extra_keys_warns(tmpdir): """ Test warning when lr_scheduler dict has extra keys