Unify attribute finding logic, fix not using dataloader when hparams present (#4559)

* Rebase onto master

* indent fix

* Remove duplicated logic

* Use single return

* Remove extra else

* add `__contains__` to TestHparamsNamespace to fix tests

* Fix lightning_setattr to set all valid attributes

* update doc

* better names

* fix holder order preference

* tests for new behavior

* Comment about using the last holder

Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>
Co-authored-by: Sean Naren <sean.narenthiran@gmail.com>
This commit is contained in:
Ryan Nett 2021-01-25 18:57:17 -08:00 committed by GitHub
parent c76cc23b3d
commit eee3b1a284
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 81 additions and 52 deletions

View File

@ -198,48 +198,56 @@ class AttributeDict(Dict):
return out
def lightning_get_all_attr_holders(model, attribute):
""" Special attribute finding for lightning. Gets all of the objects or dicts that holds attribute.
Checks for attribute in model namespace, the old hparams namespace/dict, and the datamodule. """
trainer = getattr(model, 'trainer', None)
holders = []
# Check if attribute in model
if hasattr(model, attribute):
holders.append(model)
# Check if attribute in model.hparams, either namespace or dict
if hasattr(model, 'hparams'):
if attribute in model.hparams:
holders.append(model.hparams)
# Check if the attribute in datamodule (datamodule gets registered in Trainer)
if trainer is not None and trainer.datamodule is not None and hasattr(trainer.datamodule, attribute):
holders.append(trainer.datamodule)
return holders
def lightning_get_first_attr_holder(model, attribute):
""" Special attribute finding for lightning. Gets the object or dict that holds attribute, or None. Checks for attribute in model namespace,
the old hparams namespace/dict, and the datamodule, returns the last one that has it. """
holders = lightning_get_all_attr_holders(model, attribute)
if len(holders) == 0:
return None
# using the last holder to preserve backwards compatibility
return holders[-1]
def lightning_hasattr(model, attribute):
""" Special hasattr for lightning. Checks for attribute in model namespace,
the old hparams namespace/dict, and the datamodule. """
trainer = getattr(model, 'trainer', None)
attr = False
# Check if attribute in model
if hasattr(model, attribute):
attr = True
# Check if attribute in model.hparams, either namespace or dict
elif hasattr(model, 'hparams'):
if isinstance(model.hparams, dict):
attr = attribute in model.hparams
else:
attr = hasattr(model.hparams, attribute)
# Check if the attribute in datamodule (datamodule gets registered in Trainer)
if not attr and trainer is not None:
attr = hasattr(trainer.datamodule, attribute)
return attr
return lightning_get_first_attr_holder(model, attribute) is not None
def lightning_getattr(model, attribute):
""" Special getattr for lightning. Checks for attribute in model namespace,
the old hparams namespace/dict, and the datamodule. """
trainer = getattr(model, 'trainer', None)
# Check if attribute in model
if hasattr(model, attribute):
attr = getattr(model, attribute)
# Check if attribute in model.hparams, either namespace or dict
elif hasattr(model, 'hparams') and isinstance(model.hparams, dict) and attribute in model.hparams:
attr = model.hparams[attribute]
elif hasattr(model, 'hparams') and hasattr(model.hparams, attribute):
attr = getattr(model.hparams, attribute)
# Check if the attribute in datamodule (datamodule gets registered in Trainer)
elif trainer is not None and trainer.datamodule is not None and hasattr(trainer.datamodule, attribute):
attr = getattr(trainer.datamodule, attribute)
else:
holder = lightning_get_first_attr_holder(model, attribute)
if holder is None:
raise ValueError(f'{attribute} is neither stored in the model namespace'
' nor the `hparams` namespace/dict, nor the datamodule.')
return attr
if isinstance(holder, dict):
return holder[attribute]
return getattr(holder, attribute)
def lightning_setattr(model, attribute, value):
@ -247,23 +255,13 @@ def lightning_setattr(model, attribute, value):
and the old hparams namespace/dict.
Will also set the attribute on datamodule, if it exists.
"""
if not lightning_hasattr(model, attribute):
holders = lightning_get_all_attr_holders(model, attribute)
if len(holders) == 0:
raise ValueError(f'{attribute} is neither stored in the model namespace'
' nor the `hparams` namespace/dict, nor the datamodule.')
trainer = getattr(model, 'trainer', None)
# Check if attribute in model
if hasattr(model, attribute):
setattr(model, attribute, value)
# Check if attribute in model.hparams, either namespace or dict
elif hasattr(model, 'hparams'):
if isinstance(model.hparams, dict):
model.hparams[attribute] = value
for holder in holders:
if isinstance(holder, dict):
holder[attribute] = value
else:
setattr(model.hparams, attribute, value)
# Check if the attribute in datamodule (datamodule gets registered in Trainer)
if trainer is not None and trainer.datamodule is not None and hasattr(trainer.datamodule, attribute):
setattr(trainer.datamodule, attribute, value)
setattr(holder, attribute, value)

View File

@ -20,6 +20,9 @@ def _get_test_cases():
class TestHparamsNamespace:
learning_rate = 1
def __contains__(self, item):
return item == "learning_rate"
TestHparamsDict = {'learning_rate': 2}
class TestModel1: # test for namespace
@ -53,12 +56,26 @@ def _get_test_cases():
model5 = TestModel5()
return model1, model2, model3, model4, model5
class TestModel6: # test for datamodule w/ hparams w/o attribute (should use datamodule)
trainer = Trainer
hparams = TestHparamsDict
model6 = TestModel6()
TestHparamsDict2 = {'batch_size': 2}
class TestModel7: # test for datamodule w/ hparams w/ attribute (should use datamodule)
trainer = Trainer
hparams = TestHparamsDict2
model7 = TestModel7()
return model1, model2, model3, model4, model5, model6, model7
def test_lightning_hasattr(tmpdir):
""" Test that the lightning_hasattr works in all cases"""
model1, model2, model3, model4, model5 = _get_test_cases()
model1, model2, model3, model4, model5, model6, model7 = _get_test_cases()
assert lightning_hasattr(model1, 'learning_rate'), \
'lightning_hasattr failed to find namespace variable'
assert lightning_hasattr(model2, 'learning_rate'), \
@ -69,6 +86,10 @@ def test_lightning_hasattr(tmpdir):
'lightning_hasattr found variable when it should not'
assert lightning_hasattr(model5, 'batch_size'), \
'lightning_hasattr failed to find batch_size in datamodule'
assert lightning_hasattr(model6, 'batch_size'), \
'lightning_hasattr failed to find batch_size in datamodule w/ hparams present'
assert lightning_hasattr(model7, 'batch_size'), \
'lightning_hasattr failed to find batch_size in hparams w/ datamodule present'
def test_lightning_getattr(tmpdir):
@ -78,9 +99,13 @@ def test_lightning_getattr(tmpdir):
value = lightning_getattr(m, 'learning_rate')
assert value == i, 'attribute not correctly extracted'
model5 = models[4]
model5, model6, model7 = models[4:]
assert lightning_getattr(model5, 'batch_size') == 8, \
'batch_size not correctly extracted'
assert lightning_getattr(model6, 'batch_size') == 8, \
'batch_size not correctly extracted'
assert lightning_getattr(model7, 'batch_size') == 8, \
'batch_size not correctly extracted'
def test_lightning_setattr(tmpdir):
@ -91,7 +116,13 @@ def test_lightning_setattr(tmpdir):
assert lightning_getattr(m, 'learning_rate') == 10, \
'attribute not correctly set'
model5 = models[4]
model5, model6, model7 = models[4:]
lightning_setattr(model5, 'batch_size', 128)
lightning_setattr(model6, 'batch_size', 128)
lightning_setattr(model7, 'batch_size', 128)
assert lightning_getattr(model5, 'batch_size') == 128, \
'batch_size not correctly set'
assert lightning_getattr(model6, 'batch_size') == 128, \
'batch_size not correctly set'
assert lightning_getattr(model7, 'batch_size') == 128, \
'batch_size not correctly set'