lightning/tests/utilities/test_parsing.py

297 lines
9.4 KiB
Python

# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import inspect
import pytest
from torch.jit import ScriptModule
from pytorch_lightning.utilities.parsing import (
AttributeDict,
clean_namespace,
collect_init_args,
flatten_dict,
get_init_args,
is_picklable,
lightning_getattr,
lightning_hasattr,
lightning_setattr,
parse_class_init_keys,
str_to_bool,
str_to_bool_or_int,
str_to_bool_or_str,
)
unpicklable_function = lambda: None
@pytest.fixture(scope="module")
def model_cases():
class TestHparamsNamespace:
learning_rate = 1
def __contains__(self, item):
return item == "learning_rate"
TestHparamsDict = {"learning_rate": 2}
class TestModel1: # test for namespace
learning_rate = 0
model1 = TestModel1()
class TestModel2: # test for hparams namespace
hparams = TestHparamsNamespace()
model2 = TestModel2()
class TestModel3: # test for hparams dict
hparams = TestHparamsDict
model3 = TestModel3()
class TestModel4: # fail case
batch_size = 1
model4 = TestModel4()
class DataModule:
batch_size = 8
class Trainer:
datamodule = DataModule
class TestModel5: # test for datamodule
trainer = Trainer
model5 = TestModel5()
class TestModel6: # test for datamodule w/ hparams w/o attribute (should use datamodule)
trainer = Trainer
hparams = TestHparamsDict
model6 = TestModel6()
TestHparamsDict2 = {"batch_size": 2}
class TestModel7: # test for datamodule w/ hparams w/ attribute (should use datamodule)
trainer = Trainer
hparams = TestHparamsDict2
model7 = TestModel7()
return model1, model2, model3, model4, model5, model6, model7
def test_lightning_hasattr(tmpdir, model_cases):
"""Test that the lightning_hasattr works in all cases"""
model1, model2, model3, model4, model5, model6, model7 = models = model_cases
assert lightning_hasattr(model1, "learning_rate"), "lightning_hasattr failed to find namespace variable"
assert lightning_hasattr(model2, "learning_rate"), "lightning_hasattr failed to find hparams namespace variable"
assert lightning_hasattr(model3, "learning_rate"), "lightning_hasattr failed to find hparams dict variable"
assert not lightning_hasattr(model4, "learning_rate"), "lightning_hasattr found variable when it should not"
assert lightning_hasattr(model5, "batch_size"), "lightning_hasattr failed to find batch_size in datamodule"
assert lightning_hasattr(
model6, "batch_size"
), "lightning_hasattr failed to find batch_size in datamodule w/ hparams present"
assert lightning_hasattr(
model7, "batch_size"
), "lightning_hasattr failed to find batch_size in hparams w/ datamodule present"
for m in models:
assert not lightning_hasattr(m, "this_attr_not_exist")
def test_lightning_getattr(tmpdir, model_cases):
"""Test that the lightning_getattr works in all cases"""
models = model_cases
for i, m in enumerate(models[:3]):
value = lightning_getattr(m, "learning_rate")
assert value == i, "attribute not correctly extracted"
model5, model6, model7 = models[4:]
assert lightning_getattr(model5, "batch_size") == 8, "batch_size not correctly extracted"
assert lightning_getattr(model6, "batch_size") == 8, "batch_size not correctly extracted"
assert lightning_getattr(model7, "batch_size") == 8, "batch_size not correctly extracted"
for m in models:
with pytest.raises(
AttributeError,
match="is neither stored in the model namespace nor the `hparams` namespace/dict, nor the datamodule.",
):
lightning_getattr(m, "this_attr_not_exist")
def test_lightning_setattr(tmpdir, model_cases):
"""Test that the lightning_setattr works in all cases"""
models = model_cases
for m in models[:3]:
lightning_setattr(m, "learning_rate", 10)
assert lightning_getattr(m, "learning_rate") == 10, "attribute not correctly set"
model5, model6, model7 = models[4:]
lightning_setattr(model5, "batch_size", 128)
lightning_setattr(model6, "batch_size", 128)
lightning_setattr(model7, "batch_size", 128)
assert lightning_getattr(model5, "batch_size") == 128, "batch_size not correctly set"
assert lightning_getattr(model6, "batch_size") == 128, "batch_size not correctly set"
assert lightning_getattr(model7, "batch_size") == 128, "batch_size not correctly set"
for m in models:
with pytest.raises(
AttributeError,
match="is neither stored in the model namespace nor the `hparams` namespace/dict, nor the datamodule.",
):
lightning_setattr(m, "this_attr_not_exist", None)
def test_str_to_bool_or_str():
true_cases = ["y", "yes", "t", "true", "on", "1"]
false_cases = ["n", "no", "f", "false", "off", "0"]
other_cases = ["yyeess", "noooo", "lightning"]
for case in true_cases:
assert str_to_bool_or_str(case) is True
for case in false_cases:
assert str_to_bool_or_str(case) is False
for case in other_cases:
assert str_to_bool_or_str(case) == case
def test_str_to_bool():
true_cases = ["y", "yes", "t", "true", "on", "1"]
false_cases = ["n", "no", "f", "false", "off", "0"]
other_cases = ["yyeess", "noooo", "lightning"]
for case in true_cases:
assert str_to_bool(case) is True
for case in false_cases:
assert str_to_bool(case) is False
for case in other_cases:
with pytest.raises(ValueError):
str_to_bool(case)
def test_str_to_bool_or_int():
assert str_to_bool_or_int("0") is False
assert str_to_bool_or_int("1") is True
assert str_to_bool_or_int("true") is True
assert str_to_bool_or_int("2") == 2
assert str_to_bool_or_int("abc") == "abc"
def test_is_picklable(tmpdir):
# See the full list of picklable types at
# https://docs.python.org/3/library/pickle.html#pickle-picklable
class UnpicklableClass:
# Only classes defined at the top level of a module are picklable.
pass
true_cases = [None, True, 123, "str", (123, "str"), max]
false_cases = [unpicklable_function, UnpicklableClass, ScriptModule()]
for case in true_cases:
assert is_picklable(case) is True
for case in false_cases:
assert is_picklable(case) is False
def test_clean_namespace(tmpdir):
# See the full list of picklable types at
# https://docs.python.org/3/library/pickle.html#pickle-picklable
class UnpicklableClass:
# Only classes defined at the top level of a module are picklable.
pass
test_case = {"1": None, "2": True, "3": 123, "4": unpicklable_function, "5": UnpicklableClass}
clean_namespace(test_case)
assert test_case == {"1": None, "2": True, "3": 123}
def test_parse_class_init_keys(tmpdir):
class Class:
def __init__(self, hparams, *my_args, anykw=42, **my_kwargs):
pass
assert parse_class_init_keys(Class) == ("self", "my_args", "my_kwargs")
def test_get_init_args(tmpdir):
class AutomaticArgsModel:
def __init__(self, anyarg, anykw=42, **kwargs):
super().__init__()
self.get_init_args_wrapper()
def get_init_args_wrapper(self):
frame = inspect.currentframe().f_back
self.result = get_init_args(frame)
my_class = AutomaticArgsModel("test", anykw=32, otherkw=123)
assert my_class.result == {"anyarg": "test", "anykw": 32, "otherkw": 123}
my_class.get_init_args_wrapper()
assert my_class.result == {}
def test_collect_init_args():
class AutomaticArgsParent:
def __init__(self, anyarg, anykw=42, **kwargs):
super().__init__()
self.get_init_args_wrapper()
def get_init_args_wrapper(self):
frame = inspect.currentframe()
self.result = collect_init_args(frame, [])
class AutomaticArgsChild(AutomaticArgsParent):
def __init__(self, anyarg, childarg, anykw=42, childkw=42, **kwargs):
super().__init__(anyarg, anykw=anykw, **kwargs)
my_class = AutomaticArgsChild("test1", "test2", anykw=32, childkw=22, otherkw=123)
assert my_class.result[0] == {"anyarg": "test1", "anykw": 32, "otherkw": 123}
assert my_class.result[1] == {"anyarg": "test1", "childarg": "test2", "anykw": 32, "childkw": 22, "otherkw": 123}
def test_attribute_dict(tmpdir):
# Test initialization
inputs = {"key1": 1, "key2": "abc"}
ad = AttributeDict(inputs)
for key, value in inputs.items():
assert getattr(ad, key) == value
# Test adding new items
ad = AttributeDict()
ad.update({"key1": 1})
assert ad.key1 == 1
# Test updating existing items
ad = AttributeDict({"key1": 1})
ad.key1 = 123
assert ad.key1 == 123
def test_flatten_dict(tmpdir):
d = {"1": 1, "_": {"2": 2, "_": {"3": 3, "4": 4}}}
expected = {"1": 1, "2": 2, "3": 3, "4": 4}
assert flatten_dict(d) == expected