Bugfix: Lr finder and hparams compatibility (#2821)
* fix hparams lr finder bug * add tests for new functions * better tests * fix codefactor * fix styling * fix tests * fix codefactor * Apply suggestions from code review * modified hook Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> Co-authored-by: William Falcon <waf2107@columbia.edu> Co-authored-by: mergify[bot] <37929162+mergify[bot]@users.noreply.github.com>
This commit is contained in:
parent
1dc411fc53
commit
9a402461da
|
@ -96,6 +96,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
|
|||
|
||||
- Fixed shell injection vulnerability in subprocess call ([#2786](https://github.com/PyTorchLightning/pytorch-lightning/pull/2786))
|
||||
|
||||
- Fixed LR finder and `hparams` compatibility ([#2821](https://github.com/PyTorchLightning/pytorch-lightning/pull/2821))
|
||||
|
||||
## [0.8.5] - 2020-07-09
|
||||
|
||||
### Added
|
||||
|
|
|
@ -24,7 +24,7 @@ from pytorch_lightning.callbacks import Callback
|
|||
from pytorch_lightning.loggers.base import DummyLogger
|
||||
from pytorch_lightning import _logger as log
|
||||
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
||||
from pytorch_lightning.utilities import rank_zero_warn
|
||||
from pytorch_lightning.utilities.parsing import lightning_hasattr, lightning_setattr
|
||||
|
||||
|
||||
class TrainerLRFinderMixin(ABC):
|
||||
|
@ -57,24 +57,26 @@ class TrainerLRFinderMixin(ABC):
|
|||
""" Call lr finder internally during Trainer.fit() """
|
||||
lr_finder = self.lr_find(model)
|
||||
lr = lr_finder.suggestion()
|
||||
|
||||
# TODO: log lr.results to self.logger
|
||||
if isinstance(self.auto_lr_find, str):
|
||||
# Try to find requested field, may be nested
|
||||
if _nested_hasattr(model, self.auto_lr_find):
|
||||
_nested_setattr(model, self.auto_lr_find, lr)
|
||||
if lightning_hasattr(model, self.auto_lr_find):
|
||||
lightning_setattr(model, self.auto_lr_find, lr)
|
||||
else:
|
||||
raise MisconfigurationException(
|
||||
f'`auto_lr_find` was set to {self.auto_lr_find}, however'
|
||||
' could not find this as a field in `model.hparams`.')
|
||||
' could not find this as a field in `model` or `model.hparams`.')
|
||||
else:
|
||||
if hasattr(model, 'lr'):
|
||||
model.lr = lr
|
||||
elif hasattr(model, 'learning_rate'):
|
||||
model.learning_rate = lr
|
||||
if lightning_hasattr(model, 'lr'):
|
||||
lightning_setattr(model, 'lr', lr)
|
||||
elif lightning_hasattr(model, 'learning_rate'):
|
||||
lightning_setattr(model, 'learning_rate', lr)
|
||||
else:
|
||||
raise MisconfigurationException(
|
||||
'When auto_lr_find is set to True, expects that hparams'
|
||||
' either has field `lr` or `learning_rate` that can overridden')
|
||||
'When auto_lr_find is set to True, expects that `model` or'
|
||||
' `model.hparams` either has field `lr` or `learning_rate`'
|
||||
' that can overridden')
|
||||
log.info(f'Learning rate set to {lr}')
|
||||
|
||||
def lr_find(
|
||||
|
@ -492,22 +494,3 @@ class _ExponentialLR(_LRScheduler):
|
|||
@property
|
||||
def lr(self):
|
||||
return self._lr
|
||||
|
||||
|
||||
def _nested_hasattr(obj, path):
|
||||
parts = path.split(".")
|
||||
for part in parts:
|
||||
if hasattr(obj, part):
|
||||
obj = getattr(obj, part)
|
||||
else:
|
||||
return False
|
||||
else:
|
||||
return True
|
||||
|
||||
|
||||
def _nested_setattr(obj, path, val):
|
||||
parts = path.split(".")
|
||||
for part in parts[:-1]:
|
||||
if hasattr(obj, part):
|
||||
obj = getattr(obj, part)
|
||||
setattr(obj, parts[-1], val)
|
||||
|
|
|
@ -140,3 +140,56 @@ class AttributeDict(Dict):
|
|||
rows = [tmp_name.format(f'"{n}":', self[n]) for n in sorted(self.keys())]
|
||||
out = '\n'.join(rows)
|
||||
return out
|
||||
|
||||
|
||||
def lightning_hasattr(model, attribute):
|
||||
""" Special hasattr for lightning. Checks for attribute in model namespace
|
||||
and the old hparams namespace/dict """
|
||||
# Check if attribute in model
|
||||
if hasattr(model, attribute):
|
||||
attr = True
|
||||
# Check if attribute in model.hparams, either namespace or dict
|
||||
elif hasattr(model, 'hparams'):
|
||||
if isinstance(model.hparams, dict):
|
||||
attr = attribute in model.hparams
|
||||
else:
|
||||
attr = hasattr(model.hparams, attribute)
|
||||
else:
|
||||
attr = False
|
||||
|
||||
return attr
|
||||
|
||||
|
||||
def lightning_getattr(model, attribute):
|
||||
""" Special getattr for lightning. Checks for attribute in model namespace
|
||||
and the old hparams namespace/dict """
|
||||
# Check if attribute in model
|
||||
if hasattr(model, attribute):
|
||||
attr = getattr(model, attribute)
|
||||
# Check if attribute in model.hparams, either namespace or dict
|
||||
elif hasattr(model, 'hparams'):
|
||||
if isinstance(model.hparams, dict):
|
||||
attr = model.hparams[attribute]
|
||||
else:
|
||||
attr = getattr(model.hparams, attribute)
|
||||
else:
|
||||
raise ValueError(f'{attribute} is not stored in the model namespace'
|
||||
' or the `hparams` namespace/dict.')
|
||||
return attr
|
||||
|
||||
|
||||
def lightning_setattr(model, attribute, value):
|
||||
""" Special setattr for lightning. Checks for attribute in model namespace
|
||||
and the old hparams namespace/dict """
|
||||
# Check if attribute in model
|
||||
if hasattr(model, attribute):
|
||||
setattr(model, attribute, value)
|
||||
# Check if attribute in model.hparams, either namespace or dict
|
||||
elif hasattr(model, 'hparams'):
|
||||
if isinstance(model.hparams, dict):
|
||||
model.hparams[attribute] = value
|
||||
else:
|
||||
setattr(model.hparams, attribute, value)
|
||||
else:
|
||||
raise ValueError(f'{attribute} is not stored in the model namespace'
|
||||
' or the `hparams` namespace/dict.')
|
||||
|
|
|
@ -70,3 +70,7 @@ class ConfigureOptimizersPool(ABC):
|
|||
optimizer = optim.Adam(param_groups)
|
||||
lr_scheduler = optim.lr_scheduler.StepLR(optimizer, 1, gamma=0.1)
|
||||
return [optimizer], [lr_scheduler]
|
||||
|
||||
def configure_optimizers__lr_from_hparams(self):
|
||||
optimizer = optim.Adam(self.parameters(), lr=self.hparams.learning_rate)
|
||||
return optimizer
|
||||
|
|
|
@ -4,7 +4,7 @@ from argparse import ArgumentParser
|
|||
import pytest
|
||||
import torch
|
||||
|
||||
from pytorch_lightning import Trainer
|
||||
from pytorch_lightning import Trainer, seed_everything
|
||||
from tests.base import EvalModelTemplate
|
||||
from tests.base.datamodules import TrialMNISTDataModule
|
||||
from tests.base.develop_utils import reset_seed
|
||||
|
@ -291,7 +291,7 @@ def test_full_loop_ddp_spawn(tmpdir):
|
|||
import os
|
||||
os.environ['CUDA_VISIBLE_DEVICES'] = '0,1'
|
||||
|
||||
reset_seed()
|
||||
seed_everything(1234)
|
||||
|
||||
dm = TrialMNISTDataModule(tmpdir)
|
||||
|
||||
|
|
|
@ -73,11 +73,15 @@ def test_trainer_reset_correctly(tmpdir):
|
|||
f'Attribute {key} was not reset correctly after learning rate finder'
|
||||
|
||||
|
||||
def test_trainer_arg_bool(tmpdir):
|
||||
@pytest.mark.parametrize('use_hparams', [False, True])
|
||||
def test_trainer_arg_bool(tmpdir, use_hparams):
|
||||
""" Test that setting trainer arg to bool works """
|
||||
hparams = EvalModelTemplate.get_default_hparams()
|
||||
model = EvalModelTemplate(**hparams)
|
||||
before_lr = hparams.get('learning_rate')
|
||||
if use_hparams:
|
||||
del model.learning_rate
|
||||
model.configure_optimizers = model.configure_optimizers__lr_from_hparams
|
||||
|
||||
# logger file to get meta
|
||||
trainer = Trainer(
|
||||
|
@ -87,17 +91,27 @@ def test_trainer_arg_bool(tmpdir):
|
|||
)
|
||||
|
||||
trainer.fit(model)
|
||||
after_lr = model.learning_rate
|
||||
if use_hparams:
|
||||
after_lr = model.hparams.learning_rate
|
||||
else:
|
||||
after_lr = model.learning_rate
|
||||
|
||||
assert before_lr != after_lr, \
|
||||
'Learning rate was not altered after running learning rate finder'
|
||||
|
||||
|
||||
def test_trainer_arg_str(tmpdir):
|
||||
@pytest.mark.parametrize('use_hparams', [False, True])
|
||||
def test_trainer_arg_str(tmpdir, use_hparams):
|
||||
""" Test that setting trainer arg to string works """
|
||||
model = EvalModelTemplate()
|
||||
hparams = EvalModelTemplate.get_default_hparams()
|
||||
model = EvalModelTemplate(**hparams)
|
||||
model.my_fancy_lr = 1.0 # update with non-standard field
|
||||
|
||||
model.hparams['my_fancy_lr'] = 1.0
|
||||
before_lr = model.my_fancy_lr
|
||||
if use_hparams:
|
||||
del model.my_fancy_lr
|
||||
model.configure_optimizers = model.configure_optimizers__lr_from_hparams
|
||||
|
||||
# logger file to get meta
|
||||
trainer = Trainer(
|
||||
default_root_dir=tmpdir,
|
||||
|
@ -106,7 +120,11 @@ def test_trainer_arg_str(tmpdir):
|
|||
)
|
||||
|
||||
trainer.fit(model)
|
||||
after_lr = model.my_fancy_lr
|
||||
if use_hparams:
|
||||
after_lr = model.hparams.my_fancy_lr
|
||||
else:
|
||||
after_lr = model.my_fancy_lr
|
||||
|
||||
assert before_lr != after_lr, \
|
||||
'Learning rate was not altered after running learning rate finder'
|
||||
|
||||
|
|
|
@ -0,0 +1,61 @@
|
|||
import pytest
|
||||
|
||||
from pytorch_lightning.utilities.parsing import lightning_getattr, lightning_hasattr, lightning_setattr
|
||||
|
||||
|
||||
def _get_test_cases():
|
||||
class TestHparamsNamespace:
|
||||
learning_rate = 1
|
||||
|
||||
TestHparamsDict = {'learning_rate': 2}
|
||||
|
||||
class TestModel1: # test for namespace
|
||||
learning_rate = 0
|
||||
model1 = TestModel1()
|
||||
|
||||
class TestModel2: # test for hparams namespace
|
||||
hparams = TestHparamsNamespace()
|
||||
|
||||
model2 = TestModel2()
|
||||
|
||||
class TestModel3: # test for hparams dict
|
||||
hparams = TestHparamsDict
|
||||
|
||||
model3 = TestModel3()
|
||||
|
||||
class TestModel4: # fail case
|
||||
batch_size = 1
|
||||
|
||||
model4 = TestModel4()
|
||||
|
||||
return model1, model2, model3, model4
|
||||
|
||||
|
||||
def test_lightning_hasattr(tmpdir):
|
||||
""" Test that the lightning_hasattr works in all cases"""
|
||||
model1, model2, model3, model4 = _get_test_cases()
|
||||
assert lightning_hasattr(model1, 'learning_rate'), \
|
||||
'lightning_hasattr failed to find namespace variable'
|
||||
assert lightning_hasattr(model2, 'learning_rate'), \
|
||||
'lightning_hasattr failed to find hparams namespace variable'
|
||||
assert lightning_hasattr(model3, 'learning_rate'), \
|
||||
'lightning_hasattr failed to find hparams dict variable'
|
||||
assert not lightning_hasattr(model4, 'learning_rate'), \
|
||||
'lightning_hasattr found variable when it should not'
|
||||
|
||||
|
||||
def test_lightning_getattr(tmpdir):
|
||||
""" Test that the lightning_getattr works in all cases"""
|
||||
models = _get_test_cases()
|
||||
for i, m in enumerate(models[:3]):
|
||||
value = lightning_getattr(m, 'learning_rate')
|
||||
assert value == i, 'attribute not correctly extracted'
|
||||
|
||||
|
||||
def test_lightning_setattr(tmpdir):
|
||||
""" Test that the lightning_setattr works in all cases"""
|
||||
models = _get_test_cases()
|
||||
for m in models[:3]:
|
||||
lightning_setattr(m, 'learning_rate', 10)
|
||||
assert lightning_getattr(m, 'learning_rate') == 10, \
|
||||
'attribute not correctly set'
|
Loading…
Reference in New Issue