Bugfix: Lr finder and hparams compatibility (#2821)

* fix hparams lr finder bug

* add tests for new functions

* better tests

* fix codefactor

* fix styling

* fix tests

* fix codefactor

* Apply suggestions from code review

* modified hook

Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>
Co-authored-by: William Falcon <waf2107@columbia.edu>
Co-authored-by: mergify[bot] <37929162+mergify[bot]@users.noreply.github.com>
This commit is contained in:
Nicki Skafte 2020-08-07 00:34:48 +02:00 committed by GitHub
parent 1dc411fc53
commit 9a402461da
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 158 additions and 37 deletions

View File

@ -96,6 +96,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Fixed shell injection vulnerability in subprocess call ([#2786](https://github.com/PyTorchLightning/pytorch-lightning/pull/2786))
- Fixed LR finder and `hparams` compatibility ([#2821](https://github.com/PyTorchLightning/pytorch-lightning/pull/2821))
## [0.8.5] - 2020-07-09
### Added

View File

@ -24,7 +24,7 @@ from pytorch_lightning.callbacks import Callback
from pytorch_lightning.loggers.base import DummyLogger
from pytorch_lightning import _logger as log
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities import rank_zero_warn
from pytorch_lightning.utilities.parsing import lightning_hasattr, lightning_setattr
class TrainerLRFinderMixin(ABC):
@ -57,24 +57,26 @@ class TrainerLRFinderMixin(ABC):
""" Call lr finder internally during Trainer.fit() """
lr_finder = self.lr_find(model)
lr = lr_finder.suggestion()
# TODO: log lr.results to self.logger
if isinstance(self.auto_lr_find, str):
# Try to find requested field, may be nested
if _nested_hasattr(model, self.auto_lr_find):
_nested_setattr(model, self.auto_lr_find, lr)
if lightning_hasattr(model, self.auto_lr_find):
lightning_setattr(model, self.auto_lr_find, lr)
else:
raise MisconfigurationException(
f'`auto_lr_find` was set to {self.auto_lr_find}, however'
' could not find this as a field in `model.hparams`.')
' could not find this as a field in `model` or `model.hparams`.')
else:
if hasattr(model, 'lr'):
model.lr = lr
elif hasattr(model, 'learning_rate'):
model.learning_rate = lr
if lightning_hasattr(model, 'lr'):
lightning_setattr(model, 'lr', lr)
elif lightning_hasattr(model, 'learning_rate'):
lightning_setattr(model, 'learning_rate', lr)
else:
raise MisconfigurationException(
'When auto_lr_find is set to True, expects that hparams'
' either has field `lr` or `learning_rate` that can overridden')
'When auto_lr_find is set to True, expects that `model` or'
' `model.hparams` either has field `lr` or `learning_rate`'
' that can overridden')
log.info(f'Learning rate set to {lr}')
def lr_find(
@ -492,22 +494,3 @@ class _ExponentialLR(_LRScheduler):
@property
def lr(self):
return self._lr
def _nested_hasattr(obj, path):
parts = path.split(".")
for part in parts:
if hasattr(obj, part):
obj = getattr(obj, part)
else:
return False
else:
return True
def _nested_setattr(obj, path, val):
parts = path.split(".")
for part in parts[:-1]:
if hasattr(obj, part):
obj = getattr(obj, part)
setattr(obj, parts[-1], val)

View File

@ -140,3 +140,56 @@ class AttributeDict(Dict):
rows = [tmp_name.format(f'"{n}":', self[n]) for n in sorted(self.keys())]
out = '\n'.join(rows)
return out
def lightning_hasattr(model, attribute):
""" Special hasattr for lightning. Checks for attribute in model namespace
and the old hparams namespace/dict """
# Check if attribute in model
if hasattr(model, attribute):
attr = True
# Check if attribute in model.hparams, either namespace or dict
elif hasattr(model, 'hparams'):
if isinstance(model.hparams, dict):
attr = attribute in model.hparams
else:
attr = hasattr(model.hparams, attribute)
else:
attr = False
return attr
def lightning_getattr(model, attribute):
""" Special getattr for lightning. Checks for attribute in model namespace
and the old hparams namespace/dict """
# Check if attribute in model
if hasattr(model, attribute):
attr = getattr(model, attribute)
# Check if attribute in model.hparams, either namespace or dict
elif hasattr(model, 'hparams'):
if isinstance(model.hparams, dict):
attr = model.hparams[attribute]
else:
attr = getattr(model.hparams, attribute)
else:
raise ValueError(f'{attribute} is not stored in the model namespace'
' or the `hparams` namespace/dict.')
return attr
def lightning_setattr(model, attribute, value):
""" Special setattr for lightning. Checks for attribute in model namespace
and the old hparams namespace/dict """
# Check if attribute in model
if hasattr(model, attribute):
setattr(model, attribute, value)
# Check if attribute in model.hparams, either namespace or dict
elif hasattr(model, 'hparams'):
if isinstance(model.hparams, dict):
model.hparams[attribute] = value
else:
setattr(model.hparams, attribute, value)
else:
raise ValueError(f'{attribute} is not stored in the model namespace'
' or the `hparams` namespace/dict.')

View File

@ -70,3 +70,7 @@ class ConfigureOptimizersPool(ABC):
optimizer = optim.Adam(param_groups)
lr_scheduler = optim.lr_scheduler.StepLR(optimizer, 1, gamma=0.1)
return [optimizer], [lr_scheduler]
def configure_optimizers__lr_from_hparams(self):
optimizer = optim.Adam(self.parameters(), lr=self.hparams.learning_rate)
return optimizer

View File

@ -4,7 +4,7 @@ from argparse import ArgumentParser
import pytest
import torch
from pytorch_lightning import Trainer
from pytorch_lightning import Trainer, seed_everything
from tests.base import EvalModelTemplate
from tests.base.datamodules import TrialMNISTDataModule
from tests.base.develop_utils import reset_seed
@ -291,7 +291,7 @@ def test_full_loop_ddp_spawn(tmpdir):
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '0,1'
reset_seed()
seed_everything(1234)
dm = TrialMNISTDataModule(tmpdir)

View File

@ -73,11 +73,15 @@ def test_trainer_reset_correctly(tmpdir):
f'Attribute {key} was not reset correctly after learning rate finder'
def test_trainer_arg_bool(tmpdir):
@pytest.mark.parametrize('use_hparams', [False, True])
def test_trainer_arg_bool(tmpdir, use_hparams):
""" Test that setting trainer arg to bool works """
hparams = EvalModelTemplate.get_default_hparams()
model = EvalModelTemplate(**hparams)
before_lr = hparams.get('learning_rate')
if use_hparams:
del model.learning_rate
model.configure_optimizers = model.configure_optimizers__lr_from_hparams
# logger file to get meta
trainer = Trainer(
@ -87,17 +91,27 @@ def test_trainer_arg_bool(tmpdir):
)
trainer.fit(model)
after_lr = model.learning_rate
if use_hparams:
after_lr = model.hparams.learning_rate
else:
after_lr = model.learning_rate
assert before_lr != after_lr, \
'Learning rate was not altered after running learning rate finder'
def test_trainer_arg_str(tmpdir):
@pytest.mark.parametrize('use_hparams', [False, True])
def test_trainer_arg_str(tmpdir, use_hparams):
""" Test that setting trainer arg to string works """
model = EvalModelTemplate()
hparams = EvalModelTemplate.get_default_hparams()
model = EvalModelTemplate(**hparams)
model.my_fancy_lr = 1.0 # update with non-standard field
model.hparams['my_fancy_lr'] = 1.0
before_lr = model.my_fancy_lr
if use_hparams:
del model.my_fancy_lr
model.configure_optimizers = model.configure_optimizers__lr_from_hparams
# logger file to get meta
trainer = Trainer(
default_root_dir=tmpdir,
@ -106,7 +120,11 @@ def test_trainer_arg_str(tmpdir):
)
trainer.fit(model)
after_lr = model.my_fancy_lr
if use_hparams:
after_lr = model.hparams.my_fancy_lr
else:
after_lr = model.my_fancy_lr
assert before_lr != after_lr, \
'Learning rate was not altered after running learning rate finder'

View File

@ -0,0 +1,61 @@
import pytest
from pytorch_lightning.utilities.parsing import lightning_getattr, lightning_hasattr, lightning_setattr
def _get_test_cases():
class TestHparamsNamespace:
learning_rate = 1
TestHparamsDict = {'learning_rate': 2}
class TestModel1: # test for namespace
learning_rate = 0
model1 = TestModel1()
class TestModel2: # test for hparams namespace
hparams = TestHparamsNamespace()
model2 = TestModel2()
class TestModel3: # test for hparams dict
hparams = TestHparamsDict
model3 = TestModel3()
class TestModel4: # fail case
batch_size = 1
model4 = TestModel4()
return model1, model2, model3, model4
def test_lightning_hasattr(tmpdir):
""" Test that the lightning_hasattr works in all cases"""
model1, model2, model3, model4 = _get_test_cases()
assert lightning_hasattr(model1, 'learning_rate'), \
'lightning_hasattr failed to find namespace variable'
assert lightning_hasattr(model2, 'learning_rate'), \
'lightning_hasattr failed to find hparams namespace variable'
assert lightning_hasattr(model3, 'learning_rate'), \
'lightning_hasattr failed to find hparams dict variable'
assert not lightning_hasattr(model4, 'learning_rate'), \
'lightning_hasattr found variable when it should not'
def test_lightning_getattr(tmpdir):
""" Test that the lightning_getattr works in all cases"""
models = _get_test_cases()
for i, m in enumerate(models[:3]):
value = lightning_getattr(m, 'learning_rate')
assert value == i, 'attribute not correctly extracted'
def test_lightning_setattr(tmpdir):
""" Test that the lightning_setattr works in all cases"""
models = _get_test_cases()
for m in models[:3]:
lightning_setattr(m, 'learning_rate', 10)
assert lightning_getattr(m, 'learning_rate') == 10, \
'attribute not correctly set'