Raise AttributeError in lightning_getattr and lightning_setattr when attribute not found (#6024)

* Empty commit

* Raise AttributeError instead of ValueError

* Make functions private

* Update tests

* Add match string

* Apply suggestions from code review

Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com>

* lightning to Lightning

Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com>
This commit is contained in:
Akihiro Nitta 2021-02-19 05:01:28 +09:00 committed by GitHub
parent b0074a471a
commit 8f82823a08
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 60 additions and 25 deletions

View File

@ -196,9 +196,11 @@ class AttributeDict(Dict):
return out
def lightning_get_all_attr_holders(model, attribute):
""" Special attribute finding for lightning. Gets all of the objects or dicts that holds attribute.
Checks for attribute in model namespace, the old hparams namespace/dict, and the datamodule. """
def _lightning_get_all_attr_holders(model, attribute):
"""
Special attribute finding for Lightning. Gets all of the objects or dicts that holds attribute.
Checks for attribute in model namespace, the old hparams namespace/dict, and the datamodule.
"""
trainer = getattr(model, 'trainer', None)
holders = []
@ -219,13 +221,13 @@ def lightning_get_all_attr_holders(model, attribute):
return holders
def lightning_get_first_attr_holder(model, attribute):
def _lightning_get_first_attr_holder(model, attribute):
"""
Special attribute finding for lightning. Gets the object or dict that holds attribute, or None.
Checks for attribute in model namespace, the old hparams namespace/dict, and the datamodule,
returns the last one that has it.
"""
holders = lightning_get_all_attr_holders(model, attribute)
Special attribute finding for Lightning. Gets the object or dict that holds attribute, or None.
Checks for attribute in model namespace, the old hparams namespace/dict, and the datamodule,
returns the last one that has it.
"""
holders = _lightning_get_all_attr_holders(model, attribute)
if len(holders) == 0:
return None
# using the last holder to preserve backwards compatibility
@ -233,17 +235,26 @@ def lightning_get_first_attr_holder(model, attribute):
def lightning_hasattr(model, attribute):
""" Special hasattr for lightning. Checks for attribute in model namespace,
the old hparams namespace/dict, and the datamodule. """
return lightning_get_first_attr_holder(model, attribute) is not None
"""
Special hasattr for Lightning. Checks for attribute in model namespace,
the old hparams namespace/dict, and the datamodule.
"""
return _lightning_get_first_attr_holder(model, attribute) is not None
def lightning_getattr(model, attribute):
""" Special getattr for lightning. Checks for attribute in model namespace,
the old hparams namespace/dict, and the datamodule. """
holder = lightning_get_first_attr_holder(model, attribute)
"""
Special getattr for Lightning. Checks for attribute in model namespace,
the old hparams namespace/dict, and the datamodule.
Raises:
AttributeError:
If ``model`` doesn't have ``attribute`` in any of
model namespace, the hparams namespace/dict, and the datamodule.
"""
holder = _lightning_get_first_attr_holder(model, attribute)
if holder is None:
raise ValueError(
raise AttributeError(
f'{attribute} is neither stored in the model namespace'
' nor the `hparams` namespace/dict, nor the datamodule.'
)
@ -254,13 +265,19 @@ def lightning_getattr(model, attribute):
def lightning_setattr(model, attribute, value):
""" Special setattr for lightning. Checks for attribute in model namespace
and the old hparams namespace/dict.
Will also set the attribute on datamodule, if it exists.
"""
holders = lightning_get_all_attr_holders(model, attribute)
Special setattr for Lightning. Checks for attribute in model namespace
and the old hparams namespace/dict.
Will also set the attribute on datamodule, if it exists.
Raises:
AttributeError:
If ``model`` doesn't have ``attribute`` in any of
model namespace, the hparams namespace/dict, and the datamodule.
"""
holders = _lightning_get_all_attr_holders(model, attribute)
if len(holders) == 0:
raise ValueError(
raise AttributeError(
f'{attribute} is neither stored in the model namespace'
' nor the `hparams` namespace/dict, nor the datamodule.'
)

View File

@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import pytest
from pytorch_lightning.utilities.parsing import lightning_getattr, lightning_hasattr, lightning_setattr
@ -74,8 +75,8 @@ def _get_test_cases():
def test_lightning_hasattr(tmpdir):
""" Test that the lightning_hasattr works in all cases"""
model1, model2, model3, model4, model5, model6, model7 = _get_test_cases()
"""Test that the lightning_hasattr works in all cases"""
model1, model2, model3, model4, model5, model6, model7 = models = _get_test_cases()
assert lightning_hasattr(model1, 'learning_rate'), \
'lightning_hasattr failed to find namespace variable'
assert lightning_hasattr(model2, 'learning_rate'), \
@ -91,9 +92,12 @@ def test_lightning_hasattr(tmpdir):
assert lightning_hasattr(model7, 'batch_size'), \
'lightning_hasattr failed to find batch_size in hparams w/ datamodule present'
for m in models:
assert not lightning_hasattr(m, "this_attr_not_exist")
def test_lightning_getattr(tmpdir):
""" Test that the lightning_getattr works in all cases"""
"""Test that the lightning_getattr works in all cases"""
models = _get_test_cases()
for i, m in enumerate(models[:3]):
value = lightning_getattr(m, 'learning_rate')
@ -107,9 +111,16 @@ def test_lightning_getattr(tmpdir):
assert lightning_getattr(model7, 'batch_size') == 8, \
'batch_size not correctly extracted'
for m in models:
with pytest.raises(
AttributeError,
match="is neither stored in the model namespace nor the `hparams` namespace/dict, nor the datamodule."
):
lightning_getattr(m, "this_attr_not_exist")
def test_lightning_setattr(tmpdir):
""" Test that the lightning_setattr works in all cases"""
"""Test that the lightning_setattr works in all cases"""
models = _get_test_cases()
for m in models[:3]:
lightning_setattr(m, 'learning_rate', 10)
@ -126,3 +137,10 @@ def test_lightning_setattr(tmpdir):
'batch_size not correctly set'
assert lightning_getattr(model7, 'batch_size') == 128, \
'batch_size not correctly set'
for m in models:
with pytest.raises(
AttributeError,
match="is neither stored in the model namespace nor the `hparams` namespace/dict, nor the datamodule."
):
lightning_setattr(m, "this_attr_not_exist", None)