import os import pickle from argparse import Namespace import cloudpickle import pytest import torch from omegaconf import OmegaConf, Container from torch.nn import functional as F from torch.utils.data import DataLoader from pytorch_lightning import Trainer, LightningModule from pytorch_lightning.core.saving import save_hparams_to_yaml, load_hparams_from_yaml from pytorch_lightning.utilities import AttributeDict, is_picklable from tests.base import EvalModelTemplate, TrialMNIST class SaveHparamsModel(EvalModelTemplate): """ Tests that a model can take an object """ def __init__(self, hparams): super().__init__() self.save_hyperparameters(hparams) class AssignHparamsModel(EvalModelTemplate): """ Tests that a model can take an object with explicit setter """ def __init__(self, hparams): super().__init__() self.hparams = hparams # ------------------------- # STANDARD TESTS # ------------------------- def _run_standard_hparams_test(tmpdir, model, cls, try_overwrite=False): """ Tests for the existence of an arg 'test_arg=14' """ hparam_type = type(model.hparams) # test proper property assignments assert model.hparams.test_arg == 14 # verify we can train trainer = Trainer(default_root_dir=tmpdir, max_epochs=1, overfit_batches=2) trainer.fit(model) # make sure the raw checkpoint saved the properties raw_checkpoint_path = _raw_checkpoint_path(trainer) raw_checkpoint = torch.load(raw_checkpoint_path) assert LightningModule.CHECKPOINT_HYPER_PARAMS_KEY in raw_checkpoint assert raw_checkpoint[LightningModule.CHECKPOINT_HYPER_PARAMS_KEY]['test_arg'] == 14 # verify that model loads correctly model2 = cls.load_from_checkpoint(raw_checkpoint_path) assert model2.hparams.test_arg == 14 assert isinstance(model2.hparams, hparam_type) if try_overwrite: # verify that we can overwrite the property model3 = cls.load_from_checkpoint(raw_checkpoint_path, test_arg=78) assert model3.hparams.test_arg == 78 return raw_checkpoint_path @pytest.mark.parametrize("cls", [SaveHparamsModel, AssignHparamsModel]) def test_namespace_hparams(tmpdir, cls): # init model model = cls(hparams=Namespace(test_arg=14)) # run standard test suite _run_standard_hparams_test(tmpdir, model, cls) @pytest.mark.parametrize("cls", [SaveHparamsModel, AssignHparamsModel]) def test_dict_hparams(tmpdir, cls): # init model model = cls(hparams={'test_arg': 14}) # run standard test suite _run_standard_hparams_test(tmpdir, model, cls) @pytest.mark.parametrize("cls", [SaveHparamsModel, AssignHparamsModel]) def test_omega_conf_hparams(tmpdir, cls): # init model conf = OmegaConf.create(dict(test_arg=14, mylist=[15.4, dict(a=1, b=2)])) model = cls(hparams=conf) assert isinstance(model.hparams, Container) # run standard test suite raw_checkpoint_path = _run_standard_hparams_test(tmpdir, model, cls) model2 = cls.load_from_checkpoint(raw_checkpoint_path) assert isinstance(model2.hparams, Container) # config specific tests assert model2.hparams.test_arg == 14 assert model2.hparams.mylist[0] == 15.4 def test_explicit_args_hparams(tmpdir): """ Tests that a model can take implicit args and assign """ # define model class LocalModel(EvalModelTemplate): def __init__(self, test_arg, test_arg2): super().__init__() self.save_hyperparameters('test_arg', 'test_arg2') model = LocalModel(test_arg=14, test_arg2=90) # run standard test suite raw_checkpoint_path = _run_standard_hparams_test(tmpdir, model, LocalModel) model = LocalModel.load_from_checkpoint(raw_checkpoint_path, test_arg2=120) # config specific tests assert model.hparams.test_arg2 == 120 def test_implicit_args_hparams(tmpdir): """ Tests that a model can take regular args and assign """ # define model class LocalModel(EvalModelTemplate): def __init__(self, test_arg, test_arg2): super().__init__() self.save_hyperparameters() model = LocalModel(test_arg=14, test_arg2=90) # run standard test suite raw_checkpoint_path = _run_standard_hparams_test(tmpdir, model, LocalModel) model = LocalModel.load_from_checkpoint(raw_checkpoint_path, test_arg2=120) # config specific tests assert model.hparams.test_arg2 == 120 def test_explicit_missing_args_hparams(tmpdir): """ Tests that a model can take regular args and assign """ # define model class LocalModel(EvalModelTemplate): def __init__(self, test_arg, test_arg2): super().__init__() self.save_hyperparameters('test_arg') model = LocalModel(test_arg=14, test_arg2=90) # test proper property assignments assert model.hparams.test_arg == 14 # verify we can train trainer = Trainer(default_root_dir=tmpdir, max_epochs=2, overfit_batches=0.5) trainer.fit(model) # make sure the raw checkpoint saved the properties raw_checkpoint_path = _raw_checkpoint_path(trainer) raw_checkpoint = torch.load(raw_checkpoint_path) assert LightningModule.CHECKPOINT_HYPER_PARAMS_KEY in raw_checkpoint assert raw_checkpoint[LightningModule.CHECKPOINT_HYPER_PARAMS_KEY]['test_arg'] == 14 # verify that model loads correctly model = LocalModel.load_from_checkpoint(raw_checkpoint_path, test_arg2=123) assert model.hparams.test_arg == 14 assert 'test_arg2' not in model.hparams # test_arg2 is not registered in class init return raw_checkpoint_path # ------------------------- # SPECIFIC TESTS # ------------------------- def test_class_nesting(): class MyModule(LightningModule): def forward(self): ... # make sure PL modules are always nn.Module a = MyModule() assert isinstance(a, torch.nn.Module) def test_outside(): a = MyModule() _ = a.hparams class A: def test(self): a = MyModule() _ = a.hparams def test2(self): test_outside() test_outside() A().test2() A().test() class SubClassEvalModel(EvalModelTemplate): any_other_loss = torch.nn.CrossEntropyLoss() def __init__(self, *args, subclass_arg=1200, **kwargs): super().__init__(*args, **kwargs) self.save_hyperparameters() class SubSubClassEvalModel(SubClassEvalModel): pass class AggSubClassEvalModel(SubClassEvalModel): def __init__(self, *args, my_loss=torch.nn.CrossEntropyLoss(), **kwargs): super().__init__(*args, **kwargs) self.save_hyperparameters() class UnconventionalArgsEvalModel(EvalModelTemplate): """ A model that has unconventional names for "self", "*args" and "**kwargs". """ def __init__(obj, *more_args, other_arg=300, **more_kwargs): # intentionally named obj super().__init__(*more_args, **more_kwargs) obj.save_hyperparameters() class DictConfSubClassEvalModel(SubClassEvalModel): def __init__(self, *args, dict_conf=OmegaConf.create(dict(my_param='something')), **kwargs): super().__init__(*args, **kwargs) self.save_hyperparameters() @pytest.mark.parametrize("cls", [ EvalModelTemplate, SubClassEvalModel, SubSubClassEvalModel, AggSubClassEvalModel, UnconventionalArgsEvalModel, DictConfSubClassEvalModel, ]) def test_collect_init_arguments(tmpdir, cls): """ Test that the model automatically saves the arguments passed into the constructor """ extra_args = {} if cls is AggSubClassEvalModel: extra_args.update(my_loss=torch.nn.CosineEmbeddingLoss()) elif cls is DictConfSubClassEvalModel: extra_args.update(dict_conf=OmegaConf.create(dict(my_param='anything'))) model = cls(**extra_args) assert model.hparams.batch_size == 32 model = cls(batch_size=179, **extra_args) assert model.hparams.batch_size == 179 if isinstance(model, SubClassEvalModel): assert model.hparams.subclass_arg == 1200 if isinstance(model, AggSubClassEvalModel): assert isinstance(model.hparams.my_loss, torch.nn.CosineEmbeddingLoss) # verify that the checkpoint saved the correct values trainer = Trainer(default_root_dir=tmpdir, max_epochs=2, overfit_batches=0.5) trainer.fit(model) raw_checkpoint_path = _raw_checkpoint_path(trainer) raw_checkpoint = torch.load(raw_checkpoint_path) assert LightningModule.CHECKPOINT_HYPER_PARAMS_KEY in raw_checkpoint assert raw_checkpoint[LightningModule.CHECKPOINT_HYPER_PARAMS_KEY]['batch_size'] == 179 # verify that model loads correctly model = cls.load_from_checkpoint(raw_checkpoint_path) assert model.hparams.batch_size == 179 if isinstance(model, AggSubClassEvalModel): assert isinstance(model.hparams.my_loss, torch.nn.CosineEmbeddingLoss) if isinstance(model, DictConfSubClassEvalModel): assert isinstance(model.hparams.dict_conf, Container) assert model.hparams.dict_conf['my_param'] == 'anything' # verify that we can overwrite whatever we want model = cls.load_from_checkpoint(raw_checkpoint_path, batch_size=99) assert model.hparams.batch_size == 99 def _raw_checkpoint_path(trainer) -> str: raw_checkpoint_paths = os.listdir(trainer.checkpoint_callback.dirpath) raw_checkpoint_paths = [x for x in raw_checkpoint_paths if '.ckpt' in x] assert raw_checkpoint_paths raw_checkpoint_path = raw_checkpoint_paths[0] raw_checkpoint_path = os.path.join(trainer.checkpoint_callback.dirpath, raw_checkpoint_path) return raw_checkpoint_path class LocalVariableModelSuperLast(EvalModelTemplate): """ This model has the super().__init__() call at the end. """ def __init__(self, arg1, arg2, *args, **kwargs): self.argument1 = arg1 # arg2 intentionally not set arg1 = 'overwritten' local_var = 1234 super().__init__(*args, **kwargs) # this is intentionally here at the end class LocalVariableModelSuperFirst(EvalModelTemplate): """ This model has the _auto_collect_arguments() call at the end. """ def __init__(self, arg1, arg2, *args, **kwargs): super().__init__(*args, **kwargs) self.argument1 = arg1 # arg2 intentionally not set arg1 = 'overwritten' local_var = 1234 self.save_hyperparameters() # this is intentionally here at the end @pytest.mark.parametrize("cls", [ LocalVariableModelSuperFirst, # LocalVariableModelSuperLast, ]) def test_collect_init_arguments_with_local_vars(cls): """ Tests that only the arguments are collected and not local variables. """ model = cls(arg1=1, arg2=2) assert 'local_var' not in model.hparams assert model.hparams['arg1'] == 'overwritten' assert model.hparams['arg2'] == 2 # @pytest.mark.parametrize("cls,config", [ # (SaveHparamsModel, Namespace(my_arg=42)), # (SaveHparamsModel, dict(my_arg=42)), # (SaveHparamsModel, OmegaConf.create(dict(my_arg=42))), # (AssignHparamsModel, Namespace(my_arg=42)), # (AssignHparamsModel, dict(my_arg=42)), # (AssignHparamsModel, OmegaConf.create(dict(my_arg=42))), # ]) # def test_single_config_models(tmpdir, cls, config): # """ Test that the model automatically saves the arguments passed into the constructor """ # model = cls(config) # # # no matter how you do it, it should be assigned # assert model.hparams.my_arg == 42 # # # verify that the checkpoint saved the correct values # trainer = Trainer(default_root_dir=tmpdir, max_epochs=2, overfit_batches=0.5) # trainer.fit(model) # # # verify that model loads correctly # raw_checkpoint_path = _raw_checkpoint_path(trainer) # model = cls.load_from_checkpoint(raw_checkpoint_path) # assert model.hparams.my_arg == 42 class AnotherArgModel(EvalModelTemplate): def __init__(self, arg1): super().__init__() self.save_hyperparameters(arg1) class OtherArgsModel(EvalModelTemplate): def __init__(self, arg1, arg2): super().__init__() self.save_hyperparameters(arg1, arg2) @pytest.mark.parametrize("cls,config", [ (AnotherArgModel, dict(arg1=42)), (OtherArgsModel, dict(arg1=3.14, arg2='abc')), ]) def test_single_config_models_fail(tmpdir, cls, config): """ Test fail on passing unsupported config type. """ with pytest.raises(ValueError): _ = cls(**config) @pytest.mark.parametrize("past_key", ['module_arguments']) def test_load_past_checkpoint(tmpdir, past_key): model = EvalModelTemplate() # verify we can train trainer = Trainer(default_root_dir=tmpdir, max_epochs=1) trainer.fit(model) # make sure the raw checkpoint saved the properties raw_checkpoint_path = _raw_checkpoint_path(trainer) raw_checkpoint = torch.load(raw_checkpoint_path) raw_checkpoint[past_key] = raw_checkpoint[LightningModule.CHECKPOINT_HYPER_PARAMS_KEY] raw_checkpoint['hparams_type'] = 'Namespace' raw_checkpoint[past_key]['batch_size'] = -17 del raw_checkpoint[LightningModule.CHECKPOINT_HYPER_PARAMS_KEY] # save back the checkpoint torch.save(raw_checkpoint, raw_checkpoint_path) # verify that model loads correctly model2 = EvalModelTemplate.load_from_checkpoint(raw_checkpoint_path) assert model2.hparams.batch_size == -17 def test_hparams_pickle(tmpdir): ad = AttributeDict({'key1': 1, 'key2': 'abc'}) pkl = pickle.dumps(ad) assert ad == pickle.loads(pkl) pkl = cloudpickle.dumps(ad) assert ad == pickle.loads(pkl) class UnpickleableArgsEvalModel(EvalModelTemplate): """ A model that has an attribute that cannot be pickled. """ def __init__(self, foo='bar', pickle_me=(lambda x: x + 1), **kwargs): super().__init__(**kwargs) assert not is_picklable(pickle_me) self.save_hyperparameters() def test_hparams_pickle_warning(tmpdir): model = UnpickleableArgsEvalModel() trainer = Trainer(default_root_dir=tmpdir, max_steps=1) with pytest.warns(UserWarning, match="attribute 'pickle_me' removed from hparams because it cannot be pickled"): trainer.fit(model) assert 'pickle_me' not in model.hparams def test_hparams_save_yaml(tmpdir): hparams = dict(batch_size=32, learning_rate=0.001, data_root='./any/path/here', nasted=dict(any_num=123, anystr='abcd')) path_yaml = os.path.join(tmpdir, 'testing-hparams.yaml') save_hparams_to_yaml(path_yaml, hparams) assert load_hparams_from_yaml(path_yaml) == hparams save_hparams_to_yaml(path_yaml, Namespace(**hparams)) assert load_hparams_from_yaml(path_yaml) == hparams save_hparams_to_yaml(path_yaml, AttributeDict(hparams)) assert load_hparams_from_yaml(path_yaml) == hparams save_hparams_to_yaml(path_yaml, OmegaConf.create(hparams)) assert load_hparams_from_yaml(path_yaml) == hparams class NoArgsSubClassEvalModel(EvalModelTemplate): def __init__(self): super().__init__() class SimpleNoArgsModel(LightningModule): def __init__(self): super().__init__() self.l1 = torch.nn.Linear(28 * 28, 10) def forward(self, x): return torch.relu(self.l1(x.view(x.size(0), -1))) def training_step(self, batch, batch_nb): x, y = batch loss = F.cross_entropy(self(x), y) return {'loss': loss, 'log': {'train_loss': loss}} def test_step(self, batch, batch_nb): x, y = batch loss = F.cross_entropy(self(x), y) return {'loss': loss, 'log': {'train_loss': loss}} def configure_optimizers(self): return torch.optim.Adam(self.parameters(), lr=0.02) @pytest.mark.parametrize("cls", [ SimpleNoArgsModel, NoArgsSubClassEvalModel, ]) def test_model_nohparams_train_test(tmpdir, cls): """Test models that do not tae any argument in init.""" model = cls() trainer = Trainer( max_epochs=1, default_root_dir=tmpdir, ) train_loader = DataLoader(TrialMNIST(os.getcwd(), train=True, download=True), batch_size=32) trainer.fit(model, train_loader) test_loader = DataLoader(TrialMNIST(os.getcwd(), train=False, download=True), batch_size=32) trainer.test(test_dataloaders=test_loader) def test_model_ignores_non_exist_kwargument(tmpdir): """Test that the model takes only valid class arguments.""" class LocalModel(EvalModelTemplate): def __init__(self, batch_size=15): super().__init__(batch_size=batch_size) self.save_hyperparameters() model = LocalModel() assert model.hparams.batch_size == 15 # verify that the checkpoint saved the correct values trainer = Trainer(default_root_dir=tmpdir, max_epochs=1) trainer.fit(model) # verify that we can overwrite whatever we want raw_checkpoint_path = _raw_checkpoint_path(trainer) model = LocalModel.load_from_checkpoint(raw_checkpoint_path, non_exist_kwarg=99) assert 'non_exist_kwarg' not in model.hparams class SuperClassPositionalArgs(EvalModelTemplate): def __init__(self, hparams): super().__init__() self._hparams = None # pretend EvalModelTemplate did not call self.save_hyperparameters() self.hparams = hparams class SubClassVarArgs(SuperClassPositionalArgs): """ Loading this model should accept hparams and init in the super class """ def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) def test_args(tmpdir): """ Test for inheritance: super class takes positional arg, subclass takes varargs. """ hparams = dict(test=1) model = SubClassVarArgs(hparams) trainer = Trainer(default_root_dir=tmpdir, max_epochs=1) trainer.fit(model) raw_checkpoint_path = _raw_checkpoint_path(trainer) with pytest.raises(TypeError, match="__init__\(\) got an unexpected keyword argument 'test'"): SubClassVarArgs.load_from_checkpoint(raw_checkpoint_path)