From 9a402461dab2c3797f06da7273d05145b46e5fe6 Mon Sep 17 00:00:00 2001 From: Nicki Skafte Date: Fri, 7 Aug 2020 00:34:48 +0200 Subject: [PATCH] Bugfix: Lr finder and hparams compatibility (#2821) * fix hparams lr finder bug * add tests for new functions * better tests * fix codefactor * fix styling * fix tests * fix codefactor * Apply suggestions from code review * modified hook Co-authored-by: Jirka Borovec Co-authored-by: William Falcon Co-authored-by: mergify[bot] <37929162+mergify[bot]@users.noreply.github.com> --- CHANGELOG.md | 2 + pytorch_lightning/trainer/lr_finder.py | 41 +++++------------ pytorch_lightning/utilities/parsing.py | 53 ++++++++++++++++++++++ tests/base/model_optimizers.py | 4 ++ tests/core/test_datamodules.py | 4 +- tests/trainer/test_lr_finder.py | 30 ++++++++++--- tests/utilities/parsing.py | 61 ++++++++++++++++++++++++++ 7 files changed, 158 insertions(+), 37 deletions(-) create mode 100644 tests/utilities/parsing.py diff --git a/CHANGELOG.md b/CHANGELOG.md index bf8d002bce..75c8fe3ef4 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -96,6 +96,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixed shell injection vulnerability in subprocess call ([#2786](https://github.com/PyTorchLightning/pytorch-lightning/pull/2786)) +- Fixed LR finder and `hparams` compatibility ([#2821](https://github.com/PyTorchLightning/pytorch-lightning/pull/2821)) + ## [0.8.5] - 2020-07-09 ### Added diff --git a/pytorch_lightning/trainer/lr_finder.py b/pytorch_lightning/trainer/lr_finder.py index 23ad702956..1b4fb8695e 100755 --- a/pytorch_lightning/trainer/lr_finder.py +++ b/pytorch_lightning/trainer/lr_finder.py @@ -24,7 +24,7 @@ from pytorch_lightning.callbacks import Callback from pytorch_lightning.loggers.base import DummyLogger from pytorch_lightning import _logger as log from pytorch_lightning.utilities.exceptions import MisconfigurationException -from pytorch_lightning.utilities import rank_zero_warn +from pytorch_lightning.utilities.parsing import lightning_hasattr, lightning_setattr class TrainerLRFinderMixin(ABC): @@ -57,24 +57,26 @@ class TrainerLRFinderMixin(ABC): """ Call lr finder internally during Trainer.fit() """ lr_finder = self.lr_find(model) lr = lr_finder.suggestion() + # TODO: log lr.results to self.logger if isinstance(self.auto_lr_find, str): # Try to find requested field, may be nested - if _nested_hasattr(model, self.auto_lr_find): - _nested_setattr(model, self.auto_lr_find, lr) + if lightning_hasattr(model, self.auto_lr_find): + lightning_setattr(model, self.auto_lr_find, lr) else: raise MisconfigurationException( f'`auto_lr_find` was set to {self.auto_lr_find}, however' - ' could not find this as a field in `model.hparams`.') + ' could not find this as a field in `model` or `model.hparams`.') else: - if hasattr(model, 'lr'): - model.lr = lr - elif hasattr(model, 'learning_rate'): - model.learning_rate = lr + if lightning_hasattr(model, 'lr'): + lightning_setattr(model, 'lr', lr) + elif lightning_hasattr(model, 'learning_rate'): + lightning_setattr(model, 'learning_rate', lr) else: raise MisconfigurationException( - 'When auto_lr_find is set to True, expects that hparams' - ' either has field `lr` or `learning_rate` that can overridden') + 'When auto_lr_find is set to True, expects that `model` or' + ' `model.hparams` either has field `lr` or `learning_rate`' + ' that can overridden') log.info(f'Learning rate set to {lr}') def lr_find( @@ -492,22 +494,3 @@ class _ExponentialLR(_LRScheduler): @property def lr(self): return self._lr - - -def _nested_hasattr(obj, path): - parts = path.split(".") - for part in parts: - if hasattr(obj, part): - obj = getattr(obj, part) - else: - return False - else: - return True - - -def _nested_setattr(obj, path, val): - parts = path.split(".") - for part in parts[:-1]: - if hasattr(obj, part): - obj = getattr(obj, part) - setattr(obj, parts[-1], val) diff --git a/pytorch_lightning/utilities/parsing.py b/pytorch_lightning/utilities/parsing.py index 2af67d5eb9..8c835a1820 100644 --- a/pytorch_lightning/utilities/parsing.py +++ b/pytorch_lightning/utilities/parsing.py @@ -140,3 +140,56 @@ class AttributeDict(Dict): rows = [tmp_name.format(f'"{n}":', self[n]) for n in sorted(self.keys())] out = '\n'.join(rows) return out + + +def lightning_hasattr(model, attribute): + """ Special hasattr for lightning. Checks for attribute in model namespace + and the old hparams namespace/dict """ + # Check if attribute in model + if hasattr(model, attribute): + attr = True + # Check if attribute in model.hparams, either namespace or dict + elif hasattr(model, 'hparams'): + if isinstance(model.hparams, dict): + attr = attribute in model.hparams + else: + attr = hasattr(model.hparams, attribute) + else: + attr = False + + return attr + + +def lightning_getattr(model, attribute): + """ Special getattr for lightning. Checks for attribute in model namespace + and the old hparams namespace/dict """ + # Check if attribute in model + if hasattr(model, attribute): + attr = getattr(model, attribute) + # Check if attribute in model.hparams, either namespace or dict + elif hasattr(model, 'hparams'): + if isinstance(model.hparams, dict): + attr = model.hparams[attribute] + else: + attr = getattr(model.hparams, attribute) + else: + raise ValueError(f'{attribute} is not stored in the model namespace' + ' or the `hparams` namespace/dict.') + return attr + + +def lightning_setattr(model, attribute, value): + """ Special setattr for lightning. Checks for attribute in model namespace + and the old hparams namespace/dict """ + # Check if attribute in model + if hasattr(model, attribute): + setattr(model, attribute, value) + # Check if attribute in model.hparams, either namespace or dict + elif hasattr(model, 'hparams'): + if isinstance(model.hparams, dict): + model.hparams[attribute] = value + else: + setattr(model.hparams, attribute, value) + else: + raise ValueError(f'{attribute} is not stored in the model namespace' + ' or the `hparams` namespace/dict.') diff --git a/tests/base/model_optimizers.py b/tests/base/model_optimizers.py index 6386d925bd..4628a92807 100644 --- a/tests/base/model_optimizers.py +++ b/tests/base/model_optimizers.py @@ -70,3 +70,7 @@ class ConfigureOptimizersPool(ABC): optimizer = optim.Adam(param_groups) lr_scheduler = optim.lr_scheduler.StepLR(optimizer, 1, gamma=0.1) return [optimizer], [lr_scheduler] + + def configure_optimizers__lr_from_hparams(self): + optimizer = optim.Adam(self.parameters(), lr=self.hparams.learning_rate) + return optimizer diff --git a/tests/core/test_datamodules.py b/tests/core/test_datamodules.py index fd4d3c082e..c0c41be42f 100644 --- a/tests/core/test_datamodules.py +++ b/tests/core/test_datamodules.py @@ -4,7 +4,7 @@ from argparse import ArgumentParser import pytest import torch -from pytorch_lightning import Trainer +from pytorch_lightning import Trainer, seed_everything from tests.base import EvalModelTemplate from tests.base.datamodules import TrialMNISTDataModule from tests.base.develop_utils import reset_seed @@ -291,7 +291,7 @@ def test_full_loop_ddp_spawn(tmpdir): import os os.environ['CUDA_VISIBLE_DEVICES'] = '0,1' - reset_seed() + seed_everything(1234) dm = TrialMNISTDataModule(tmpdir) diff --git a/tests/trainer/test_lr_finder.py b/tests/trainer/test_lr_finder.py index ed0421a077..37bdb90ce6 100755 --- a/tests/trainer/test_lr_finder.py +++ b/tests/trainer/test_lr_finder.py @@ -73,11 +73,15 @@ def test_trainer_reset_correctly(tmpdir): f'Attribute {key} was not reset correctly after learning rate finder' -def test_trainer_arg_bool(tmpdir): +@pytest.mark.parametrize('use_hparams', [False, True]) +def test_trainer_arg_bool(tmpdir, use_hparams): """ Test that setting trainer arg to bool works """ hparams = EvalModelTemplate.get_default_hparams() model = EvalModelTemplate(**hparams) before_lr = hparams.get('learning_rate') + if use_hparams: + del model.learning_rate + model.configure_optimizers = model.configure_optimizers__lr_from_hparams # logger file to get meta trainer = Trainer( @@ -87,17 +91,27 @@ def test_trainer_arg_bool(tmpdir): ) trainer.fit(model) - after_lr = model.learning_rate + if use_hparams: + after_lr = model.hparams.learning_rate + else: + after_lr = model.learning_rate + assert before_lr != after_lr, \ 'Learning rate was not altered after running learning rate finder' -def test_trainer_arg_str(tmpdir): +@pytest.mark.parametrize('use_hparams', [False, True]) +def test_trainer_arg_str(tmpdir, use_hparams): """ Test that setting trainer arg to string works """ - model = EvalModelTemplate() + hparams = EvalModelTemplate.get_default_hparams() + model = EvalModelTemplate(**hparams) model.my_fancy_lr = 1.0 # update with non-standard field - + model.hparams['my_fancy_lr'] = 1.0 before_lr = model.my_fancy_lr + if use_hparams: + del model.my_fancy_lr + model.configure_optimizers = model.configure_optimizers__lr_from_hparams + # logger file to get meta trainer = Trainer( default_root_dir=tmpdir, @@ -106,7 +120,11 @@ def test_trainer_arg_str(tmpdir): ) trainer.fit(model) - after_lr = model.my_fancy_lr + if use_hparams: + after_lr = model.hparams.my_fancy_lr + else: + after_lr = model.my_fancy_lr + assert before_lr != after_lr, \ 'Learning rate was not altered after running learning rate finder' diff --git a/tests/utilities/parsing.py b/tests/utilities/parsing.py new file mode 100644 index 0000000000..469eca74e7 --- /dev/null +++ b/tests/utilities/parsing.py @@ -0,0 +1,61 @@ +import pytest + +from pytorch_lightning.utilities.parsing import lightning_getattr, lightning_hasattr, lightning_setattr + + +def _get_test_cases(): + class TestHparamsNamespace: + learning_rate = 1 + + TestHparamsDict = {'learning_rate': 2} + + class TestModel1: # test for namespace + learning_rate = 0 + model1 = TestModel1() + + class TestModel2: # test for hparams namespace + hparams = TestHparamsNamespace() + + model2 = TestModel2() + + class TestModel3: # test for hparams dict + hparams = TestHparamsDict + + model3 = TestModel3() + + class TestModel4: # fail case + batch_size = 1 + + model4 = TestModel4() + + return model1, model2, model3, model4 + + +def test_lightning_hasattr(tmpdir): + """ Test that the lightning_hasattr works in all cases""" + model1, model2, model3, model4 = _get_test_cases() + assert lightning_hasattr(model1, 'learning_rate'), \ + 'lightning_hasattr failed to find namespace variable' + assert lightning_hasattr(model2, 'learning_rate'), \ + 'lightning_hasattr failed to find hparams namespace variable' + assert lightning_hasattr(model3, 'learning_rate'), \ + 'lightning_hasattr failed to find hparams dict variable' + assert not lightning_hasattr(model4, 'learning_rate'), \ + 'lightning_hasattr found variable when it should not' + + +def test_lightning_getattr(tmpdir): + """ Test that the lightning_getattr works in all cases""" + models = _get_test_cases() + for i, m in enumerate(models[:3]): + value = lightning_getattr(m, 'learning_rate') + assert value == i, 'attribute not correctly extracted' + + +def test_lightning_setattr(tmpdir): + """ Test that the lightning_setattr works in all cases""" + models = _get_test_cases() + for m in models[:3]: + lightning_setattr(m, 'learning_rate', 10) + assert lightning_getattr(m, 'learning_rate') == 10, \ + 'attribute not correctly set'