329 lines
9.7 KiB
Python
329 lines
9.7 KiB
Python
# Copyright The PyTorch Lightning team.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
import inspect
|
|
|
|
import pytest
|
|
from torch.jit import ScriptModule
|
|
|
|
from pytorch_lightning.utilities.parsing import (
|
|
AttributeDict,
|
|
clean_namespace,
|
|
collect_init_args,
|
|
flatten_dict,
|
|
get_init_args,
|
|
is_picklable,
|
|
lightning_getattr,
|
|
lightning_hasattr,
|
|
lightning_setattr,
|
|
parse_class_init_keys,
|
|
str_to_bool,
|
|
str_to_bool_or_int,
|
|
str_to_bool_or_str,
|
|
)
|
|
|
|
unpicklable_function = lambda: None
|
|
|
|
|
|
@pytest.fixture(scope="module")
|
|
def model_cases():
|
|
|
|
class TestHparamsNamespace:
|
|
learning_rate = 1
|
|
|
|
def __contains__(self, item):
|
|
return item == "learning_rate"
|
|
|
|
TestHparamsDict = {'learning_rate': 2}
|
|
|
|
class TestModel1: # test for namespace
|
|
learning_rate = 0
|
|
|
|
model1 = TestModel1()
|
|
|
|
class TestModel2: # test for hparams namespace
|
|
hparams = TestHparamsNamespace()
|
|
|
|
model2 = TestModel2()
|
|
|
|
class TestModel3: # test for hparams dict
|
|
hparams = TestHparamsDict
|
|
|
|
model3 = TestModel3()
|
|
|
|
class TestModel4: # fail case
|
|
batch_size = 1
|
|
|
|
model4 = TestModel4()
|
|
|
|
class DataModule:
|
|
batch_size = 8
|
|
|
|
class Trainer:
|
|
datamodule = DataModule
|
|
|
|
class TestModel5: # test for datamodule
|
|
trainer = Trainer
|
|
|
|
model5 = TestModel5()
|
|
|
|
class TestModel6: # test for datamodule w/ hparams w/o attribute (should use datamodule)
|
|
trainer = Trainer
|
|
hparams = TestHparamsDict
|
|
|
|
model6 = TestModel6()
|
|
|
|
TestHparamsDict2 = {'batch_size': 2}
|
|
|
|
class TestModel7: # test for datamodule w/ hparams w/ attribute (should use datamodule)
|
|
trainer = Trainer
|
|
hparams = TestHparamsDict2
|
|
|
|
model7 = TestModel7()
|
|
|
|
return model1, model2, model3, model4, model5, model6, model7
|
|
|
|
|
|
def test_lightning_hasattr(tmpdir, model_cases):
|
|
"""Test that the lightning_hasattr works in all cases"""
|
|
model1, model2, model3, model4, model5, model6, model7 = models = model_cases
|
|
assert lightning_hasattr(model1, 'learning_rate'), \
|
|
'lightning_hasattr failed to find namespace variable'
|
|
assert lightning_hasattr(model2, 'learning_rate'), \
|
|
'lightning_hasattr failed to find hparams namespace variable'
|
|
assert lightning_hasattr(model3, 'learning_rate'), \
|
|
'lightning_hasattr failed to find hparams dict variable'
|
|
assert not lightning_hasattr(model4, 'learning_rate'), \
|
|
'lightning_hasattr found variable when it should not'
|
|
assert lightning_hasattr(model5, 'batch_size'), \
|
|
'lightning_hasattr failed to find batch_size in datamodule'
|
|
assert lightning_hasattr(model6, 'batch_size'), \
|
|
'lightning_hasattr failed to find batch_size in datamodule w/ hparams present'
|
|
assert lightning_hasattr(model7, 'batch_size'), \
|
|
'lightning_hasattr failed to find batch_size in hparams w/ datamodule present'
|
|
|
|
for m in models:
|
|
assert not lightning_hasattr(m, "this_attr_not_exist")
|
|
|
|
|
|
def test_lightning_getattr(tmpdir, model_cases):
|
|
"""Test that the lightning_getattr works in all cases"""
|
|
models = model_cases
|
|
for i, m in enumerate(models[:3]):
|
|
value = lightning_getattr(m, 'learning_rate')
|
|
assert value == i, 'attribute not correctly extracted'
|
|
|
|
model5, model6, model7 = models[4:]
|
|
assert lightning_getattr(model5, 'batch_size') == 8, \
|
|
'batch_size not correctly extracted'
|
|
assert lightning_getattr(model6, 'batch_size') == 8, \
|
|
'batch_size not correctly extracted'
|
|
assert lightning_getattr(model7, 'batch_size') == 8, \
|
|
'batch_size not correctly extracted'
|
|
|
|
for m in models:
|
|
with pytest.raises(
|
|
AttributeError,
|
|
match="is neither stored in the model namespace nor the `hparams` namespace/dict, nor the datamodule."
|
|
):
|
|
lightning_getattr(m, "this_attr_not_exist")
|
|
|
|
|
|
def test_lightning_setattr(tmpdir, model_cases):
|
|
"""Test that the lightning_setattr works in all cases"""
|
|
models = model_cases
|
|
for m in models[:3]:
|
|
lightning_setattr(m, 'learning_rate', 10)
|
|
assert lightning_getattr(m, 'learning_rate') == 10, \
|
|
'attribute not correctly set'
|
|
|
|
model5, model6, model7 = models[4:]
|
|
lightning_setattr(model5, 'batch_size', 128)
|
|
lightning_setattr(model6, 'batch_size', 128)
|
|
lightning_setattr(model7, 'batch_size', 128)
|
|
assert lightning_getattr(model5, 'batch_size') == 128, \
|
|
'batch_size not correctly set'
|
|
assert lightning_getattr(model6, 'batch_size') == 128, \
|
|
'batch_size not correctly set'
|
|
assert lightning_getattr(model7, 'batch_size') == 128, \
|
|
'batch_size not correctly set'
|
|
|
|
for m in models:
|
|
with pytest.raises(
|
|
AttributeError,
|
|
match="is neither stored in the model namespace nor the `hparams` namespace/dict, nor the datamodule."
|
|
):
|
|
lightning_setattr(m, "this_attr_not_exist", None)
|
|
|
|
|
|
def test_str_to_bool_or_str():
|
|
true_cases = ['y', 'yes', 't', 'true', 'on', '1']
|
|
false_cases = ['n', 'no', 'f', 'false', 'off', '0']
|
|
other_cases = ['yyeess', 'noooo', 'lightning']
|
|
|
|
for case in true_cases:
|
|
assert str_to_bool_or_str(case) is True
|
|
|
|
for case in false_cases:
|
|
assert str_to_bool_or_str(case) is False
|
|
|
|
for case in other_cases:
|
|
assert str_to_bool_or_str(case) == case
|
|
|
|
|
|
def test_str_to_bool():
|
|
true_cases = ['y', 'yes', 't', 'true', 'on', '1']
|
|
false_cases = ['n', 'no', 'f', 'false', 'off', '0']
|
|
other_cases = ['yyeess', 'noooo', 'lightning']
|
|
|
|
for case in true_cases:
|
|
assert str_to_bool(case) is True
|
|
|
|
for case in false_cases:
|
|
assert str_to_bool(case) is False
|
|
|
|
for case in other_cases:
|
|
with pytest.raises(ValueError):
|
|
str_to_bool(case)
|
|
|
|
|
|
def test_str_to_bool_or_int():
|
|
assert str_to_bool_or_int("0") is False
|
|
assert str_to_bool_or_int("1") is True
|
|
assert str_to_bool_or_int("true") is True
|
|
assert str_to_bool_or_int("2") == 2
|
|
assert str_to_bool_or_int("abc") == "abc"
|
|
|
|
|
|
def test_is_picklable(tmpdir):
|
|
# See the full list of picklable types at
|
|
# https://docs.python.org/3/library/pickle.html#pickle-picklable
|
|
class UnpicklableClass:
|
|
# Only classes defined at the top level of a module are picklable.
|
|
pass
|
|
|
|
true_cases = [None, True, 123, "str", (123, "str"), max]
|
|
false_cases = [unpicklable_function, UnpicklableClass, ScriptModule()]
|
|
|
|
for case in true_cases:
|
|
assert is_picklable(case) is True
|
|
|
|
for case in false_cases:
|
|
assert is_picklable(case) is False
|
|
|
|
|
|
def test_clean_namespace(tmpdir):
|
|
# See the full list of picklable types at
|
|
# https://docs.python.org/3/library/pickle.html#pickle-picklable
|
|
class UnpicklableClass:
|
|
# Only classes defined at the top level of a module are picklable.
|
|
pass
|
|
|
|
test_case = {
|
|
"1": None,
|
|
"2": True,
|
|
"3": 123,
|
|
"4": unpicklable_function,
|
|
"5": UnpicklableClass,
|
|
}
|
|
|
|
clean_namespace(test_case)
|
|
|
|
assert test_case == {"1": None, "2": True, "3": 123}
|
|
|
|
|
|
def test_parse_class_init_keys(tmpdir):
|
|
|
|
class Class:
|
|
|
|
def __init__(self, hparams, *my_args, anykw=42, **my_kwargs):
|
|
pass
|
|
|
|
assert parse_class_init_keys(Class) == ("self", "my_args", "my_kwargs")
|
|
|
|
|
|
def test_get_init_args(tmpdir):
|
|
|
|
class AutomaticArgsModel:
|
|
|
|
def __init__(self, anyarg, anykw=42, **kwargs):
|
|
super().__init__()
|
|
|
|
self.get_init_args_wrapper()
|
|
|
|
def get_init_args_wrapper(self):
|
|
frame = inspect.currentframe().f_back
|
|
self.result = get_init_args(frame)
|
|
|
|
my_class = AutomaticArgsModel("test", anykw=32, otherkw=123)
|
|
assert my_class.result == {"anyarg": "test", "anykw": 32, "otherkw": 123}
|
|
|
|
my_class.get_init_args_wrapper()
|
|
assert my_class.result == {}
|
|
|
|
|
|
def test_collect_init_args():
|
|
|
|
class AutomaticArgsParent:
|
|
|
|
def __init__(self, anyarg, anykw=42, **kwargs):
|
|
super().__init__()
|
|
self.get_init_args_wrapper()
|
|
|
|
def get_init_args_wrapper(self):
|
|
frame = inspect.currentframe()
|
|
self.result = collect_init_args(frame, [])
|
|
|
|
class AutomaticArgsChild(AutomaticArgsParent):
|
|
|
|
def __init__(self, anyarg, childarg, anykw=42, childkw=42, **kwargs):
|
|
super().__init__(anyarg, anykw=anykw, **kwargs)
|
|
|
|
my_class = AutomaticArgsChild("test1", "test2", anykw=32, childkw=22, otherkw=123)
|
|
assert my_class.result[0] == {"anyarg": "test1", "anykw": 32, "otherkw": 123}
|
|
assert my_class.result[1] == {"anyarg": "test1", "childarg": "test2", "anykw": 32, "childkw": 22, "otherkw": 123}
|
|
|
|
|
|
def test_attribute_dict(tmpdir):
|
|
# Test initialization
|
|
inputs = {
|
|
'key1': 1,
|
|
'key2': 'abc',
|
|
}
|
|
ad = AttributeDict(inputs)
|
|
for key, value in inputs.items():
|
|
assert getattr(ad, key) == value
|
|
|
|
# Test adding new items
|
|
ad = AttributeDict()
|
|
ad.update({'key1': 1})
|
|
assert ad.key1 == 1
|
|
|
|
# Test updating existing items
|
|
ad = AttributeDict({'key1': 1})
|
|
ad.key1 = 123
|
|
assert ad.key1 == 123
|
|
|
|
|
|
def test_flatten_dict(tmpdir):
|
|
d = {'1': 1, '_': {'2': 2, '_': {'3': 3, '4': 4}}}
|
|
|
|
expected = {
|
|
'1': 1,
|
|
'2': 2,
|
|
'3': 3,
|
|
'4': 4,
|
|
}
|
|
|
|
assert flatten_dict(d) == expected
|