From da5ba50727c6af80b65570ad77e8be59b255abcb Mon Sep 17 00:00:00 2001 From: Ryan Nett Date: Tue, 26 Jan 2021 03:57:17 +0100 Subject: [PATCH] Unify attribute finding logic, fix not using dataloader when hparams present (#4559) * Rebase onto master * indent fix * Remove duplicated logic * Use single return * Remove extra else * add `__contains__` to TestHparamsNamespace to fix tests * Fix lightning_setattr to set all valid attributes * update doc * better names * fix holder order preference * tests for new behavior * Comment about using the last holder Co-authored-by: Jirka Borovec Co-authored-by: Sean Naren (cherry picked from commit eee3b1a284470b3eca6ebd1815c0ddd7f550c667) --- pytorch_lightning/utilities/parsing.py | 104 ++++++++++++------------- tests/utilities/test_parsing.py | 39 +++++++++- 2 files changed, 84 insertions(+), 59 deletions(-) diff --git a/pytorch_lightning/utilities/parsing.py b/pytorch_lightning/utilities/parsing.py index 57568967ae..e55a96a55b 100644 --- a/pytorch_lightning/utilities/parsing.py +++ b/pytorch_lightning/utilities/parsing.py @@ -196,50 +196,56 @@ class AttributeDict(Dict): return out +def lightning_get_all_attr_holders(model, attribute): + """ Special attribute finding for lightning. Gets all of the objects or dicts that holds attribute. + Checks for attribute in model namespace, the old hparams namespace/dict, and the datamodule. """ + trainer = getattr(model, 'trainer', None) + + holders = [] + + # Check if attribute in model + if hasattr(model, attribute): + holders.append(model) + + # Check if attribute in model.hparams, either namespace or dict + if hasattr(model, 'hparams'): + if attribute in model.hparams: + holders.append(model.hparams) + + # Check if the attribute in datamodule (datamodule gets registered in Trainer) + if trainer is not None and trainer.datamodule is not None and hasattr(trainer.datamodule, attribute): + holders.append(trainer.datamodule) + + return holders + + +def lightning_get_first_attr_holder(model, attribute): + """ Special attribute finding for lightning. Gets the object or dict that holds attribute, or None. Checks for attribute in model namespace, + the old hparams namespace/dict, and the datamodule, returns the last one that has it. """ + holders = lightning_get_all_attr_holders(model, attribute) + if len(holders) == 0: + return None + # using the last holder to preserve backwards compatibility + return holders[-1] + + def lightning_hasattr(model, attribute): """ Special hasattr for lightning. Checks for attribute in model namespace, the old hparams namespace/dict, and the datamodule. """ - trainer = getattr(model, 'trainer', None) - - attr = False - # Check if attribute in model - if hasattr(model, attribute): - attr = True - # Check if attribute in model.hparams, either namespace or dict - elif hasattr(model, 'hparams'): - if isinstance(model.hparams, dict): - attr = attribute in model.hparams - else: - attr = hasattr(model.hparams, attribute) - # Check if the attribute in datamodule (datamodule gets registered in Trainer) - if not attr and trainer is not None: - attr = hasattr(trainer.datamodule, attribute) - - return attr + return lightning_get_first_attr_holder(model, attribute) is not None def lightning_getattr(model, attribute): """ Special getattr for lightning. Checks for attribute in model namespace, the old hparams namespace/dict, and the datamodule. """ - trainer = getattr(model, 'trainer', None) + holder = lightning_get_first_attr_holder(model, attribute) + if holder is None: + raise ValueError(f'{attribute} is neither stored in the model namespace' + ' nor the `hparams` namespace/dict, nor the datamodule.') - # Check if attribute in model - if hasattr(model, attribute): - attr = getattr(model, attribute) - # Check if attribute in model.hparams, either namespace or dict - elif hasattr(model, 'hparams') and isinstance(model.hparams, dict) and attribute in model.hparams: - attr = model.hparams[attribute] - elif hasattr(model, 'hparams') and hasattr(model.hparams, attribute): - attr = getattr(model.hparams, attribute) - # Check if the attribute in datamodule (datamodule gets registered in Trainer) - elif trainer is not None and trainer.datamodule is not None and hasattr(trainer.datamodule, attribute): - attr = getattr(trainer.datamodule, attribute) - else: - raise ValueError( - f'The {attribute} is neither stored in the model namespace nor the `hparams` namespace/dict,' - ' nor the datamodule.' - ) - return attr + if isinstance(holder, dict): + return holder[attribute] + return getattr(holder, attribute) def lightning_setattr(model, attribute, value): @@ -247,25 +253,13 @@ def lightning_setattr(model, attribute, value): and the old hparams namespace/dict. Will also set the attribute on datamodule, if it exists. """ - if not lightning_hasattr(model, attribute): - raise ValueError( - f'The {attribute} is neither stored in the model namespace nor the `hparams` namespace/dict,' - ' nor the datamodule.' - ) + holders = lightning_get_all_attr_holders(model, attribute) + if len(holders) == 0: + raise ValueError(f'{attribute} is neither stored in the model namespace' + ' nor the `hparams` namespace/dict, nor the datamodule.') - trainer = getattr(model, 'trainer', None) - - # Check if attribute in model - if hasattr(model, attribute): - setattr(model, attribute, value) - - # Check if attribute in model.hparams, either namespace or dict - elif hasattr(model, 'hparams'): - if isinstance(model.hparams, dict): - model.hparams[attribute] = value + for holder in holders: + if isinstance(holder, dict): + holder[attribute] = value else: - setattr(model.hparams, attribute, value) - - # Check if the attribute in datamodule (datamodule gets registered in Trainer) - if trainer is not None and trainer.datamodule is not None and hasattr(trainer.datamodule, attribute): - setattr(trainer.datamodule, attribute, value) + setattr(holder, attribute, value) diff --git a/tests/utilities/test_parsing.py b/tests/utilities/test_parsing.py index ba8f20d778..08e24d746f 100644 --- a/tests/utilities/test_parsing.py +++ b/tests/utilities/test_parsing.py @@ -19,6 +19,9 @@ def _get_test_cases(): class TestHparamsNamespace: learning_rate = 1 + def __contains__(self, item): + return item == "learning_rate" + TestHparamsDict = {'learning_rate': 2} class TestModel1: # test for namespace @@ -52,12 +55,26 @@ def _get_test_cases(): model5 = TestModel5() - return model1, model2, model3, model4, model5 + class TestModel6: # test for datamodule w/ hparams w/o attribute (should use datamodule) + trainer = Trainer + hparams = TestHparamsDict + + model6 = TestModel6() + + TestHparamsDict2 = {'batch_size': 2} + + class TestModel7: # test for datamodule w/ hparams w/ attribute (should use datamodule) + trainer = Trainer + hparams = TestHparamsDict2 + + model7 = TestModel7() + + return model1, model2, model3, model4, model5, model6, model7 def test_lightning_hasattr(tmpdir): """ Test that the lightning_hasattr works in all cases""" - model1, model2, model3, model4, model5 = _get_test_cases() + model1, model2, model3, model4, model5, model6, model7 = _get_test_cases() assert lightning_hasattr(model1, 'learning_rate'), \ 'lightning_hasattr failed to find namespace variable' assert lightning_hasattr(model2, 'learning_rate'), \ @@ -68,6 +85,10 @@ def test_lightning_hasattr(tmpdir): 'lightning_hasattr found variable when it should not' assert lightning_hasattr(model5, 'batch_size'), \ 'lightning_hasattr failed to find batch_size in datamodule' + assert lightning_hasattr(model6, 'batch_size'), \ + 'lightning_hasattr failed to find batch_size in datamodule w/ hparams present' + assert lightning_hasattr(model7, 'batch_size'), \ + 'lightning_hasattr failed to find batch_size in hparams w/ datamodule present' def test_lightning_getattr(tmpdir): @@ -77,9 +98,13 @@ def test_lightning_getattr(tmpdir): value = lightning_getattr(m, 'learning_rate') assert value == i, 'attribute not correctly extracted' - model5 = models[4] + model5, model6, model7 = models[4:] assert lightning_getattr(model5, 'batch_size') == 8, \ 'batch_size not correctly extracted' + assert lightning_getattr(model6, 'batch_size') == 8, \ + 'batch_size not correctly extracted' + assert lightning_getattr(model7, 'batch_size') == 8, \ + 'batch_size not correctly extracted' def test_lightning_setattr(tmpdir): @@ -90,7 +115,13 @@ def test_lightning_setattr(tmpdir): assert lightning_getattr(m, 'learning_rate') == 10, \ 'attribute not correctly set' - model5 = models[4] + model5, model6, model7 = models[4:] lightning_setattr(model5, 'batch_size', 128) + lightning_setattr(model6, 'batch_size', 128) + lightning_setattr(model7, 'batch_size', 128) assert lightning_getattr(model5, 'batch_size') == 128, \ 'batch_size not correctly set' + assert lightning_getattr(model6, 'batch_size') == 128, \ + 'batch_size not correctly set' + assert lightning_getattr(model7, 'batch_size') == 128, \ + 'batch_size not correctly set'