Unify attribute finding logic, fix not using dataloader when hparams present (#4559)
* Rebase onto master
* indent fix
* Remove duplicated logic
* Use single return
* Remove extra else
* add `__contains__` to TestHparamsNamespace to fix tests
* Fix lightning_setattr to set all valid attributes
* update doc
* better names
* fix holder order preference
* tests for new behavior
* Comment about using the last holder
Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>
Co-authored-by: Sean Naren <sean.narenthiran@gmail.com>
(cherry picked from commit eee3b1a284
)
This commit is contained in:
parent
cc38d4c342
commit
da5ba50727
|
@ -196,50 +196,56 @@ class AttributeDict(Dict):
|
|||
return out
|
||||
|
||||
|
||||
def lightning_get_all_attr_holders(model, attribute):
|
||||
""" Special attribute finding for lightning. Gets all of the objects or dicts that holds attribute.
|
||||
Checks for attribute in model namespace, the old hparams namespace/dict, and the datamodule. """
|
||||
trainer = getattr(model, 'trainer', None)
|
||||
|
||||
holders = []
|
||||
|
||||
# Check if attribute in model
|
||||
if hasattr(model, attribute):
|
||||
holders.append(model)
|
||||
|
||||
# Check if attribute in model.hparams, either namespace or dict
|
||||
if hasattr(model, 'hparams'):
|
||||
if attribute in model.hparams:
|
||||
holders.append(model.hparams)
|
||||
|
||||
# Check if the attribute in datamodule (datamodule gets registered in Trainer)
|
||||
if trainer is not None and trainer.datamodule is not None and hasattr(trainer.datamodule, attribute):
|
||||
holders.append(trainer.datamodule)
|
||||
|
||||
return holders
|
||||
|
||||
|
||||
def lightning_get_first_attr_holder(model, attribute):
|
||||
""" Special attribute finding for lightning. Gets the object or dict that holds attribute, or None. Checks for attribute in model namespace,
|
||||
the old hparams namespace/dict, and the datamodule, returns the last one that has it. """
|
||||
holders = lightning_get_all_attr_holders(model, attribute)
|
||||
if len(holders) == 0:
|
||||
return None
|
||||
# using the last holder to preserve backwards compatibility
|
||||
return holders[-1]
|
||||
|
||||
|
||||
def lightning_hasattr(model, attribute):
|
||||
""" Special hasattr for lightning. Checks for attribute in model namespace,
|
||||
the old hparams namespace/dict, and the datamodule. """
|
||||
trainer = getattr(model, 'trainer', None)
|
||||
|
||||
attr = False
|
||||
# Check if attribute in model
|
||||
if hasattr(model, attribute):
|
||||
attr = True
|
||||
# Check if attribute in model.hparams, either namespace or dict
|
||||
elif hasattr(model, 'hparams'):
|
||||
if isinstance(model.hparams, dict):
|
||||
attr = attribute in model.hparams
|
||||
else:
|
||||
attr = hasattr(model.hparams, attribute)
|
||||
# Check if the attribute in datamodule (datamodule gets registered in Trainer)
|
||||
if not attr and trainer is not None:
|
||||
attr = hasattr(trainer.datamodule, attribute)
|
||||
|
||||
return attr
|
||||
return lightning_get_first_attr_holder(model, attribute) is not None
|
||||
|
||||
|
||||
def lightning_getattr(model, attribute):
|
||||
""" Special getattr for lightning. Checks for attribute in model namespace,
|
||||
the old hparams namespace/dict, and the datamodule. """
|
||||
trainer = getattr(model, 'trainer', None)
|
||||
holder = lightning_get_first_attr_holder(model, attribute)
|
||||
if holder is None:
|
||||
raise ValueError(f'{attribute} is neither stored in the model namespace'
|
||||
' nor the `hparams` namespace/dict, nor the datamodule.')
|
||||
|
||||
# Check if attribute in model
|
||||
if hasattr(model, attribute):
|
||||
attr = getattr(model, attribute)
|
||||
# Check if attribute in model.hparams, either namespace or dict
|
||||
elif hasattr(model, 'hparams') and isinstance(model.hparams, dict) and attribute in model.hparams:
|
||||
attr = model.hparams[attribute]
|
||||
elif hasattr(model, 'hparams') and hasattr(model.hparams, attribute):
|
||||
attr = getattr(model.hparams, attribute)
|
||||
# Check if the attribute in datamodule (datamodule gets registered in Trainer)
|
||||
elif trainer is not None and trainer.datamodule is not None and hasattr(trainer.datamodule, attribute):
|
||||
attr = getattr(trainer.datamodule, attribute)
|
||||
else:
|
||||
raise ValueError(
|
||||
f'The {attribute} is neither stored in the model namespace nor the `hparams` namespace/dict,'
|
||||
' nor the datamodule.'
|
||||
)
|
||||
return attr
|
||||
if isinstance(holder, dict):
|
||||
return holder[attribute]
|
||||
return getattr(holder, attribute)
|
||||
|
||||
|
||||
def lightning_setattr(model, attribute, value):
|
||||
|
@ -247,25 +253,13 @@ def lightning_setattr(model, attribute, value):
|
|||
and the old hparams namespace/dict.
|
||||
Will also set the attribute on datamodule, if it exists.
|
||||
"""
|
||||
if not lightning_hasattr(model, attribute):
|
||||
raise ValueError(
|
||||
f'The {attribute} is neither stored in the model namespace nor the `hparams` namespace/dict,'
|
||||
' nor the datamodule.'
|
||||
)
|
||||
holders = lightning_get_all_attr_holders(model, attribute)
|
||||
if len(holders) == 0:
|
||||
raise ValueError(f'{attribute} is neither stored in the model namespace'
|
||||
' nor the `hparams` namespace/dict, nor the datamodule.')
|
||||
|
||||
trainer = getattr(model, 'trainer', None)
|
||||
|
||||
# Check if attribute in model
|
||||
if hasattr(model, attribute):
|
||||
setattr(model, attribute, value)
|
||||
|
||||
# Check if attribute in model.hparams, either namespace or dict
|
||||
elif hasattr(model, 'hparams'):
|
||||
if isinstance(model.hparams, dict):
|
||||
model.hparams[attribute] = value
|
||||
for holder in holders:
|
||||
if isinstance(holder, dict):
|
||||
holder[attribute] = value
|
||||
else:
|
||||
setattr(model.hparams, attribute, value)
|
||||
|
||||
# Check if the attribute in datamodule (datamodule gets registered in Trainer)
|
||||
if trainer is not None and trainer.datamodule is not None and hasattr(trainer.datamodule, attribute):
|
||||
setattr(trainer.datamodule, attribute, value)
|
||||
setattr(holder, attribute, value)
|
||||
|
|
|
@ -19,6 +19,9 @@ def _get_test_cases():
|
|||
class TestHparamsNamespace:
|
||||
learning_rate = 1
|
||||
|
||||
def __contains__(self, item):
|
||||
return item == "learning_rate"
|
||||
|
||||
TestHparamsDict = {'learning_rate': 2}
|
||||
|
||||
class TestModel1: # test for namespace
|
||||
|
@ -52,12 +55,26 @@ def _get_test_cases():
|
|||
|
||||
model5 = TestModel5()
|
||||
|
||||
return model1, model2, model3, model4, model5
|
||||
class TestModel6: # test for datamodule w/ hparams w/o attribute (should use datamodule)
|
||||
trainer = Trainer
|
||||
hparams = TestHparamsDict
|
||||
|
||||
model6 = TestModel6()
|
||||
|
||||
TestHparamsDict2 = {'batch_size': 2}
|
||||
|
||||
class TestModel7: # test for datamodule w/ hparams w/ attribute (should use datamodule)
|
||||
trainer = Trainer
|
||||
hparams = TestHparamsDict2
|
||||
|
||||
model7 = TestModel7()
|
||||
|
||||
return model1, model2, model3, model4, model5, model6, model7
|
||||
|
||||
|
||||
def test_lightning_hasattr(tmpdir):
|
||||
""" Test that the lightning_hasattr works in all cases"""
|
||||
model1, model2, model3, model4, model5 = _get_test_cases()
|
||||
model1, model2, model3, model4, model5, model6, model7 = _get_test_cases()
|
||||
assert lightning_hasattr(model1, 'learning_rate'), \
|
||||
'lightning_hasattr failed to find namespace variable'
|
||||
assert lightning_hasattr(model2, 'learning_rate'), \
|
||||
|
@ -68,6 +85,10 @@ def test_lightning_hasattr(tmpdir):
|
|||
'lightning_hasattr found variable when it should not'
|
||||
assert lightning_hasattr(model5, 'batch_size'), \
|
||||
'lightning_hasattr failed to find batch_size in datamodule'
|
||||
assert lightning_hasattr(model6, 'batch_size'), \
|
||||
'lightning_hasattr failed to find batch_size in datamodule w/ hparams present'
|
||||
assert lightning_hasattr(model7, 'batch_size'), \
|
||||
'lightning_hasattr failed to find batch_size in hparams w/ datamodule present'
|
||||
|
||||
|
||||
def test_lightning_getattr(tmpdir):
|
||||
|
@ -77,9 +98,13 @@ def test_lightning_getattr(tmpdir):
|
|||
value = lightning_getattr(m, 'learning_rate')
|
||||
assert value == i, 'attribute not correctly extracted'
|
||||
|
||||
model5 = models[4]
|
||||
model5, model6, model7 = models[4:]
|
||||
assert lightning_getattr(model5, 'batch_size') == 8, \
|
||||
'batch_size not correctly extracted'
|
||||
assert lightning_getattr(model6, 'batch_size') == 8, \
|
||||
'batch_size not correctly extracted'
|
||||
assert lightning_getattr(model7, 'batch_size') == 8, \
|
||||
'batch_size not correctly extracted'
|
||||
|
||||
|
||||
def test_lightning_setattr(tmpdir):
|
||||
|
@ -90,7 +115,13 @@ def test_lightning_setattr(tmpdir):
|
|||
assert lightning_getattr(m, 'learning_rate') == 10, \
|
||||
'attribute not correctly set'
|
||||
|
||||
model5 = models[4]
|
||||
model5, model6, model7 = models[4:]
|
||||
lightning_setattr(model5, 'batch_size', 128)
|
||||
lightning_setattr(model6, 'batch_size', 128)
|
||||
lightning_setattr(model7, 'batch_size', 128)
|
||||
assert lightning_getattr(model5, 'batch_size') == 128, \
|
||||
'batch_size not correctly set'
|
||||
assert lightning_getattr(model6, 'batch_size') == 128, \
|
||||
'batch_size not correctly set'
|
||||
assert lightning_getattr(model7, 'batch_size') == 128, \
|
||||
'batch_size not correctly set'
|
||||
|
|
Loading…
Reference in New Issue