Raise AttributeError in lightning_getattr and lightning_setattr when attribute not found (#6024)
* Empty commit * Raise AttributeError instead of ValueError * Make functions private * Update tests * Add match string * Apply suggestions from code review Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com> * lightning to Lightning Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com>
This commit is contained in:
parent
b0074a471a
commit
8f82823a08
|
@ -196,9 +196,11 @@ class AttributeDict(Dict):
|
|||
return out
|
||||
|
||||
|
||||
def lightning_get_all_attr_holders(model, attribute):
|
||||
""" Special attribute finding for lightning. Gets all of the objects or dicts that holds attribute.
|
||||
Checks for attribute in model namespace, the old hparams namespace/dict, and the datamodule. """
|
||||
def _lightning_get_all_attr_holders(model, attribute):
|
||||
"""
|
||||
Special attribute finding for Lightning. Gets all of the objects or dicts that holds attribute.
|
||||
Checks for attribute in model namespace, the old hparams namespace/dict, and the datamodule.
|
||||
"""
|
||||
trainer = getattr(model, 'trainer', None)
|
||||
|
||||
holders = []
|
||||
|
@ -219,13 +221,13 @@ def lightning_get_all_attr_holders(model, attribute):
|
|||
return holders
|
||||
|
||||
|
||||
def lightning_get_first_attr_holder(model, attribute):
|
||||
def _lightning_get_first_attr_holder(model, attribute):
|
||||
"""
|
||||
Special attribute finding for lightning. Gets the object or dict that holds attribute, or None.
|
||||
Checks for attribute in model namespace, the old hparams namespace/dict, and the datamodule,
|
||||
returns the last one that has it.
|
||||
"""
|
||||
holders = lightning_get_all_attr_holders(model, attribute)
|
||||
Special attribute finding for Lightning. Gets the object or dict that holds attribute, or None.
|
||||
Checks for attribute in model namespace, the old hparams namespace/dict, and the datamodule,
|
||||
returns the last one that has it.
|
||||
"""
|
||||
holders = _lightning_get_all_attr_holders(model, attribute)
|
||||
if len(holders) == 0:
|
||||
return None
|
||||
# using the last holder to preserve backwards compatibility
|
||||
|
@ -233,17 +235,26 @@ def lightning_get_first_attr_holder(model, attribute):
|
|||
|
||||
|
||||
def lightning_hasattr(model, attribute):
|
||||
""" Special hasattr for lightning. Checks for attribute in model namespace,
|
||||
the old hparams namespace/dict, and the datamodule. """
|
||||
return lightning_get_first_attr_holder(model, attribute) is not None
|
||||
"""
|
||||
Special hasattr for Lightning. Checks for attribute in model namespace,
|
||||
the old hparams namespace/dict, and the datamodule.
|
||||
"""
|
||||
return _lightning_get_first_attr_holder(model, attribute) is not None
|
||||
|
||||
|
||||
def lightning_getattr(model, attribute):
|
||||
""" Special getattr for lightning. Checks for attribute in model namespace,
|
||||
the old hparams namespace/dict, and the datamodule. """
|
||||
holder = lightning_get_first_attr_holder(model, attribute)
|
||||
"""
|
||||
Special getattr for Lightning. Checks for attribute in model namespace,
|
||||
the old hparams namespace/dict, and the datamodule.
|
||||
|
||||
Raises:
|
||||
AttributeError:
|
||||
If ``model`` doesn't have ``attribute`` in any of
|
||||
model namespace, the hparams namespace/dict, and the datamodule.
|
||||
"""
|
||||
holder = _lightning_get_first_attr_holder(model, attribute)
|
||||
if holder is None:
|
||||
raise ValueError(
|
||||
raise AttributeError(
|
||||
f'{attribute} is neither stored in the model namespace'
|
||||
' nor the `hparams` namespace/dict, nor the datamodule.'
|
||||
)
|
||||
|
@ -254,13 +265,19 @@ def lightning_getattr(model, attribute):
|
|||
|
||||
|
||||
def lightning_setattr(model, attribute, value):
|
||||
""" Special setattr for lightning. Checks for attribute in model namespace
|
||||
and the old hparams namespace/dict.
|
||||
Will also set the attribute on datamodule, if it exists.
|
||||
"""
|
||||
holders = lightning_get_all_attr_holders(model, attribute)
|
||||
Special setattr for Lightning. Checks for attribute in model namespace
|
||||
and the old hparams namespace/dict.
|
||||
Will also set the attribute on datamodule, if it exists.
|
||||
|
||||
Raises:
|
||||
AttributeError:
|
||||
If ``model`` doesn't have ``attribute`` in any of
|
||||
model namespace, the hparams namespace/dict, and the datamodule.
|
||||
"""
|
||||
holders = _lightning_get_all_attr_holders(model, attribute)
|
||||
if len(holders) == 0:
|
||||
raise ValueError(
|
||||
raise AttributeError(
|
||||
f'{attribute} is neither stored in the model namespace'
|
||||
' nor the `hparams` namespace/dict, nor the datamodule.'
|
||||
)
|
||||
|
|
|
@ -11,6 +11,7 @@
|
|||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import pytest
|
||||
|
||||
from pytorch_lightning.utilities.parsing import lightning_getattr, lightning_hasattr, lightning_setattr
|
||||
|
||||
|
@ -74,8 +75,8 @@ def _get_test_cases():
|
|||
|
||||
|
||||
def test_lightning_hasattr(tmpdir):
|
||||
""" Test that the lightning_hasattr works in all cases"""
|
||||
model1, model2, model3, model4, model5, model6, model7 = _get_test_cases()
|
||||
"""Test that the lightning_hasattr works in all cases"""
|
||||
model1, model2, model3, model4, model5, model6, model7 = models = _get_test_cases()
|
||||
assert lightning_hasattr(model1, 'learning_rate'), \
|
||||
'lightning_hasattr failed to find namespace variable'
|
||||
assert lightning_hasattr(model2, 'learning_rate'), \
|
||||
|
@ -91,9 +92,12 @@ def test_lightning_hasattr(tmpdir):
|
|||
assert lightning_hasattr(model7, 'batch_size'), \
|
||||
'lightning_hasattr failed to find batch_size in hparams w/ datamodule present'
|
||||
|
||||
for m in models:
|
||||
assert not lightning_hasattr(m, "this_attr_not_exist")
|
||||
|
||||
|
||||
def test_lightning_getattr(tmpdir):
|
||||
""" Test that the lightning_getattr works in all cases"""
|
||||
"""Test that the lightning_getattr works in all cases"""
|
||||
models = _get_test_cases()
|
||||
for i, m in enumerate(models[:3]):
|
||||
value = lightning_getattr(m, 'learning_rate')
|
||||
|
@ -107,9 +111,16 @@ def test_lightning_getattr(tmpdir):
|
|||
assert lightning_getattr(model7, 'batch_size') == 8, \
|
||||
'batch_size not correctly extracted'
|
||||
|
||||
for m in models:
|
||||
with pytest.raises(
|
||||
AttributeError,
|
||||
match="is neither stored in the model namespace nor the `hparams` namespace/dict, nor the datamodule."
|
||||
):
|
||||
lightning_getattr(m, "this_attr_not_exist")
|
||||
|
||||
|
||||
def test_lightning_setattr(tmpdir):
|
||||
""" Test that the lightning_setattr works in all cases"""
|
||||
"""Test that the lightning_setattr works in all cases"""
|
||||
models = _get_test_cases()
|
||||
for m in models[:3]:
|
||||
lightning_setattr(m, 'learning_rate', 10)
|
||||
|
@ -126,3 +137,10 @@ def test_lightning_setattr(tmpdir):
|
|||
'batch_size not correctly set'
|
||||
assert lightning_getattr(model7, 'batch_size') == 128, \
|
||||
'batch_size not correctly set'
|
||||
|
||||
for m in models:
|
||||
with pytest.raises(
|
||||
AttributeError,
|
||||
match="is neither stored in the model namespace nor the `hparams` namespace/dict, nor the datamodule."
|
||||
):
|
||||
lightning_setattr(m, "this_attr_not_exist", None)
|
||||
|
|
Loading…
Reference in New Issue