lightning/tests/tests_pytorch/models/test_hparams.py

908 lines
30 KiB
Python

# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import copy
import functools
import os
import pickle
from argparse import Namespace
from dataclasses import dataclass
from enum import Enum
from unittest import mock
import cloudpickle
import pytest
import torch
from fsspec.implementations.local import LocalFileSystem
from lightning_utilities.core.imports import RequirementCache
from torch.utils.data import DataLoader
from pytorch_lightning import LightningModule, Trainer
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.core.datamodule import LightningDataModule
from pytorch_lightning.core.mixins import HyperparametersMixin
from pytorch_lightning.core.saving import load_hparams_from_yaml, save_hparams_to_yaml
from pytorch_lightning.demos.boring_classes import BoringDataModule, BoringModel, RandomDataset
from pytorch_lightning.utilities import _OMEGACONF_AVAILABLE, AttributeDict, is_picklable
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from tests_pytorch.helpers.runif import RunIf
from tests_pytorch.helpers.utils import no_warning_call
if _OMEGACONF_AVAILABLE:
from omegaconf import Container, OmegaConf
from omegaconf.dictconfig import DictConfig
class SaveHparamsModel(BoringModel):
"""Tests that a model can take an object."""
def __init__(self, hparams):
super().__init__()
self.save_hyperparameters(hparams)
def decorate(func):
@functools.wraps(func)
def wrapper(*args, **kwargs):
return func(*args, **kwargs)
return wrapper
class SaveHparamsDecoratedModel(BoringModel):
"""Tests that a model can take an object."""
@decorate
@decorate
def __init__(self, hparams, *my_args, **my_kwargs):
super().__init__()
self.save_hyperparameters(hparams)
class SaveHparamsDataModule(BoringDataModule):
"""Tests that a model can take an object."""
def __init__(self, hparams):
super().__init__()
self.save_hyperparameters(hparams)
class SaveHparamsDecoratedDataModule(BoringDataModule):
"""Tests that a model can take an object."""
@decorate
@decorate
def __init__(self, hparams, *my_args, **my_kwargs):
super().__init__()
self.save_hyperparameters(hparams)
# -------------------------
# STANDARD TESTS
# -------------------------
def _run_standard_hparams_test(tmpdir, model, cls, datamodule=None, try_overwrite=False):
"""Tests for the existence of an arg 'test_arg=14'."""
obj = datamodule if issubclass(cls, LightningDataModule) else model
hparam_type = type(obj.hparams)
# test proper property assignments
assert obj.hparams.test_arg == 14
# verify we can train
trainer = Trainer(default_root_dir=tmpdir, max_epochs=1, overfit_batches=2)
trainer.fit(model, datamodule=datamodule if issubclass(cls, LightningDataModule) else None)
# make sure the raw checkpoint saved the properties
raw_checkpoint_path = _raw_checkpoint_path(trainer)
raw_checkpoint = torch.load(raw_checkpoint_path)
assert cls.CHECKPOINT_HYPER_PARAMS_KEY in raw_checkpoint
assert raw_checkpoint[cls.CHECKPOINT_HYPER_PARAMS_KEY]["test_arg"] == 14
# verify that model loads correctly
obj2 = cls.load_from_checkpoint(raw_checkpoint_path)
assert obj2.hparams.test_arg == 14
assert isinstance(obj2.hparams, hparam_type)
if try_overwrite:
# verify that we can overwrite the property
obj3 = cls.load_from_checkpoint(raw_checkpoint_path, test_arg=78)
assert obj3.hparams.test_arg == 78
return raw_checkpoint_path
@pytest.mark.parametrize(
"cls", [SaveHparamsModel, SaveHparamsDecoratedModel, SaveHparamsDataModule, SaveHparamsDecoratedDataModule]
)
def test_namespace_hparams(tmpdir, cls):
hparams = Namespace(test_arg=14)
if issubclass(cls, LightningDataModule):
model = BoringModel()
datamodule = cls(hparams=hparams)
else:
model = cls(hparams=hparams)
datamodule = None
# run standard test suite
_run_standard_hparams_test(tmpdir, model, cls, datamodule=datamodule)
@pytest.mark.parametrize(
"cls", [SaveHparamsModel, SaveHparamsDecoratedModel, SaveHparamsDataModule, SaveHparamsDecoratedDataModule]
)
def test_dict_hparams(tmpdir, cls):
hparams = {"test_arg": 14}
if issubclass(cls, LightningDataModule):
model = BoringModel()
datamodule = cls(hparams=hparams)
else:
model = cls(hparams=hparams)
datamodule = None
# run standard test suite
_run_standard_hparams_test(tmpdir, model, cls, datamodule=datamodule)
@RunIf(omegaconf=True)
@pytest.mark.parametrize(
"cls", [SaveHparamsModel, SaveHparamsDecoratedModel, SaveHparamsDataModule, SaveHparamsDecoratedDataModule]
)
def test_omega_conf_hparams(tmpdir, cls):
conf = OmegaConf.create(dict(test_arg=14, mylist=[15.4, dict(a=1, b=2)]))
if issubclass(cls, LightningDataModule):
model = BoringModel()
obj = datamodule = cls(hparams=conf)
else:
obj = model = cls(hparams=conf)
datamodule = None
assert isinstance(obj.hparams, Container)
# run standard test suite
raw_checkpoint_path = _run_standard_hparams_test(tmpdir, model, cls, datamodule=datamodule)
obj2 = cls.load_from_checkpoint(raw_checkpoint_path)
assert isinstance(obj2.hparams, Container)
# config specific tests
assert obj2.hparams.test_arg == 14
assert obj2.hparams.mylist[0] == 15.4
def test_explicit_args_hparams(tmpdir):
"""Tests that a model can take implicit args and assign."""
# define model
class LocalModel(BoringModel):
def __init__(self, test_arg, test_arg2):
super().__init__()
self.save_hyperparameters("test_arg", "test_arg2")
model = LocalModel(test_arg=14, test_arg2=90)
# run standard test suite
raw_checkpoint_path = _run_standard_hparams_test(tmpdir, model, LocalModel)
model = LocalModel.load_from_checkpoint(raw_checkpoint_path, test_arg2=120)
# config specific tests
assert model.hparams.test_arg2 == 120
def test_implicit_args_hparams(tmpdir):
"""Tests that a model can take regular args and assign."""
# define model
class LocalModel(BoringModel):
def __init__(self, test_arg, test_arg2):
super().__init__()
self.save_hyperparameters()
model = LocalModel(test_arg=14, test_arg2=90)
# run standard test suite
raw_checkpoint_path = _run_standard_hparams_test(tmpdir, model, LocalModel)
model = LocalModel.load_from_checkpoint(raw_checkpoint_path, test_arg2=120)
# config specific tests
assert model.hparams.test_arg2 == 120
def test_explicit_missing_args_hparams(tmpdir):
"""Tests that a model can take regular args and assign."""
# define model
class LocalModel(BoringModel):
def __init__(self, test_arg, test_arg2):
super().__init__()
self.save_hyperparameters("test_arg")
model = LocalModel(test_arg=14, test_arg2=90)
# test proper property assignments
assert model.hparams.test_arg == 14
# verify we can train
trainer = Trainer(default_root_dir=tmpdir, max_epochs=2, overfit_batches=0.5)
trainer.fit(model)
# make sure the raw checkpoint saved the properties
raw_checkpoint_path = _raw_checkpoint_path(trainer)
raw_checkpoint = torch.load(raw_checkpoint_path)
assert LightningModule.CHECKPOINT_HYPER_PARAMS_KEY in raw_checkpoint
assert raw_checkpoint[LightningModule.CHECKPOINT_HYPER_PARAMS_KEY]["test_arg"] == 14
# verify that model loads correctly
model = LocalModel.load_from_checkpoint(raw_checkpoint_path, test_arg2=123)
assert model.hparams.test_arg == 14
assert "test_arg2" not in model.hparams # test_arg2 is not registered in class init
return raw_checkpoint_path
# -------------------------
# SPECIFIC TESTS
# -------------------------
def test_class_nesting():
class MyModule(LightningModule):
def forward(self):
...
# make sure PL modules are always nn.Module
a = MyModule()
assert isinstance(a, torch.nn.Module)
def test_outside():
a = MyModule()
_ = a.hparams
class A:
def test(self):
a = MyModule()
_ = a.hparams
def test2(self):
test_outside()
test_outside()
A().test2()
A().test()
class CustomBoringModel(BoringModel):
def __init__(self, batch_size=64):
super().__init__()
self.save_hyperparameters()
class SubClassBoringModel(CustomBoringModel):
any_other_loss = torch.nn.CrossEntropyLoss()
def __init__(self, *args, subclass_arg=1200, **kwargs):
super().__init__(*args, **kwargs)
self.save_hyperparameters()
class NonSavingSubClassBoringModel(CustomBoringModel):
any_other_loss = torch.nn.CrossEntropyLoss()
def __init__(self, *args, subclass_arg=1200, **kwargs):
super().__init__(*args, **kwargs)
class SubSubClassBoringModel(SubClassBoringModel):
pass
class AggSubClassBoringModel(SubClassBoringModel):
def __init__(self, *args, my_loss=torch.nn.CrossEntropyLoss(), **kwargs):
super().__init__(*args, **kwargs)
self.save_hyperparameters()
class UnconventionalArgsBoringModel(CustomBoringModel):
"""A model that has unconventional names for "self", "*args" and "**kwargs"."""
def __init__(obj, *more_args, other_arg=300, **more_kwargs):
# intentionally named obj
super().__init__(*more_args, **more_kwargs)
obj.save_hyperparameters()
if _OMEGACONF_AVAILABLE:
class DictConfSubClassBoringModel(SubClassBoringModel):
def __init__(self, *args, dict_conf=OmegaConf.create(dict(my_param="something")), **kwargs):
super().__init__(*args, **kwargs)
self.save_hyperparameters()
else:
class DictConfSubClassBoringModel:
...
@pytest.mark.parametrize(
"cls",
[
CustomBoringModel,
SubClassBoringModel,
NonSavingSubClassBoringModel,
SubSubClassBoringModel,
AggSubClassBoringModel,
UnconventionalArgsBoringModel,
pytest.param(DictConfSubClassBoringModel, marks=RunIf(omegaconf=True)),
],
)
def test_collect_init_arguments(tmpdir, cls):
"""Test that the model automatically saves the arguments passed into the constructor."""
extra_args = {}
if cls is AggSubClassBoringModel:
extra_args.update(my_loss=torch.nn.CosineEmbeddingLoss())
elif cls is DictConfSubClassBoringModel:
extra_args.update(dict_conf=OmegaConf.create(dict(my_param="anything")))
model = cls(**extra_args)
assert model.hparams.batch_size == 64
model = cls(batch_size=179, **extra_args)
assert model.hparams.batch_size == 179
if isinstance(model, (SubClassBoringModel, NonSavingSubClassBoringModel)):
assert model.hparams.subclass_arg == 1200
if isinstance(model, AggSubClassBoringModel):
assert isinstance(model.hparams.my_loss, torch.nn.CosineEmbeddingLoss)
# verify that the checkpoint saved the correct values
trainer = Trainer(default_root_dir=tmpdir, max_epochs=2, overfit_batches=0.5)
trainer.fit(model)
raw_checkpoint_path = _raw_checkpoint_path(trainer)
raw_checkpoint = torch.load(raw_checkpoint_path)
assert LightningModule.CHECKPOINT_HYPER_PARAMS_KEY in raw_checkpoint
assert raw_checkpoint[LightningModule.CHECKPOINT_HYPER_PARAMS_KEY]["batch_size"] == 179
# verify that model loads correctly
model = cls.load_from_checkpoint(raw_checkpoint_path)
assert model.hparams.batch_size == 179
if isinstance(model, AggSubClassBoringModel):
assert isinstance(model.hparams.my_loss, torch.nn.CosineEmbeddingLoss)
if isinstance(model, DictConfSubClassBoringModel):
assert isinstance(model.hparams.dict_conf, Container)
assert model.hparams.dict_conf["my_param"] == "anything"
# verify that we can overwrite whatever we want
model = cls.load_from_checkpoint(raw_checkpoint_path, batch_size=99)
assert model.hparams.batch_size == 99
def _raw_checkpoint_path(trainer) -> str:
raw_checkpoint_paths = os.listdir(trainer.checkpoint_callback.dirpath)
raw_checkpoint_paths = [x for x in raw_checkpoint_paths if ".ckpt" in x]
assert raw_checkpoint_paths
raw_checkpoint_path = raw_checkpoint_paths[0]
raw_checkpoint_path = os.path.join(trainer.checkpoint_callback.dirpath, raw_checkpoint_path)
return raw_checkpoint_path
@pytest.mark.parametrize("base_class", (HyperparametersMixin, LightningModule, LightningDataModule))
def test_save_hyperparameters_under_composition(base_class):
"""Test that in a composition where the parent is not a Lightning-like module, the parent's arguments don't get
collected."""
class ChildInComposition(base_class):
def __init__(self, same_arg):
super().__init__()
self.save_hyperparameters()
class NotPLSubclass: # intentionally not subclassing LightningModule/LightningDataModule
def __init__(self, same_arg="parent_default", other_arg="other"):
self.child = ChildInComposition(same_arg="cocofruit")
parent = NotPLSubclass()
assert parent.child.hparams == dict(same_arg="cocofruit")
class LocalVariableModelSuperLast(BoringModel):
"""This model has the super().__init__() call at the end."""
def __init__(self, arg1, arg2, *args, **kwargs):
self.argument1 = arg1 # arg2 intentionally not set
arg1 = "overwritten"
local_var = 1234 # noqa: F841
super().__init__(*args, **kwargs) # this is intentionally here at the end
class LocalVariableModelSuperFirst(BoringModel):
"""This model has the save_hyperparameters() call at the end."""
def __init__(self, arg1, arg2, *args, **kwargs):
super().__init__(*args, **kwargs)
self.argument1 = arg1 # arg2 intentionally not set
arg1 = "overwritten"
local_var = 1234 # noqa: F841
self.save_hyperparameters() # this is intentionally here at the end
@pytest.mark.parametrize(
"cls",
[
LocalVariableModelSuperFirst,
# LocalVariableModelSuperLast,
],
)
def test_collect_init_arguments_with_local_vars(cls):
"""Tests that only the arguments are collected and not local variables."""
model = cls(arg1=1, arg2=2)
assert "local_var" not in model.hparams
assert model.hparams["arg1"] == "overwritten"
assert model.hparams["arg2"] == 2
class AnotherArgModel(BoringModel):
def __init__(self, arg1):
super().__init__()
self.save_hyperparameters(arg1)
class OtherArgsModel(BoringModel):
def __init__(self, arg1, arg2):
super().__init__()
self.save_hyperparameters(arg1, arg2)
@pytest.mark.parametrize(
"cls,config", [(AnotherArgModel, dict(arg1=42)), (OtherArgsModel, dict(arg1=3.14, arg2="abc"))]
)
def test_single_config_models_fail(tmpdir, cls, config):
"""Test fail on passing unsupported config type."""
with pytest.raises(ValueError):
_ = cls(**config)
@pytest.mark.parametrize("past_key", ["module_arguments"])
def test_load_past_checkpoint(tmpdir, past_key):
model = CustomBoringModel()
# verify we can train
trainer = Trainer(default_root_dir=tmpdir, max_epochs=1)
trainer.fit(model)
# make sure the raw checkpoint saved the properties
raw_checkpoint_path = _raw_checkpoint_path(trainer)
raw_checkpoint = torch.load(raw_checkpoint_path)
raw_checkpoint[past_key] = raw_checkpoint[LightningModule.CHECKPOINT_HYPER_PARAMS_KEY]
raw_checkpoint["hparams_type"] = "Namespace"
raw_checkpoint[past_key]["batch_size"] = -17
del raw_checkpoint[LightningModule.CHECKPOINT_HYPER_PARAMS_KEY]
# save back the checkpoint
torch.save(raw_checkpoint, raw_checkpoint_path)
# verify that model loads correctly
model2 = CustomBoringModel.load_from_checkpoint(raw_checkpoint_path)
assert model2.hparams.batch_size == -17
def test_hparams_pickle(tmpdir):
ad = AttributeDict({"key1": 1, "key2": "abc"})
pkl = pickle.dumps(ad)
assert ad == pickle.loads(pkl)
pkl = cloudpickle.dumps(ad)
assert ad == pickle.loads(pkl)
class UnpickleableArgsBoringModel(BoringModel):
"""A model that has an attribute that cannot be pickled."""
def __init__(self, foo="bar", pickle_me=(lambda x: x + 1), **kwargs):
super().__init__(**kwargs)
assert not is_picklable(pickle_me)
self.save_hyperparameters()
def test_hparams_pickle_warning(tmpdir):
model = UnpickleableArgsBoringModel()
trainer = Trainer(default_root_dir=tmpdir, max_steps=1)
with pytest.warns(UserWarning, match="attribute 'pickle_me' removed from hparams because it cannot be pickled"):
trainer.fit(model)
assert "pickle_me" not in model.hparams
def test_hparams_save_yaml(tmpdir):
class Options(str, Enum):
option1name = "option1val"
option2name = "option2val"
option3name = "option3val"
hparams = dict(
batch_size=32,
learning_rate=0.001,
data_root="./any/path/here",
nested=dict(any_num=123, anystr="abcd"),
switch=Options.option3name,
)
path_yaml = os.path.join(tmpdir, "testing-hparams.yaml")
def _compare_params(loaded_params, default_params: dict):
assert isinstance(loaded_params, (dict, DictConfig))
assert loaded_params.keys() == default_params.keys()
for k, v in default_params.items():
if isinstance(v, Enum):
assert v.name == loaded_params[k]
else:
assert v == loaded_params[k]
save_hparams_to_yaml(path_yaml, hparams)
_compare_params(load_hparams_from_yaml(path_yaml, use_omegaconf=False), hparams)
save_hparams_to_yaml(path_yaml, Namespace(**hparams))
_compare_params(load_hparams_from_yaml(path_yaml, use_omegaconf=False), hparams)
save_hparams_to_yaml(path_yaml, AttributeDict(hparams))
_compare_params(load_hparams_from_yaml(path_yaml, use_omegaconf=False), hparams)
if _OMEGACONF_AVAILABLE:
save_hparams_to_yaml(path_yaml, OmegaConf.create(hparams))
_compare_params(load_hparams_from_yaml(path_yaml), hparams)
class NoArgsSubClassBoringModel(CustomBoringModel):
def __init__(self):
super().__init__()
@pytest.mark.parametrize("cls", [BoringModel, NoArgsSubClassBoringModel])
def test_model_nohparams_train_test(tmpdir, cls):
"""Test models that do not take any argument in init."""
model = cls()
trainer = Trainer(max_epochs=1, default_root_dir=tmpdir)
train_loader = DataLoader(RandomDataset(32, 64), batch_size=32)
trainer.fit(model, train_loader)
test_loader = DataLoader(RandomDataset(32, 64), batch_size=32)
trainer.test(dataloaders=test_loader)
def test_model_ignores_non_exist_kwargument(tmpdir):
"""Test that the model takes only valid class arguments."""
class LocalModel(BoringModel):
def __init__(self, batch_size=15):
super().__init__()
self.save_hyperparameters()
model = LocalModel()
assert model.hparams.batch_size == 15
# verify that the checkpoint saved the correct values
trainer = Trainer(default_root_dir=tmpdir, max_epochs=1)
trainer.fit(model)
# verify that we can overwrite whatever we want
raw_checkpoint_path = _raw_checkpoint_path(trainer)
model = LocalModel.load_from_checkpoint(raw_checkpoint_path, non_exist_kwarg=99)
assert "non_exist_kwarg" not in model.hparams
class SuperClassPositionalArgs(BoringModel):
def __init__(self, hparams):
super().__init__()
self._hparams = hparams # pretend BoringModel did not call self.save_hyperparameters()
class SubClassVarArgs(SuperClassPositionalArgs):
"""Loading this model should accept hparams and init in the super class."""
def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)
def test_args(tmpdir):
"""Test for inheritance: super class takes positional arg, subclass takes varargs."""
hparams = dict(test=1)
model = SubClassVarArgs(hparams)
trainer = Trainer(default_root_dir=tmpdir, max_epochs=1)
trainer.fit(model)
raw_checkpoint_path = _raw_checkpoint_path(trainer)
with pytest.raises(TypeError, match=r"__init__\(\) got an unexpected keyword argument 'test'"):
SubClassVarArgs.load_from_checkpoint(raw_checkpoint_path)
class RuntimeParamChangeModelSaving(BoringModel):
def __init__(self, **kwargs):
super().__init__()
self.save_hyperparameters()
@pytest.mark.parametrize("cls", [RuntimeParamChangeModelSaving])
def test_init_arg_with_runtime_change(tmpdir, cls):
"""Test that we save/export only the initial hparams, no other runtime change allowed."""
model = cls(running_arg=123)
assert model.hparams.running_arg == 123
model.hparams.running_arg = -1
assert model.hparams.running_arg == -1
trainer = Trainer(
default_root_dir=tmpdir, limit_train_batches=2, limit_val_batches=2, limit_test_batches=2, max_epochs=1
)
trainer.fit(model)
path_yaml = os.path.join(trainer.logger.log_dir, trainer.logger.NAME_HPARAMS_FILE)
hparams = load_hparams_from_yaml(path_yaml)
assert hparams.get("running_arg") == 123
class UnsafeParamModel(BoringModel):
def __init__(self, my_path, any_param=123):
super().__init__()
self.save_hyperparameters()
def test_model_with_fsspec_as_parameter(tmpdir):
model = UnsafeParamModel(LocalFileSystem(tmpdir))
trainer = Trainer(
default_root_dir=tmpdir, limit_train_batches=2, limit_val_batches=2, limit_test_batches=2, max_epochs=1
)
trainer.fit(model)
trainer.test()
@pytest.mark.skipif(RequirementCache("hydra-core<1.1"), reason="Requires Hydra's Compose API")
def test_model_save_hyper_parameters_interpolation_with_hydra(tmpdir):
"""This test relies on configuration saved under tests/models/conf/config.yaml."""
from hydra import compose, initialize
class TestHydraModel(BoringModel):
def __init__(self, args_0, args_1, args_2, kwarg_1=None):
self.save_hyperparameters()
assert self.hparams.args_0.log == "Something"
assert self.hparams.args_1["cfg"].log == "Something"
assert self.hparams.args_2[0].log == "Something"
assert self.hparams.kwarg_1["cfg"][0].log == "Something"
super().__init__()
with initialize(config_path="conf"):
args_0 = compose(config_name="config")
args_1 = {"cfg": compose(config_name="config")}
args_2 = [compose(config_name="config")]
kwarg_1 = {"cfg": [compose(config_name="config")]}
model = TestHydraModel(args_0, args_1, args_2, kwarg_1=kwarg_1)
epochs = 2
checkpoint_callback = ModelCheckpoint(monitor=None, dirpath=tmpdir, save_top_k=-1)
trainer = Trainer(
default_root_dir=tmpdir,
callbacks=[checkpoint_callback],
limit_train_batches=10,
limit_val_batches=10,
max_epochs=epochs,
logger=False,
)
trainer.fit(model)
_ = TestHydraModel.load_from_checkpoint(checkpoint_callback.best_model_path)
@pytest.mark.parametrize("ignore", ("arg2", ("arg2", "arg3")))
def test_ignore_args_list_hparams(tmpdir, ignore):
"""Tests that args can be ignored in save_hyperparameters."""
class LocalModel(BoringModel):
def __init__(self, arg1, arg2, arg3):
super().__init__()
self.save_hyperparameters(ignore=ignore)
model = LocalModel(arg1=14, arg2=90, arg3=50)
# test proper property assignments
assert model.hparams.arg1 == 14
for arg in ignore:
assert arg not in model.hparams
# verify we can train
trainer = Trainer(default_root_dir=tmpdir, max_epochs=2, overfit_batches=0.5)
trainer.fit(model)
# make sure the raw checkpoint saved the properties
raw_checkpoint_path = _raw_checkpoint_path(trainer)
raw_checkpoint = torch.load(raw_checkpoint_path)
assert LightningModule.CHECKPOINT_HYPER_PARAMS_KEY in raw_checkpoint
assert raw_checkpoint[LightningModule.CHECKPOINT_HYPER_PARAMS_KEY]["arg1"] == 14
# verify that model loads correctly
model = LocalModel.load_from_checkpoint(raw_checkpoint_path, arg2=123, arg3=100)
assert model.hparams.arg1 == 14
for arg in ignore:
assert arg not in model.hparams
class IgnoreAllParametersModel(BoringModel):
def __init__(self, arg1, arg2, arg3):
super().__init__()
self.save_hyperparameters(ignore=("arg1", "arg2", "arg3"))
class NoParametersModel(BoringModel):
def __init__(self):
super().__init__()
self.save_hyperparameters()
@pytest.mark.parametrize(
"model",
(
IgnoreAllParametersModel(arg1=14, arg2=90, arg3=50),
NoParametersModel(),
),
)
def test_save_no_parameters(model):
"""Test that calling save_hyperparameters works if no parameters need saving."""
assert model.hparams == {}
assert model._hparams_initial == {}
class HparamsKwargsContainerModel(BoringModel):
def __init__(self, **kwargs):
super().__init__()
self.save_hyperparameters(kwargs)
class HparamsNamespaceContainerModel(BoringModel):
def __init__(self, config):
super().__init__()
self.save_hyperparameters(config)
def test_empty_hparams_container(tmpdir):
"""Test that save_hyperparameters() is a no-op when saving an empty hparams container."""
model = HparamsKwargsContainerModel()
assert not model.hparams
model = HparamsNamespaceContainerModel(Namespace())
assert not model.hparams
def test_hparams_name_from_container(tmpdir):
"""Test that save_hyperparameters(container) captures the name of the argument correctly."""
model = HparamsKwargsContainerModel(a=1, b=2)
assert model._hparams_name is None
model = HparamsNamespaceContainerModel(Namespace(a=1, b=2))
assert model._hparams_name == "config"
@dataclass
class DataClassModel(BoringModel):
mandatory: int
optional: str = "optional"
ignore_me: bool = False
def __post_init__(self):
super().__init__()
self.save_hyperparameters(ignore=("ignore_me",))
def test_dataclass_lightning_module(tmpdir):
"""Test that save_hyperparameters() works with a LightningModule as a dataclass."""
model = DataClassModel(33, optional="cocofruit")
assert model.hparams == dict(mandatory=33, optional="cocofruit")
class NoHparamsModel(BoringModel):
"""Tests a model without hparams."""
class DataModuleWithoutHparams(LightningDataModule):
def train_dataloader(self, *args, **kwargs) -> DataLoader:
return DataLoader(RandomDataset(32, 64), batch_size=32)
class DataModuleWithHparams(LightningDataModule):
def __init__(self, hparams):
super().__init__()
self.save_hyperparameters(hparams)
def train_dataloader(self, *args, **kwargs) -> DataLoader:
return DataLoader(RandomDataset(32, 64), batch_size=32)
def _get_mock_logger(tmpdir):
mock_logger = mock.MagicMock(name="logger")
mock_logger.name = "mock_logger"
mock_logger.save_dir = tmpdir
mock_logger.version = "0"
del mock_logger.__iter__
return mock_logger
@pytest.mark.parametrize("model", (SaveHparamsModel({"arg1": 5, "arg2": "abc"}), NoHparamsModel()))
@pytest.mark.parametrize("data", (DataModuleWithHparams({"data_dir": "foo"}), DataModuleWithoutHparams()))
def test_adding_datamodule_hparams(tmpdir, model, data):
"""Test that hparams from datamodule and model are logged."""
org_model_hparams = copy.deepcopy(model.hparams_initial)
org_data_hparams = copy.deepcopy(data.hparams_initial)
mock_logger = _get_mock_logger(tmpdir)
trainer = Trainer(default_root_dir=tmpdir, max_epochs=1, logger=mock_logger)
trainer.fit(model, datamodule=data)
# Hparams of model and data were not modified
assert org_model_hparams == model.hparams
assert org_data_hparams == data.hparams
# Merged hparams were logged
merged_hparams = copy.deepcopy(org_model_hparams)
merged_hparams.update(org_data_hparams)
if merged_hparams:
mock_logger.log_hyperparams.assert_called_with(merged_hparams)
else:
mock_logger.log_hyperparams.assert_not_called()
def test_no_datamodule_for_hparams(tmpdir):
"""Test that hparams model are logged if no datamodule is used."""
model = SaveHparamsModel({"arg1": 5, "arg2": "abc"})
org_model_hparams = copy.deepcopy(model.hparams_initial)
data = DataModuleWithoutHparams()
data.setup("fit")
mock_logger = _get_mock_logger(tmpdir)
trainer = Trainer(default_root_dir=tmpdir, max_epochs=1, logger=mock_logger)
trainer.fit(model, datamodule=data)
# Merged hparams were logged
mock_logger.log_hyperparams.assert_called_with(org_model_hparams)
def test_colliding_hparams(tmpdir):
model = SaveHparamsModel({"data_dir": "abc", "arg2": "abc"})
data = DataModuleWithHparams({"data_dir": "foo"})
trainer = Trainer(default_root_dir=tmpdir, max_epochs=1)
with pytest.raises(MisconfigurationException, match=r"Error while merging hparams:"):
trainer.fit(model, datamodule=data)
def test_nn_modules_warning_when_saved_as_hparams():
class TorchModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.l1 = torch.nn.Linear(4, 5)
class CustomBoringModelWarn(BoringModel):
def __init__(self, encoder, decoder, other_hparam=7):
super().__init__()
self.save_hyperparameters()
with pytest.warns(UserWarning, match="is an instance of `nn.Module` and is already saved"):
model = CustomBoringModelWarn(encoder=TorchModule(), decoder=TorchModule())
assert list(model.hparams) == ["encoder", "decoder", "other_hparam"]
class CustomBoringModelNoWarn(BoringModel):
def __init__(self, encoder, decoder, other_hparam=7):
super().__init__()
self.save_hyperparameters("other_hparam")
with no_warning_call(UserWarning, match="is an instance of `nn.Module` and is already saved"):
model = CustomBoringModelNoWarn(encoder=TorchModule(), decoder=TorchModule())
assert list(model.hparams) == ["other_hparam"]