# Copyright The PyTorch Lightning team. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import inspect import pytest from torch.jit import ScriptModule from pytorch_lightning import LightningDataModule, LightningModule, Trainer from pytorch_lightning.utilities.parsing import ( AttributeDict, clean_namespace, collect_init_args, flatten_dict, get_init_args, is_picklable, lightning_getattr, lightning_hasattr, lightning_setattr, parse_class_init_keys, str_to_bool, str_to_bool_or_int, str_to_bool_or_str, ) unpicklable_function = lambda: None def model_and_trainer_cases(): class TestHparamsNamespace(LightningModule): learning_rate = 1 def __contains__(self, item): return item == "learning_rate" TestHparamsDict = {"learning_rate": 2} class TestModel1(LightningModule): # test for namespace learning_rate = 0 model1 = TestModel1() class TestModel2(LightningModule): # test for hparams namespace hparams = TestHparamsNamespace() model2 = TestModel2() class TestModel3(LightningModule): # test for hparams dict hparams = TestHparamsDict model3 = TestModel3() class TestModel4(LightningModule): # fail case batch_size = 1 model4 = TestModel4() trainer1 = Trainer() model4.trainer = trainer1 datamodule = LightningDataModule() datamodule.batch_size = 8 trainer1.datamodule = datamodule model5 = LightningModule() model5.trainer = trainer1 class TestModel6(LightningModule): # test for datamodule w/ hparams w/o attribute (should use datamodule) hparams = TestHparamsDict model6 = TestModel6() model6.trainer = trainer1 TestHparamsDict2 = {"batch_size": 2} class TestModel7(LightningModule): # test for datamodule w/ hparams w/ attribute (should use datamodule) hparams = TestHparamsDict2 model7 = TestModel7() model7.trainer = trainer1 class TestDataModule8(LightningDataModule): # test for hparams dict hparams = TestHparamsDict2 model8 = TestModel1() trainer2 = Trainer() model8.trainer = trainer2 datamodule = TestDataModule8() trainer2.datamodule = datamodule return (model1, model2, model3, model4, model5, model6, model7, model8), (trainer1, trainer2) def test_lightning_hasattr(): """Test that the lightning_hasattr works in all cases.""" models, _ = model_and_trainer_cases() model1, model2, model3, model4, model5, model6, model7, model8 = models assert lightning_hasattr(model1, "learning_rate"), "lightning_hasattr failed to find namespace variable" assert lightning_hasattr(model2, "learning_rate"), "lightning_hasattr failed to find hparams namespace variable" assert lightning_hasattr(model3, "learning_rate"), "lightning_hasattr failed to find hparams dict variable" assert not lightning_hasattr(model4, "learning_rate"), "lightning_hasattr found variable when it should not" assert lightning_hasattr(model5, "batch_size"), "lightning_hasattr failed to find batch_size in datamodule" assert lightning_hasattr( model6, "batch_size" ), "lightning_hasattr failed to find batch_size in datamodule w/ hparams present" assert lightning_hasattr( model7, "batch_size" ), "lightning_hasattr failed to find batch_size in hparams w/ datamodule present" assert lightning_hasattr(model8, "batch_size") for m in models: assert not lightning_hasattr(m, "this_attr_not_exist") def test_lightning_getattr(): """Test that the lightning_getattr works in all cases.""" models, _ = model_and_trainer_cases() *__, model5, model6, model7, model8 = models for i, m in enumerate(models[:3]): value = lightning_getattr(m, "learning_rate") assert value == i, "attribute not correctly extracted" assert lightning_getattr(model5, "batch_size") == 8, "batch_size not correctly extracted" assert lightning_getattr(model6, "batch_size") == 8, "batch_size not correctly extracted" assert lightning_getattr(model7, "batch_size") == 8, "batch_size not correctly extracted" assert lightning_getattr(model8, "batch_size") == 2, "batch_size not correctly extracted" for m in models: with pytest.raises( AttributeError, match="is neither stored in the model namespace nor the `hparams` namespace/dict, nor the datamodule.", ): lightning_getattr(m, "this_attr_not_exist") def test_lightning_setattr(tmpdir): """Test that the lightning_setattr works in all cases.""" models, _ = model_and_trainer_cases() *__, model5, model6, model7, model8 = models for m in models[:3]: lightning_setattr(m, "learning_rate", 10) assert lightning_getattr(m, "learning_rate") == 10, "attribute not correctly set" lightning_setattr(model5, "batch_size", 128) lightning_setattr(model6, "batch_size", 128) lightning_setattr(model7, "batch_size", 128) assert lightning_getattr(model5, "batch_size") == 128, "batch_size not correctly set" assert lightning_getattr(model6, "batch_size") == 128, "batch_size not correctly set" assert lightning_getattr(model7, "batch_size") == 128, "batch_size not correctly set" assert lightning_getattr(model8, "batch_size") == 128, "batch_size not correctly set" for m in models: with pytest.raises( AttributeError, match="is neither stored in the model namespace nor the `hparams` namespace/dict, nor the datamodule.", ): lightning_setattr(m, "this_attr_not_exist", None) def test_str_to_bool_or_str(): true_cases = ["y", "yes", "t", "true", "on", "1"] false_cases = ["n", "no", "f", "false", "off", "0"] other_cases = ["yyeess", "noooo", "lightning"] for case in true_cases: assert str_to_bool_or_str(case) is True for case in false_cases: assert str_to_bool_or_str(case) is False for case in other_cases: assert str_to_bool_or_str(case) == case def test_str_to_bool(): true_cases = ["y", "yes", "t", "true", "on", "1"] false_cases = ["n", "no", "f", "false", "off", "0"] other_cases = ["yyeess", "noooo", "lightning"] for case in true_cases: assert str_to_bool(case) is True for case in false_cases: assert str_to_bool(case) is False for case in other_cases: with pytest.raises(ValueError): str_to_bool(case) def test_str_to_bool_or_int(): assert str_to_bool_or_int("0") is False assert str_to_bool_or_int("1") is True assert str_to_bool_or_int("true") is True assert str_to_bool_or_int("2") == 2 assert str_to_bool_or_int("abc") == "abc" def test_is_picklable(tmpdir): # See the full list of picklable types at # https://docs.python.org/3/library/pickle.html#pickle-picklable class UnpicklableClass: # Only classes defined at the top level of a module are picklable. pass true_cases = [None, True, 123, "str", (123, "str"), max] false_cases = [unpicklable_function, UnpicklableClass, ScriptModule()] for case in true_cases: assert is_picklable(case) is True for case in false_cases: assert is_picklable(case) is False def test_clean_namespace(tmpdir): # See the full list of picklable types at # https://docs.python.org/3/library/pickle.html#pickle-picklable class UnpicklableClass: # Only classes defined at the top level of a module are picklable. pass test_case = {"1": None, "2": True, "3": 123, "4": unpicklable_function, "5": UnpicklableClass} clean_namespace(test_case) assert test_case == {"1": None, "2": True, "3": 123} def test_parse_class_init_keys(tmpdir): class Class: def __init__(self, hparams, *my_args, anykw=42, **my_kwargs): pass assert parse_class_init_keys(Class) == ("self", "my_args", "my_kwargs") def test_get_init_args(tmpdir): class AutomaticArgsModel: def __init__(self, anyarg, anykw=42, **kwargs): super().__init__() self.get_init_args_wrapper() def get_init_args_wrapper(self): frame = inspect.currentframe().f_back self.result = get_init_args(frame) my_class = AutomaticArgsModel("test", anykw=32, otherkw=123) assert my_class.result == {"anyarg": "test", "anykw": 32, "otherkw": 123} my_class.get_init_args_wrapper() assert my_class.result == {} def test_collect_init_args(): class AutomaticArgsParent: def __init__(self, anyarg, anykw=42, **kwargs): super().__init__() self.get_init_args_wrapper() def get_init_args_wrapper(self): frame = inspect.currentframe() self.result = collect_init_args(frame, []) class AutomaticArgsChild(AutomaticArgsParent): def __init__(self, anyarg, childarg, anykw=42, childkw=42, **kwargs): super().__init__(anyarg, anykw=anykw, **kwargs) my_class = AutomaticArgsChild("test1", "test2", anykw=32, childkw=22, otherkw=123) assert my_class.result[0] == {"anyarg": "test1", "anykw": 32, "otherkw": 123} assert my_class.result[1] == {"anyarg": "test1", "childarg": "test2", "anykw": 32, "childkw": 22, "otherkw": 123} def test_attribute_dict(tmpdir): # Test initialization inputs = {"key1": 1, "key2": "abc"} ad = AttributeDict(inputs) for key, value in inputs.items(): assert getattr(ad, key) == value # Test adding new items ad = AttributeDict() ad.update({"key1": 1}) assert ad.key1 == 1 # Test updating existing items ad = AttributeDict({"key1": 1}) ad.key1 = 123 assert ad.key1 == 123 def test_flatten_dict(tmpdir): d = {"1": 1, "_": {"2": 2, "_": {"3": 3, "4": 4}}} expected = {"1": 1, "2": 2, "3": 3, "4": 4} assert flatten_dict(d) == expected