908 lines
30 KiB
Python
908 lines
30 KiB
Python
# Copyright The PyTorch Lightning team.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
import copy
|
|
import functools
|
|
import os
|
|
import pickle
|
|
from argparse import Namespace
|
|
from dataclasses import dataclass
|
|
from enum import Enum
|
|
from unittest import mock
|
|
|
|
import cloudpickle
|
|
import pytest
|
|
import torch
|
|
from fsspec.implementations.local import LocalFileSystem
|
|
from lightning_utilities.core.imports import RequirementCache
|
|
from torch.utils.data import DataLoader
|
|
|
|
from pytorch_lightning import LightningModule, Trainer
|
|
from pytorch_lightning.callbacks import ModelCheckpoint
|
|
from pytorch_lightning.core.datamodule import LightningDataModule
|
|
from pytorch_lightning.core.mixins import HyperparametersMixin
|
|
from pytorch_lightning.core.saving import load_hparams_from_yaml, save_hparams_to_yaml
|
|
from pytorch_lightning.demos.boring_classes import BoringDataModule, BoringModel, RandomDataset
|
|
from pytorch_lightning.utilities import _OMEGACONF_AVAILABLE, AttributeDict, is_picklable
|
|
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
|
from tests_pytorch.helpers.runif import RunIf
|
|
from tests_pytorch.helpers.utils import no_warning_call
|
|
|
|
if _OMEGACONF_AVAILABLE:
|
|
from omegaconf import Container, OmegaConf
|
|
from omegaconf.dictconfig import DictConfig
|
|
|
|
|
|
class SaveHparamsModel(BoringModel):
|
|
"""Tests that a model can take an object."""
|
|
|
|
def __init__(self, hparams):
|
|
super().__init__()
|
|
self.save_hyperparameters(hparams)
|
|
|
|
|
|
def decorate(func):
|
|
@functools.wraps(func)
|
|
def wrapper(*args, **kwargs):
|
|
return func(*args, **kwargs)
|
|
|
|
return wrapper
|
|
|
|
|
|
class SaveHparamsDecoratedModel(BoringModel):
|
|
"""Tests that a model can take an object."""
|
|
|
|
@decorate
|
|
@decorate
|
|
def __init__(self, hparams, *my_args, **my_kwargs):
|
|
super().__init__()
|
|
self.save_hyperparameters(hparams)
|
|
|
|
|
|
class SaveHparamsDataModule(BoringDataModule):
|
|
"""Tests that a model can take an object."""
|
|
|
|
def __init__(self, hparams):
|
|
super().__init__()
|
|
self.save_hyperparameters(hparams)
|
|
|
|
|
|
class SaveHparamsDecoratedDataModule(BoringDataModule):
|
|
"""Tests that a model can take an object."""
|
|
|
|
@decorate
|
|
@decorate
|
|
def __init__(self, hparams, *my_args, **my_kwargs):
|
|
super().__init__()
|
|
self.save_hyperparameters(hparams)
|
|
|
|
|
|
# -------------------------
|
|
# STANDARD TESTS
|
|
# -------------------------
|
|
def _run_standard_hparams_test(tmpdir, model, cls, datamodule=None, try_overwrite=False):
|
|
"""Tests for the existence of an arg 'test_arg=14'."""
|
|
obj = datamodule if issubclass(cls, LightningDataModule) else model
|
|
|
|
hparam_type = type(obj.hparams)
|
|
# test proper property assignments
|
|
assert obj.hparams.test_arg == 14
|
|
|
|
# verify we can train
|
|
trainer = Trainer(default_root_dir=tmpdir, max_epochs=1, overfit_batches=2)
|
|
trainer.fit(model, datamodule=datamodule if issubclass(cls, LightningDataModule) else None)
|
|
|
|
# make sure the raw checkpoint saved the properties
|
|
raw_checkpoint_path = _raw_checkpoint_path(trainer)
|
|
raw_checkpoint = torch.load(raw_checkpoint_path)
|
|
assert cls.CHECKPOINT_HYPER_PARAMS_KEY in raw_checkpoint
|
|
assert raw_checkpoint[cls.CHECKPOINT_HYPER_PARAMS_KEY]["test_arg"] == 14
|
|
|
|
# verify that model loads correctly
|
|
obj2 = cls.load_from_checkpoint(raw_checkpoint_path)
|
|
assert obj2.hparams.test_arg == 14
|
|
|
|
assert isinstance(obj2.hparams, hparam_type)
|
|
|
|
if try_overwrite:
|
|
# verify that we can overwrite the property
|
|
obj3 = cls.load_from_checkpoint(raw_checkpoint_path, test_arg=78)
|
|
assert obj3.hparams.test_arg == 78
|
|
|
|
return raw_checkpoint_path
|
|
|
|
|
|
@pytest.mark.parametrize(
|
|
"cls", [SaveHparamsModel, SaveHparamsDecoratedModel, SaveHparamsDataModule, SaveHparamsDecoratedDataModule]
|
|
)
|
|
def test_namespace_hparams(tmpdir, cls):
|
|
hparams = Namespace(test_arg=14)
|
|
|
|
if issubclass(cls, LightningDataModule):
|
|
model = BoringModel()
|
|
datamodule = cls(hparams=hparams)
|
|
else:
|
|
model = cls(hparams=hparams)
|
|
datamodule = None
|
|
|
|
# run standard test suite
|
|
_run_standard_hparams_test(tmpdir, model, cls, datamodule=datamodule)
|
|
|
|
|
|
@pytest.mark.parametrize(
|
|
"cls", [SaveHparamsModel, SaveHparamsDecoratedModel, SaveHparamsDataModule, SaveHparamsDecoratedDataModule]
|
|
)
|
|
def test_dict_hparams(tmpdir, cls):
|
|
hparams = {"test_arg": 14}
|
|
if issubclass(cls, LightningDataModule):
|
|
model = BoringModel()
|
|
datamodule = cls(hparams=hparams)
|
|
else:
|
|
model = cls(hparams=hparams)
|
|
datamodule = None
|
|
|
|
# run standard test suite
|
|
_run_standard_hparams_test(tmpdir, model, cls, datamodule=datamodule)
|
|
|
|
|
|
@RunIf(omegaconf=True)
|
|
@pytest.mark.parametrize(
|
|
"cls", [SaveHparamsModel, SaveHparamsDecoratedModel, SaveHparamsDataModule, SaveHparamsDecoratedDataModule]
|
|
)
|
|
def test_omega_conf_hparams(tmpdir, cls):
|
|
conf = OmegaConf.create(dict(test_arg=14, mylist=[15.4, dict(a=1, b=2)]))
|
|
if issubclass(cls, LightningDataModule):
|
|
model = BoringModel()
|
|
obj = datamodule = cls(hparams=conf)
|
|
else:
|
|
obj = model = cls(hparams=conf)
|
|
datamodule = None
|
|
|
|
assert isinstance(obj.hparams, Container)
|
|
|
|
# run standard test suite
|
|
raw_checkpoint_path = _run_standard_hparams_test(tmpdir, model, cls, datamodule=datamodule)
|
|
obj2 = cls.load_from_checkpoint(raw_checkpoint_path)
|
|
assert isinstance(obj2.hparams, Container)
|
|
|
|
# config specific tests
|
|
assert obj2.hparams.test_arg == 14
|
|
assert obj2.hparams.mylist[0] == 15.4
|
|
|
|
|
|
def test_explicit_args_hparams(tmpdir):
|
|
"""Tests that a model can take implicit args and assign."""
|
|
|
|
# define model
|
|
class LocalModel(BoringModel):
|
|
def __init__(self, test_arg, test_arg2):
|
|
super().__init__()
|
|
self.save_hyperparameters("test_arg", "test_arg2")
|
|
|
|
model = LocalModel(test_arg=14, test_arg2=90)
|
|
|
|
# run standard test suite
|
|
raw_checkpoint_path = _run_standard_hparams_test(tmpdir, model, LocalModel)
|
|
model = LocalModel.load_from_checkpoint(raw_checkpoint_path, test_arg2=120)
|
|
|
|
# config specific tests
|
|
assert model.hparams.test_arg2 == 120
|
|
|
|
|
|
def test_implicit_args_hparams(tmpdir):
|
|
"""Tests that a model can take regular args and assign."""
|
|
|
|
# define model
|
|
class LocalModel(BoringModel):
|
|
def __init__(self, test_arg, test_arg2):
|
|
super().__init__()
|
|
self.save_hyperparameters()
|
|
|
|
model = LocalModel(test_arg=14, test_arg2=90)
|
|
|
|
# run standard test suite
|
|
raw_checkpoint_path = _run_standard_hparams_test(tmpdir, model, LocalModel)
|
|
model = LocalModel.load_from_checkpoint(raw_checkpoint_path, test_arg2=120)
|
|
|
|
# config specific tests
|
|
assert model.hparams.test_arg2 == 120
|
|
|
|
|
|
def test_explicit_missing_args_hparams(tmpdir):
|
|
"""Tests that a model can take regular args and assign."""
|
|
|
|
# define model
|
|
class LocalModel(BoringModel):
|
|
def __init__(self, test_arg, test_arg2):
|
|
super().__init__()
|
|
self.save_hyperparameters("test_arg")
|
|
|
|
model = LocalModel(test_arg=14, test_arg2=90)
|
|
|
|
# test proper property assignments
|
|
assert model.hparams.test_arg == 14
|
|
|
|
# verify we can train
|
|
trainer = Trainer(default_root_dir=tmpdir, max_epochs=2, overfit_batches=0.5)
|
|
trainer.fit(model)
|
|
|
|
# make sure the raw checkpoint saved the properties
|
|
raw_checkpoint_path = _raw_checkpoint_path(trainer)
|
|
raw_checkpoint = torch.load(raw_checkpoint_path)
|
|
assert LightningModule.CHECKPOINT_HYPER_PARAMS_KEY in raw_checkpoint
|
|
assert raw_checkpoint[LightningModule.CHECKPOINT_HYPER_PARAMS_KEY]["test_arg"] == 14
|
|
|
|
# verify that model loads correctly
|
|
model = LocalModel.load_from_checkpoint(raw_checkpoint_path, test_arg2=123)
|
|
assert model.hparams.test_arg == 14
|
|
assert "test_arg2" not in model.hparams # test_arg2 is not registered in class init
|
|
|
|
return raw_checkpoint_path
|
|
|
|
|
|
# -------------------------
|
|
# SPECIFIC TESTS
|
|
# -------------------------
|
|
|
|
|
|
def test_class_nesting():
|
|
class MyModule(LightningModule):
|
|
def forward(self):
|
|
...
|
|
|
|
# make sure PL modules are always nn.Module
|
|
a = MyModule()
|
|
assert isinstance(a, torch.nn.Module)
|
|
|
|
def test_outside():
|
|
a = MyModule()
|
|
_ = a.hparams
|
|
|
|
class A:
|
|
def test(self):
|
|
a = MyModule()
|
|
_ = a.hparams
|
|
|
|
def test2(self):
|
|
test_outside()
|
|
|
|
test_outside()
|
|
A().test2()
|
|
A().test()
|
|
|
|
|
|
class CustomBoringModel(BoringModel):
|
|
def __init__(self, batch_size=64):
|
|
super().__init__()
|
|
self.save_hyperparameters()
|
|
|
|
|
|
class SubClassBoringModel(CustomBoringModel):
|
|
any_other_loss = torch.nn.CrossEntropyLoss()
|
|
|
|
def __init__(self, *args, subclass_arg=1200, **kwargs):
|
|
super().__init__(*args, **kwargs)
|
|
self.save_hyperparameters()
|
|
|
|
|
|
class NonSavingSubClassBoringModel(CustomBoringModel):
|
|
any_other_loss = torch.nn.CrossEntropyLoss()
|
|
|
|
def __init__(self, *args, subclass_arg=1200, **kwargs):
|
|
super().__init__(*args, **kwargs)
|
|
|
|
|
|
class SubSubClassBoringModel(SubClassBoringModel):
|
|
pass
|
|
|
|
|
|
class AggSubClassBoringModel(SubClassBoringModel):
|
|
def __init__(self, *args, my_loss=torch.nn.CrossEntropyLoss(), **kwargs):
|
|
super().__init__(*args, **kwargs)
|
|
self.save_hyperparameters()
|
|
|
|
|
|
class UnconventionalArgsBoringModel(CustomBoringModel):
|
|
"""A model that has unconventional names for "self", "*args" and "**kwargs"."""
|
|
|
|
def __init__(obj, *more_args, other_arg=300, **more_kwargs):
|
|
# intentionally named obj
|
|
super().__init__(*more_args, **more_kwargs)
|
|
obj.save_hyperparameters()
|
|
|
|
|
|
if _OMEGACONF_AVAILABLE:
|
|
|
|
class DictConfSubClassBoringModel(SubClassBoringModel):
|
|
def __init__(self, *args, dict_conf=OmegaConf.create(dict(my_param="something")), **kwargs):
|
|
super().__init__(*args, **kwargs)
|
|
self.save_hyperparameters()
|
|
|
|
else:
|
|
|
|
class DictConfSubClassBoringModel:
|
|
...
|
|
|
|
|
|
@pytest.mark.parametrize(
|
|
"cls",
|
|
[
|
|
CustomBoringModel,
|
|
SubClassBoringModel,
|
|
NonSavingSubClassBoringModel,
|
|
SubSubClassBoringModel,
|
|
AggSubClassBoringModel,
|
|
UnconventionalArgsBoringModel,
|
|
pytest.param(DictConfSubClassBoringModel, marks=RunIf(omegaconf=True)),
|
|
],
|
|
)
|
|
def test_collect_init_arguments(tmpdir, cls):
|
|
"""Test that the model automatically saves the arguments passed into the constructor."""
|
|
extra_args = {}
|
|
if cls is AggSubClassBoringModel:
|
|
extra_args.update(my_loss=torch.nn.CosineEmbeddingLoss())
|
|
elif cls is DictConfSubClassBoringModel:
|
|
extra_args.update(dict_conf=OmegaConf.create(dict(my_param="anything")))
|
|
|
|
model = cls(**extra_args)
|
|
assert model.hparams.batch_size == 64
|
|
model = cls(batch_size=179, **extra_args)
|
|
assert model.hparams.batch_size == 179
|
|
|
|
if isinstance(model, (SubClassBoringModel, NonSavingSubClassBoringModel)):
|
|
assert model.hparams.subclass_arg == 1200
|
|
|
|
if isinstance(model, AggSubClassBoringModel):
|
|
assert isinstance(model.hparams.my_loss, torch.nn.CosineEmbeddingLoss)
|
|
|
|
# verify that the checkpoint saved the correct values
|
|
trainer = Trainer(default_root_dir=tmpdir, max_epochs=2, overfit_batches=0.5)
|
|
trainer.fit(model)
|
|
|
|
raw_checkpoint_path = _raw_checkpoint_path(trainer)
|
|
|
|
raw_checkpoint = torch.load(raw_checkpoint_path)
|
|
assert LightningModule.CHECKPOINT_HYPER_PARAMS_KEY in raw_checkpoint
|
|
assert raw_checkpoint[LightningModule.CHECKPOINT_HYPER_PARAMS_KEY]["batch_size"] == 179
|
|
|
|
# verify that model loads correctly
|
|
model = cls.load_from_checkpoint(raw_checkpoint_path)
|
|
assert model.hparams.batch_size == 179
|
|
|
|
if isinstance(model, AggSubClassBoringModel):
|
|
assert isinstance(model.hparams.my_loss, torch.nn.CosineEmbeddingLoss)
|
|
|
|
if isinstance(model, DictConfSubClassBoringModel):
|
|
assert isinstance(model.hparams.dict_conf, Container)
|
|
assert model.hparams.dict_conf["my_param"] == "anything"
|
|
|
|
# verify that we can overwrite whatever we want
|
|
model = cls.load_from_checkpoint(raw_checkpoint_path, batch_size=99)
|
|
assert model.hparams.batch_size == 99
|
|
|
|
|
|
def _raw_checkpoint_path(trainer) -> str:
|
|
raw_checkpoint_paths = os.listdir(trainer.checkpoint_callback.dirpath)
|
|
raw_checkpoint_paths = [x for x in raw_checkpoint_paths if ".ckpt" in x]
|
|
assert raw_checkpoint_paths
|
|
raw_checkpoint_path = raw_checkpoint_paths[0]
|
|
raw_checkpoint_path = os.path.join(trainer.checkpoint_callback.dirpath, raw_checkpoint_path)
|
|
return raw_checkpoint_path
|
|
|
|
|
|
@pytest.mark.parametrize("base_class", (HyperparametersMixin, LightningModule, LightningDataModule))
|
|
def test_save_hyperparameters_under_composition(base_class):
|
|
"""Test that in a composition where the parent is not a Lightning-like module, the parent's arguments don't get
|
|
collected."""
|
|
|
|
class ChildInComposition(base_class):
|
|
def __init__(self, same_arg):
|
|
super().__init__()
|
|
self.save_hyperparameters()
|
|
|
|
class NotPLSubclass: # intentionally not subclassing LightningModule/LightningDataModule
|
|
def __init__(self, same_arg="parent_default", other_arg="other"):
|
|
self.child = ChildInComposition(same_arg="cocofruit")
|
|
|
|
parent = NotPLSubclass()
|
|
assert parent.child.hparams == dict(same_arg="cocofruit")
|
|
|
|
|
|
class LocalVariableModelSuperLast(BoringModel):
|
|
"""This model has the super().__init__() call at the end."""
|
|
|
|
def __init__(self, arg1, arg2, *args, **kwargs):
|
|
self.argument1 = arg1 # arg2 intentionally not set
|
|
arg1 = "overwritten"
|
|
local_var = 1234 # noqa: F841
|
|
super().__init__(*args, **kwargs) # this is intentionally here at the end
|
|
|
|
|
|
class LocalVariableModelSuperFirst(BoringModel):
|
|
"""This model has the save_hyperparameters() call at the end."""
|
|
|
|
def __init__(self, arg1, arg2, *args, **kwargs):
|
|
super().__init__(*args, **kwargs)
|
|
self.argument1 = arg1 # arg2 intentionally not set
|
|
arg1 = "overwritten"
|
|
local_var = 1234 # noqa: F841
|
|
self.save_hyperparameters() # this is intentionally here at the end
|
|
|
|
|
|
@pytest.mark.parametrize(
|
|
"cls",
|
|
[
|
|
LocalVariableModelSuperFirst,
|
|
# LocalVariableModelSuperLast,
|
|
],
|
|
)
|
|
def test_collect_init_arguments_with_local_vars(cls):
|
|
"""Tests that only the arguments are collected and not local variables."""
|
|
model = cls(arg1=1, arg2=2)
|
|
assert "local_var" not in model.hparams
|
|
assert model.hparams["arg1"] == "overwritten"
|
|
assert model.hparams["arg2"] == 2
|
|
|
|
|
|
class AnotherArgModel(BoringModel):
|
|
def __init__(self, arg1):
|
|
super().__init__()
|
|
self.save_hyperparameters(arg1)
|
|
|
|
|
|
class OtherArgsModel(BoringModel):
|
|
def __init__(self, arg1, arg2):
|
|
|
|
super().__init__()
|
|
self.save_hyperparameters(arg1, arg2)
|
|
|
|
|
|
@pytest.mark.parametrize(
|
|
"cls,config", [(AnotherArgModel, dict(arg1=42)), (OtherArgsModel, dict(arg1=3.14, arg2="abc"))]
|
|
)
|
|
def test_single_config_models_fail(tmpdir, cls, config):
|
|
"""Test fail on passing unsupported config type."""
|
|
with pytest.raises(ValueError):
|
|
_ = cls(**config)
|
|
|
|
|
|
@pytest.mark.parametrize("past_key", ["module_arguments"])
|
|
def test_load_past_checkpoint(tmpdir, past_key):
|
|
model = CustomBoringModel()
|
|
|
|
# verify we can train
|
|
trainer = Trainer(default_root_dir=tmpdir, max_epochs=1)
|
|
trainer.fit(model)
|
|
|
|
# make sure the raw checkpoint saved the properties
|
|
raw_checkpoint_path = _raw_checkpoint_path(trainer)
|
|
raw_checkpoint = torch.load(raw_checkpoint_path)
|
|
raw_checkpoint[past_key] = raw_checkpoint[LightningModule.CHECKPOINT_HYPER_PARAMS_KEY]
|
|
raw_checkpoint["hparams_type"] = "Namespace"
|
|
raw_checkpoint[past_key]["batch_size"] = -17
|
|
del raw_checkpoint[LightningModule.CHECKPOINT_HYPER_PARAMS_KEY]
|
|
# save back the checkpoint
|
|
torch.save(raw_checkpoint, raw_checkpoint_path)
|
|
|
|
# verify that model loads correctly
|
|
model2 = CustomBoringModel.load_from_checkpoint(raw_checkpoint_path)
|
|
assert model2.hparams.batch_size == -17
|
|
|
|
|
|
def test_hparams_pickle(tmpdir):
|
|
ad = AttributeDict({"key1": 1, "key2": "abc"})
|
|
pkl = pickle.dumps(ad)
|
|
assert ad == pickle.loads(pkl)
|
|
pkl = cloudpickle.dumps(ad)
|
|
assert ad == pickle.loads(pkl)
|
|
|
|
|
|
class UnpickleableArgsBoringModel(BoringModel):
|
|
"""A model that has an attribute that cannot be pickled."""
|
|
|
|
def __init__(self, foo="bar", pickle_me=(lambda x: x + 1), **kwargs):
|
|
super().__init__(**kwargs)
|
|
assert not is_picklable(pickle_me)
|
|
self.save_hyperparameters()
|
|
|
|
|
|
def test_hparams_pickle_warning(tmpdir):
|
|
model = UnpickleableArgsBoringModel()
|
|
trainer = Trainer(default_root_dir=tmpdir, max_steps=1)
|
|
with pytest.warns(UserWarning, match="attribute 'pickle_me' removed from hparams because it cannot be pickled"):
|
|
trainer.fit(model)
|
|
assert "pickle_me" not in model.hparams
|
|
|
|
|
|
def test_hparams_save_yaml(tmpdir):
|
|
class Options(str, Enum):
|
|
option1name = "option1val"
|
|
option2name = "option2val"
|
|
option3name = "option3val"
|
|
|
|
hparams = dict(
|
|
batch_size=32,
|
|
learning_rate=0.001,
|
|
data_root="./any/path/here",
|
|
nested=dict(any_num=123, anystr="abcd"),
|
|
switch=Options.option3name,
|
|
)
|
|
path_yaml = os.path.join(tmpdir, "testing-hparams.yaml")
|
|
|
|
def _compare_params(loaded_params, default_params: dict):
|
|
assert isinstance(loaded_params, (dict, DictConfig))
|
|
assert loaded_params.keys() == default_params.keys()
|
|
for k, v in default_params.items():
|
|
if isinstance(v, Enum):
|
|
assert v.name == loaded_params[k]
|
|
else:
|
|
assert v == loaded_params[k]
|
|
|
|
save_hparams_to_yaml(path_yaml, hparams)
|
|
_compare_params(load_hparams_from_yaml(path_yaml, use_omegaconf=False), hparams)
|
|
|
|
save_hparams_to_yaml(path_yaml, Namespace(**hparams))
|
|
_compare_params(load_hparams_from_yaml(path_yaml, use_omegaconf=False), hparams)
|
|
|
|
save_hparams_to_yaml(path_yaml, AttributeDict(hparams))
|
|
_compare_params(load_hparams_from_yaml(path_yaml, use_omegaconf=False), hparams)
|
|
|
|
if _OMEGACONF_AVAILABLE:
|
|
save_hparams_to_yaml(path_yaml, OmegaConf.create(hparams))
|
|
_compare_params(load_hparams_from_yaml(path_yaml), hparams)
|
|
|
|
|
|
class NoArgsSubClassBoringModel(CustomBoringModel):
|
|
def __init__(self):
|
|
super().__init__()
|
|
|
|
|
|
@pytest.mark.parametrize("cls", [BoringModel, NoArgsSubClassBoringModel])
|
|
def test_model_nohparams_train_test(tmpdir, cls):
|
|
"""Test models that do not take any argument in init."""
|
|
|
|
model = cls()
|
|
trainer = Trainer(max_epochs=1, default_root_dir=tmpdir)
|
|
|
|
train_loader = DataLoader(RandomDataset(32, 64), batch_size=32)
|
|
trainer.fit(model, train_loader)
|
|
|
|
test_loader = DataLoader(RandomDataset(32, 64), batch_size=32)
|
|
trainer.test(dataloaders=test_loader)
|
|
|
|
|
|
def test_model_ignores_non_exist_kwargument(tmpdir):
|
|
"""Test that the model takes only valid class arguments."""
|
|
|
|
class LocalModel(BoringModel):
|
|
def __init__(self, batch_size=15):
|
|
super().__init__()
|
|
self.save_hyperparameters()
|
|
|
|
model = LocalModel()
|
|
assert model.hparams.batch_size == 15
|
|
|
|
# verify that the checkpoint saved the correct values
|
|
trainer = Trainer(default_root_dir=tmpdir, max_epochs=1)
|
|
trainer.fit(model)
|
|
|
|
# verify that we can overwrite whatever we want
|
|
raw_checkpoint_path = _raw_checkpoint_path(trainer)
|
|
model = LocalModel.load_from_checkpoint(raw_checkpoint_path, non_exist_kwarg=99)
|
|
assert "non_exist_kwarg" not in model.hparams
|
|
|
|
|
|
class SuperClassPositionalArgs(BoringModel):
|
|
def __init__(self, hparams):
|
|
super().__init__()
|
|
self._hparams = hparams # pretend BoringModel did not call self.save_hyperparameters()
|
|
|
|
|
|
class SubClassVarArgs(SuperClassPositionalArgs):
|
|
"""Loading this model should accept hparams and init in the super class."""
|
|
|
|
def __init__(self, *args, **kwargs) -> None:
|
|
super().__init__(*args, **kwargs)
|
|
|
|
|
|
def test_args(tmpdir):
|
|
"""Test for inheritance: super class takes positional arg, subclass takes varargs."""
|
|
hparams = dict(test=1)
|
|
model = SubClassVarArgs(hparams)
|
|
trainer = Trainer(default_root_dir=tmpdir, max_epochs=1)
|
|
trainer.fit(model)
|
|
|
|
raw_checkpoint_path = _raw_checkpoint_path(trainer)
|
|
with pytest.raises(TypeError, match=r"__init__\(\) got an unexpected keyword argument 'test'"):
|
|
SubClassVarArgs.load_from_checkpoint(raw_checkpoint_path)
|
|
|
|
|
|
class RuntimeParamChangeModelSaving(BoringModel):
|
|
def __init__(self, **kwargs):
|
|
super().__init__()
|
|
self.save_hyperparameters()
|
|
|
|
|
|
@pytest.mark.parametrize("cls", [RuntimeParamChangeModelSaving])
|
|
def test_init_arg_with_runtime_change(tmpdir, cls):
|
|
"""Test that we save/export only the initial hparams, no other runtime change allowed."""
|
|
model = cls(running_arg=123)
|
|
assert model.hparams.running_arg == 123
|
|
model.hparams.running_arg = -1
|
|
assert model.hparams.running_arg == -1
|
|
|
|
trainer = Trainer(
|
|
default_root_dir=tmpdir, limit_train_batches=2, limit_val_batches=2, limit_test_batches=2, max_epochs=1
|
|
)
|
|
trainer.fit(model)
|
|
|
|
path_yaml = os.path.join(trainer.logger.log_dir, trainer.logger.NAME_HPARAMS_FILE)
|
|
hparams = load_hparams_from_yaml(path_yaml)
|
|
assert hparams.get("running_arg") == 123
|
|
|
|
|
|
class UnsafeParamModel(BoringModel):
|
|
def __init__(self, my_path, any_param=123):
|
|
super().__init__()
|
|
self.save_hyperparameters()
|
|
|
|
|
|
def test_model_with_fsspec_as_parameter(tmpdir):
|
|
model = UnsafeParamModel(LocalFileSystem(tmpdir))
|
|
trainer = Trainer(
|
|
default_root_dir=tmpdir, limit_train_batches=2, limit_val_batches=2, limit_test_batches=2, max_epochs=1
|
|
)
|
|
trainer.fit(model)
|
|
trainer.test()
|
|
|
|
|
|
@pytest.mark.skipif(RequirementCache("hydra-core<1.1"), reason="Requires Hydra's Compose API")
|
|
def test_model_save_hyper_parameters_interpolation_with_hydra(tmpdir):
|
|
"""This test relies on configuration saved under tests/models/conf/config.yaml."""
|
|
from hydra import compose, initialize
|
|
|
|
class TestHydraModel(BoringModel):
|
|
def __init__(self, args_0, args_1, args_2, kwarg_1=None):
|
|
self.save_hyperparameters()
|
|
assert self.hparams.args_0.log == "Something"
|
|
assert self.hparams.args_1["cfg"].log == "Something"
|
|
assert self.hparams.args_2[0].log == "Something"
|
|
assert self.hparams.kwarg_1["cfg"][0].log == "Something"
|
|
super().__init__()
|
|
|
|
with initialize(config_path="conf"):
|
|
args_0 = compose(config_name="config")
|
|
args_1 = {"cfg": compose(config_name="config")}
|
|
args_2 = [compose(config_name="config")]
|
|
kwarg_1 = {"cfg": [compose(config_name="config")]}
|
|
model = TestHydraModel(args_0, args_1, args_2, kwarg_1=kwarg_1)
|
|
epochs = 2
|
|
checkpoint_callback = ModelCheckpoint(monitor=None, dirpath=tmpdir, save_top_k=-1)
|
|
trainer = Trainer(
|
|
default_root_dir=tmpdir,
|
|
callbacks=[checkpoint_callback],
|
|
limit_train_batches=10,
|
|
limit_val_batches=10,
|
|
max_epochs=epochs,
|
|
logger=False,
|
|
)
|
|
trainer.fit(model)
|
|
_ = TestHydraModel.load_from_checkpoint(checkpoint_callback.best_model_path)
|
|
|
|
|
|
@pytest.mark.parametrize("ignore", ("arg2", ("arg2", "arg3")))
|
|
def test_ignore_args_list_hparams(tmpdir, ignore):
|
|
"""Tests that args can be ignored in save_hyperparameters."""
|
|
|
|
class LocalModel(BoringModel):
|
|
def __init__(self, arg1, arg2, arg3):
|
|
super().__init__()
|
|
self.save_hyperparameters(ignore=ignore)
|
|
|
|
model = LocalModel(arg1=14, arg2=90, arg3=50)
|
|
|
|
# test proper property assignments
|
|
assert model.hparams.arg1 == 14
|
|
for arg in ignore:
|
|
assert arg not in model.hparams
|
|
|
|
# verify we can train
|
|
trainer = Trainer(default_root_dir=tmpdir, max_epochs=2, overfit_batches=0.5)
|
|
trainer.fit(model)
|
|
|
|
# make sure the raw checkpoint saved the properties
|
|
raw_checkpoint_path = _raw_checkpoint_path(trainer)
|
|
raw_checkpoint = torch.load(raw_checkpoint_path)
|
|
assert LightningModule.CHECKPOINT_HYPER_PARAMS_KEY in raw_checkpoint
|
|
assert raw_checkpoint[LightningModule.CHECKPOINT_HYPER_PARAMS_KEY]["arg1"] == 14
|
|
|
|
# verify that model loads correctly
|
|
model = LocalModel.load_from_checkpoint(raw_checkpoint_path, arg2=123, arg3=100)
|
|
assert model.hparams.arg1 == 14
|
|
for arg in ignore:
|
|
assert arg not in model.hparams
|
|
|
|
|
|
class IgnoreAllParametersModel(BoringModel):
|
|
def __init__(self, arg1, arg2, arg3):
|
|
super().__init__()
|
|
self.save_hyperparameters(ignore=("arg1", "arg2", "arg3"))
|
|
|
|
|
|
class NoParametersModel(BoringModel):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.save_hyperparameters()
|
|
|
|
|
|
@pytest.mark.parametrize(
|
|
"model",
|
|
(
|
|
IgnoreAllParametersModel(arg1=14, arg2=90, arg3=50),
|
|
NoParametersModel(),
|
|
),
|
|
)
|
|
def test_save_no_parameters(model):
|
|
"""Test that calling save_hyperparameters works if no parameters need saving."""
|
|
assert model.hparams == {}
|
|
assert model._hparams_initial == {}
|
|
|
|
|
|
class HparamsKwargsContainerModel(BoringModel):
|
|
def __init__(self, **kwargs):
|
|
super().__init__()
|
|
self.save_hyperparameters(kwargs)
|
|
|
|
|
|
class HparamsNamespaceContainerModel(BoringModel):
|
|
def __init__(self, config):
|
|
super().__init__()
|
|
self.save_hyperparameters(config)
|
|
|
|
|
|
def test_empty_hparams_container(tmpdir):
|
|
"""Test that save_hyperparameters() is a no-op when saving an empty hparams container."""
|
|
model = HparamsKwargsContainerModel()
|
|
assert not model.hparams
|
|
model = HparamsNamespaceContainerModel(Namespace())
|
|
assert not model.hparams
|
|
|
|
|
|
def test_hparams_name_from_container(tmpdir):
|
|
"""Test that save_hyperparameters(container) captures the name of the argument correctly."""
|
|
model = HparamsKwargsContainerModel(a=1, b=2)
|
|
assert model._hparams_name is None
|
|
model = HparamsNamespaceContainerModel(Namespace(a=1, b=2))
|
|
assert model._hparams_name == "config"
|
|
|
|
|
|
@dataclass
|
|
class DataClassModel(BoringModel):
|
|
|
|
mandatory: int
|
|
optional: str = "optional"
|
|
ignore_me: bool = False
|
|
|
|
def __post_init__(self):
|
|
super().__init__()
|
|
self.save_hyperparameters(ignore=("ignore_me",))
|
|
|
|
|
|
def test_dataclass_lightning_module(tmpdir):
|
|
"""Test that save_hyperparameters() works with a LightningModule as a dataclass."""
|
|
model = DataClassModel(33, optional="cocofruit")
|
|
assert model.hparams == dict(mandatory=33, optional="cocofruit")
|
|
|
|
|
|
class NoHparamsModel(BoringModel):
|
|
"""Tests a model without hparams."""
|
|
|
|
|
|
class DataModuleWithoutHparams(LightningDataModule):
|
|
def train_dataloader(self, *args, **kwargs) -> DataLoader:
|
|
return DataLoader(RandomDataset(32, 64), batch_size=32)
|
|
|
|
|
|
class DataModuleWithHparams(LightningDataModule):
|
|
def __init__(self, hparams):
|
|
super().__init__()
|
|
self.save_hyperparameters(hparams)
|
|
|
|
def train_dataloader(self, *args, **kwargs) -> DataLoader:
|
|
return DataLoader(RandomDataset(32, 64), batch_size=32)
|
|
|
|
|
|
def _get_mock_logger(tmpdir):
|
|
mock_logger = mock.MagicMock(name="logger")
|
|
mock_logger.name = "mock_logger"
|
|
mock_logger.save_dir = tmpdir
|
|
mock_logger.version = "0"
|
|
del mock_logger.__iter__
|
|
return mock_logger
|
|
|
|
|
|
@pytest.mark.parametrize("model", (SaveHparamsModel({"arg1": 5, "arg2": "abc"}), NoHparamsModel()))
|
|
@pytest.mark.parametrize("data", (DataModuleWithHparams({"data_dir": "foo"}), DataModuleWithoutHparams()))
|
|
def test_adding_datamodule_hparams(tmpdir, model, data):
|
|
"""Test that hparams from datamodule and model are logged."""
|
|
org_model_hparams = copy.deepcopy(model.hparams_initial)
|
|
org_data_hparams = copy.deepcopy(data.hparams_initial)
|
|
|
|
mock_logger = _get_mock_logger(tmpdir)
|
|
trainer = Trainer(default_root_dir=tmpdir, max_epochs=1, logger=mock_logger)
|
|
trainer.fit(model, datamodule=data)
|
|
|
|
# Hparams of model and data were not modified
|
|
assert org_model_hparams == model.hparams
|
|
assert org_data_hparams == data.hparams
|
|
|
|
# Merged hparams were logged
|
|
merged_hparams = copy.deepcopy(org_model_hparams)
|
|
merged_hparams.update(org_data_hparams)
|
|
if merged_hparams:
|
|
mock_logger.log_hyperparams.assert_called_with(merged_hparams)
|
|
else:
|
|
mock_logger.log_hyperparams.assert_not_called()
|
|
|
|
|
|
def test_no_datamodule_for_hparams(tmpdir):
|
|
"""Test that hparams model are logged if no datamodule is used."""
|
|
model = SaveHparamsModel({"arg1": 5, "arg2": "abc"})
|
|
org_model_hparams = copy.deepcopy(model.hparams_initial)
|
|
data = DataModuleWithoutHparams()
|
|
data.setup("fit")
|
|
|
|
mock_logger = _get_mock_logger(tmpdir)
|
|
trainer = Trainer(default_root_dir=tmpdir, max_epochs=1, logger=mock_logger)
|
|
trainer.fit(model, datamodule=data)
|
|
|
|
# Merged hparams were logged
|
|
mock_logger.log_hyperparams.assert_called_with(org_model_hparams)
|
|
|
|
|
|
def test_colliding_hparams(tmpdir):
|
|
|
|
model = SaveHparamsModel({"data_dir": "abc", "arg2": "abc"})
|
|
data = DataModuleWithHparams({"data_dir": "foo"})
|
|
|
|
trainer = Trainer(default_root_dir=tmpdir, max_epochs=1)
|
|
with pytest.raises(MisconfigurationException, match=r"Error while merging hparams:"):
|
|
trainer.fit(model, datamodule=data)
|
|
|
|
|
|
def test_nn_modules_warning_when_saved_as_hparams():
|
|
class TorchModule(torch.nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.l1 = torch.nn.Linear(4, 5)
|
|
|
|
class CustomBoringModelWarn(BoringModel):
|
|
def __init__(self, encoder, decoder, other_hparam=7):
|
|
super().__init__()
|
|
self.save_hyperparameters()
|
|
|
|
with pytest.warns(UserWarning, match="is an instance of `nn.Module` and is already saved"):
|
|
model = CustomBoringModelWarn(encoder=TorchModule(), decoder=TorchModule())
|
|
|
|
assert list(model.hparams) == ["encoder", "decoder", "other_hparam"]
|
|
|
|
class CustomBoringModelNoWarn(BoringModel):
|
|
def __init__(self, encoder, decoder, other_hparam=7):
|
|
super().__init__()
|
|
self.save_hyperparameters("other_hparam")
|
|
|
|
with no_warning_call(UserWarning, match="is an instance of `nn.Module` and is already saved"):
|
|
model = CustomBoringModelNoWarn(encoder=TorchModule(), decoder=TorchModule())
|
|
|
|
assert list(model.hparams) == ["other_hparam"]
|