diff --git a/pytorch_lightning/core/saving.py b/pytorch_lightning/core/saving.py index ed299bb8a8..4970644470 100644 --- a/pytorch_lightning/core/saving.py +++ b/pytorch_lightning/core/saving.py @@ -21,10 +21,17 @@ except ImportError: else: ALLOWED_CONFIG_TYPES = ALLOWED_CONFIG_TYPES + (Container, ) +# the older shall be on the top +CHECKPOINT_PAST_HPARAMS_KEYS = ( + 'hparams', + 'module_arguments', # used in 0.7.6 +) + class ModelIO(object): - CHECKPOINT_KEY_HYPER_PARAMS = 'hyper_parameters' - CHECKPOINT_NAME_HYPER_PARAMS = 'hparams_name' + CHECKPOINT_HYPER_PARAMS_KEY = 'hyper_parameters' + CHECKPOINT_HYPER_PARAMS_NAME = 'hparams_name' + CHECKPOINT_HYPER_PARAMS_TYPE = 'hparams_type' @classmethod def load_from_metrics(cls, weights_path, tags_csv, map_location=None): @@ -153,10 +160,13 @@ class ModelIO(object): hparams['on_gpu'] = False # overwrite hparams by the given file - checkpoint[cls.CHECKPOINT_KEY_HYPER_PARAMS] = hparams + checkpoint[cls.CHECKPOINT_HYPER_PARAMS_KEY] = hparams - # override the module_arguments with values that were passed in - checkpoint[cls.CHECKPOINT_KEY_HYPER_PARAMS].update(kwargs) + # for past checkpoint need to add the new key + if cls.CHECKPOINT_HYPER_PARAMS_KEY not in checkpoint: + checkpoint[cls.CHECKPOINT_HYPER_PARAMS_KEY] = {} + # override the hparams with values that were passed in + checkpoint[cls.CHECKPOINT_HYPER_PARAMS_KEY].update(kwargs) model = cls._load_model_state(checkpoint, *args, **kwargs) return model @@ -164,10 +174,15 @@ class ModelIO(object): @classmethod def _load_model_state(cls, checkpoint: Dict[str, Any], *args, **kwargs): # pass in the values we saved automatically - if cls.CHECKPOINT_KEY_HYPER_PARAMS in checkpoint: - # todo add some back compatibility - model_args = checkpoint[cls.CHECKPOINT_KEY_HYPER_PARAMS] - args_name = checkpoint.get(cls.CHECKPOINT_NAME_HYPER_PARAMS) + if cls.CHECKPOINT_HYPER_PARAMS_KEY in checkpoint: + model_args = {} + # add some back compatibility, the actual one shall be last + for hparam_key in CHECKPOINT_PAST_HPARAMS_KEYS + (cls.CHECKPOINT_HYPER_PARAMS_KEY,): + if hparam_key in checkpoint: + model_args.update(checkpoint[hparam_key]) + if cls.CHECKPOINT_HYPER_PARAMS_TYPE in checkpoint: + model_args = checkpoint[cls.CHECKPOINT_HYPER_PARAMS_TYPE](model_args) + args_name = checkpoint.get(cls.CHECKPOINT_HYPER_PARAMS_NAME) init_args_name = inspect.signature(cls).parameters.keys() if args_name == 'kwargs': cls_kwargs = {k: v for k, v in model_args.items() if k in init_args_name} diff --git a/pytorch_lightning/loggers/tensorboard.py b/pytorch_lightning/loggers/tensorboard.py index 12ce47398d..1c1f8161c0 100644 --- a/pytorch_lightning/loggers/tensorboard.py +++ b/pytorch_lightning/loggers/tensorboard.py @@ -4,7 +4,6 @@ TensorBoard """ import os -import yaml from argparse import Namespace from typing import Optional, Dict, Union, Any from warnings import warn @@ -18,6 +17,11 @@ from pytorch_lightning.core.saving import save_hparams_to_yaml from pytorch_lightning.loggers.base import LightningLoggerBase from pytorch_lightning.utilities import rank_zero_only +try: + from omegaconf import Container +except ImportError: + Container = None + class TensorBoardLogger(LightningLoggerBase): r""" @@ -152,7 +156,14 @@ class TensorBoardLogger(LightningLoggerBase): hparams_file = os.path.join(dir_path, self.NAME_HPARAMS_FILE) # save the metatags file - save_hparams_to_yaml(hparams_file, self.hparams) + if Container is not None: + if isinstance(self.hparams, Container): + from omegaconf import OmegaConf + OmegaConf.save(self.hparams, hparams_file, resolve=True) + else: + save_hparams_to_yaml(hparams_file, self.hparams) + else: + save_hparams_to_yaml(hparams_file, self.hparams) @rank_zero_only def finalize(self, status: str) -> None: diff --git a/pytorch_lightning/trainer/training_io.py b/pytorch_lightning/trainer/training_io.py index 6f4e85d5b2..955f6d768d 100644 --- a/pytorch_lightning/trainer/training_io.py +++ b/pytorch_lightning/trainer/training_io.py @@ -87,12 +87,12 @@ import os import re import signal from abc import ABC -from argparse import Namespace from subprocess import call import torch import torch.distributed as torch_distrib +import pytorch_lightning from pytorch_lightning import _logger as log from pytorch_lightning.core.lightning import LightningModule from pytorch_lightning.loggers import LightningLoggerBase @@ -100,7 +100,7 @@ from pytorch_lightning.overrides.data_parallel import ( LightningDistributedDataParallel, LightningDataParallel, ) -from pytorch_lightning.utilities import rank_zero_warn, parsing +from pytorch_lightning.utilities import rank_zero_warn from pytorch_lightning.utilities.io import load as pl_load try: @@ -119,6 +119,11 @@ except ImportError: else: HOROVOD_AVAILABLE = True +try: + from omegaconf import Container +except ImportError: + Container = None + class TrainerIOMixin(ABC): @@ -267,8 +272,8 @@ class TrainerIOMixin(ABC): try: self._atomic_save(checkpoint, filepath) except AttributeError as err: - if LightningModule.CHECKPOINT_KEY_HYPER_PARAMS in checkpoint: - del checkpoint[LightningModule.CHECKPOINT_KEY_HYPER_PARAMS] + if LightningModule.CHECKPOINT_HYPER_PARAMS_KEY in checkpoint: + del checkpoint[LightningModule.CHECKPOINT_HYPER_PARAMS_KEY] rank_zero_warn('Warning, `module_arguments` dropped from checkpoint.' f' An attribute is not picklable {err}') self._atomic_save(checkpoint, filepath) @@ -320,6 +325,7 @@ class TrainerIOMixin(ABC): checkpoint = { 'epoch': self.current_epoch + 1, 'global_step': self.global_step + 1, + 'pytorch-ligthning_version': pytorch_lightning.__version__, } if not weights_only: @@ -356,10 +362,12 @@ class TrainerIOMixin(ABC): if model.hparams: if hasattr(model, '_hparams_name'): - checkpoint[LightningModule.CHECKPOINT_NAME_HYPER_PARAMS] = model._hparams_name + checkpoint[LightningModule.CHECKPOINT_HYPER_PARAMS_NAME] = model._hparams_name # add arguments to the checkpoint - # todo: add some recursion in case of OmegaConf - checkpoint[LightningModule.CHECKPOINT_KEY_HYPER_PARAMS] = dict(model.hparams) + checkpoint[LightningModule.CHECKPOINT_HYPER_PARAMS_KEY] = model.hparams + if Container is not None: + if isinstance(model.hparams, Container): + checkpoint[LightningModule.CHECKPOINT_HYPER_PARAMS_TYPE] = type(model.hparams) # give the model a chance to add a few things model.on_save_checkpoint(checkpoint) @@ -473,8 +481,8 @@ class TrainerIOMixin(ABC): try: self._atomic_save(checkpoint, filepath) except AttributeError as err: - if LightningModule.CHECKPOINT_KEY_HYPER_PARAMS in checkpoint: - del checkpoint[LightningModule.CHECKPOINT_KEY_HYPER_PARAMS] + if LightningModule.CHECKPOINT_HYPER_PARAMS_KEY in checkpoint: + del checkpoint[LightningModule.CHECKPOINT_HYPER_PARAMS_KEY] rank_zero_warn('warning, `module_arguments` dropped from checkpoint.' f' An attribute is not picklable {err}') self._atomic_save(checkpoint, filepath) diff --git a/tests/models/test_hparams.py b/tests/models/test_hparams.py index 74faec8fba..eef38c4656 100644 --- a/tests/models/test_hparams.py +++ b/tests/models/test_hparams.py @@ -30,31 +30,34 @@ class AssignHparamsModel(EvalModelTemplate): # ------------------------- # STANDARD TESTS # ------------------------- -def _run_standard_hparams_test(tmpdir, model, cls): +def _run_standard_hparams_test(tmpdir, model, cls, try_overwrite=False): """ Tests for the existence of an arg 'test_arg=14' """ + hparam_type = type(model.hparams) # test proper property assignments assert model.hparams.test_arg == 14 # verify we can train - trainer = Trainer(default_root_dir=tmpdir, max_epochs=2, overfit_pct=0.5) + trainer = Trainer(default_root_dir=tmpdir, max_epochs=1, overfit_pct=0.5) trainer.fit(model) # make sure the raw checkpoint saved the properties raw_checkpoint_path = _raw_checkpoint_path(trainer) raw_checkpoint = torch.load(raw_checkpoint_path) - assert LightningModule.CHECKPOINT_KEY_HYPER_PARAMS in raw_checkpoint - assert raw_checkpoint[LightningModule.CHECKPOINT_KEY_HYPER_PARAMS]['test_arg'] == 14 + assert LightningModule.CHECKPOINT_HYPER_PARAMS_KEY in raw_checkpoint + assert raw_checkpoint[LightningModule.CHECKPOINT_HYPER_PARAMS_KEY]['test_arg'] == 14 # verify that model loads correctly - model = cls.load_from_checkpoint(raw_checkpoint_path) - assert model.hparams.test_arg == 14 + model2 = cls.load_from_checkpoint(raw_checkpoint_path) + assert model2.hparams.test_arg == 14 - # todo - # verify that we can overwrite the property - # model = cls.load_from_checkpoint(raw_checkpoint_path, test_arg=78) - # assert model.hparams.test_arg == 78 + assert isinstance(model2.hparams, hparam_type) + + if try_overwrite: + # verify that we can overwrite the property + model3 = cls.load_from_checkpoint(raw_checkpoint_path, test_arg=78) + assert model3.hparams.test_arg == 78 return raw_checkpoint_path @@ -82,14 +85,16 @@ def test_omega_conf_hparams(tmpdir, cls): # init model conf = OmegaConf.create(dict(test_arg=14, mylist=[15.4, dict(a=1, b=2)])) model = cls(hparams=conf) + assert isinstance(model.hparams, Container) # run standard test suite raw_checkpoint_path = _run_standard_hparams_test(tmpdir, model, cls) - model = cls.load_from_checkpoint(raw_checkpoint_path) + model2 = cls.load_from_checkpoint(raw_checkpoint_path) + assert isinstance(model2.hparams, Container) # config specific tests - assert model.hparams.test_arg == 14 - assert model.hparams.mylist[0] == 15.4 + assert model2.hparams.test_arg == 14 + assert model2.hparams.mylist[0] == 15.4 def test_explicit_args_hparams(tmpdir): @@ -157,8 +162,8 @@ def test_explicit_missing_args_hparams(tmpdir): # make sure the raw checkpoint saved the properties raw_checkpoint_path = _raw_checkpoint_path(trainer) raw_checkpoint = torch.load(raw_checkpoint_path) - assert LightningModule.CHECKPOINT_KEY_HYPER_PARAMS in raw_checkpoint - assert raw_checkpoint[LightningModule.CHECKPOINT_KEY_HYPER_PARAMS]['test_arg'] == 14 + assert LightningModule.CHECKPOINT_HYPER_PARAMS_KEY in raw_checkpoint + assert raw_checkpoint[LightningModule.CHECKPOINT_HYPER_PARAMS_KEY]['test_arg'] == 14 # verify that model loads correctly model = TestModel.load_from_checkpoint(raw_checkpoint_path, test_arg2=123) @@ -199,110 +204,91 @@ def test_class_nesting(): A().test() -@pytest.mark.xfail(sys.version_info >= (3, 6), reason='OmegaConf only for Python >= 3.8') -def test_omegaconf(tmpdir): - class OmegaConfModel(EvalModelTemplate): - def __init__(self, ogc): - super().__init__() - self.ogc = ogc - self.size = ogc.list[0] +class SubClassEvalModel(EvalModelTemplate): + any_other_loss = torch.nn.CrossEntropyLoss() - conf = OmegaConf.create({"k": "v", "list": [15.4, {"a": "1", "b": "2"}]}) - model = OmegaConfModel(conf) + def __init__(self, *args, subclass_arg=1200, **kwargs): + super().__init__(*args, **kwargs) + self.save_hyperparameters() - # ensure ogc passed values correctly - assert model.size == 15.4 +class SubSubClassEvalModel(SubClassEvalModel): + pass + + +class AggSubClassEvalModel(SubClassEvalModel): + + def __init__(self, *args, my_loss=torch.nn.CrossEntropyLoss(), **kwargs): + super().__init__(*args, **kwargs) + self.save_hyperparameters() + + +class UnconventionalArgsEvalModel(EvalModelTemplate): + """ A model that has unconventional names for "self", "*args" and "**kwargs". """ + + def __init__(obj, *more_args, other_arg=300, **more_kwargs): + # intentionally named obj + super().__init__(*more_args, **more_kwargs) + obj.save_hyperparameters() + + +class DictConfSubClassEvalModel(SubClassEvalModel): + def __init__(self, *args, dict_conf=OmegaConf.create(dict(my_param='something')), **kwargs): + super().__init__(*args, **kwargs) + self.save_hyperparameters() + + +@pytest.mark.parametrize("cls", [ + EvalModelTemplate, + SubClassEvalModel, + SubSubClassEvalModel, + AggSubClassEvalModel, + UnconventionalArgsEvalModel, + DictConfSubClassEvalModel, +]) +def test_collect_init_arguments(tmpdir, cls): + """ Test that the model automatically saves the arguments passed into the constructor """ + extra_args = {} + if cls is AggSubClassEvalModel: + extra_args.update(my_loss=torch.nn.CosineEmbeddingLoss()) + elif cls is DictConfSubClassEvalModel: + extra_args.update(dict_conf=OmegaConf.create(dict(my_param='anything'))) + + model = cls(**extra_args) + assert model.hparams.batch_size == 32 + model = cls(batch_size=179, **extra_args) + assert model.hparams.batch_size == 179 + + if isinstance(model, SubClassEvalModel): + assert model.hparams.subclass_arg == 1200 + + if isinstance(model, AggSubClassEvalModel): + assert isinstance(model.hparams.my_loss, torch.nn.CosineEmbeddingLoss) + + # verify that the checkpoint saved the correct values trainer = Trainer(default_root_dir=tmpdir, max_epochs=2, overfit_pct=0.5) - result = trainer.fit(model) + trainer.fit(model) + raw_checkpoint_path = _raw_checkpoint_path(trainer) - assert result == 1 + raw_checkpoint = torch.load(raw_checkpoint_path) + assert LightningModule.CHECKPOINT_HYPER_PARAMS_KEY in raw_checkpoint + assert raw_checkpoint[LightningModule.CHECKPOINT_HYPER_PARAMS_KEY]['batch_size'] == 179 + # verify that model loads correctly + # TODO: uncomment and get it pass + # model = cls.load_from_checkpoint(raw_checkpoint_path) + # assert model.hparams.batch_size == 179 + # + # if isinstance(model, AggSubClassEvalModel): + # assert isinstance(model.hparams.my_loss, torch.nn.CrossEntropyLoss) + # + # if isinstance(model, DictConfSubClassEvalModel): + # assert isinstance(model.hparams.dict_conf, Container) + # assert model.hparams.dict_conf == 'anything' -# class SubClassEvalModel(EvalModelTemplate): -# any_other_loss = torch.nn.CrossEntropyLoss() -# -# def __init__(self, *args, subclass_arg=1200, **kwargs): -# super().__init__(*args, **kwargs) -# self.save_hyperparameters() -# -# -# class SubSubClassEvalModel(SubClassEvalModel): -# pass -# -# -# class AggSubClassEvalModel(SubClassEvalModel): -# -# def __init__(self, *args, my_loss=torch.nn.CrossEntropyLoss(), **kwargs): -# super().__init__(*args, **kwargs) -# self.save_hyperparameters() -# -# -# class UnconventionalArgsEvalModel(EvalModelTemplate): -# """ A model that has unconventional names for "self", "*args" and "**kwargs". """ -# -# def __init__(obj, *more_args, other_arg=300, **more_kwargs): -# # intentionally named obj -# super().__init__(*more_args, **more_kwargs) -# obj.save_hyperparameters() -# -# -# class DictConfSubClassEvalModel(SubClassEvalModel): -# def __init__(self, *args, dict_conf=OmegaConf.create(dict(my_param='something')), **kwargs): -# super().__init__(*args, **kwargs) -# self.save_hyperparameters() -# -# -# @pytest.mark.parametrize("cls", [ -# EvalModelTemplate, -# SubClassEvalModel, -# SubSubClassEvalModel, -# AggSubClassEvalModel, -# UnconventionalArgsEvalModel, -# DictConfSubClassEvalModel, -# ]) -# def test_collect_init_arguments(tmpdir, cls): -# """ Test that the model automatically saves the arguments passed into the constructor """ -# extra_args = {} -# if cls is AggSubClassEvalModel: -# extra_args.update(my_loss=torch.nn.CosineEmbeddingLoss()) -# elif cls is DictConfSubClassEvalModel: -# extra_args.update(dict_conf=OmegaConf.create(dict(my_param='anything'))) -# -# model = cls(**extra_args) -# assert model.batch_size == 32 -# model = cls(batch_size=179, **extra_args) -# assert model.batch_size == 179 -# -# if isinstance(model, SubClassEvalModel): -# assert model.subclass_arg == 1200 -# -# if isinstance(model, AggSubClassEvalModel): -# assert isinstance(model.my_loss, torch.nn.CosineEmbeddingLoss) -# -# # verify that the checkpoint saved the correct values -# trainer = Trainer(default_root_dir=tmpdir, max_epochs=2, overfit_pct=0.5) -# trainer.fit(model) -# raw_checkpoint_path = _raw_checkpoint_path(trainer) -# -# raw_checkpoint = torch.load(raw_checkpoint_path) -# assert LightningModule.CHECKPOINT_KEY_HYPER_PARAMS in raw_checkpoint -# assert raw_checkpoint[LightningModule.CHECKPOINT_KEY_HYPER_PARAMS]['batch_size'] == 179 -# -# # verify that model loads correctly -# model = cls.load_from_checkpoint(raw_checkpoint_path) -# assert model.batch_size == 179 -# -# if isinstance(model, AggSubClassEvalModel): -# assert isinstance(model.my_loss, torch.nn.CrossEntropyLoss) -# -# if isinstance(model, DictConfSubClassEvalModel): -# assert isinstance(model.dict_conf, DictConfig) -# assert model.dict_conf == 'anything' -# -# # verify that we can overwrite whatever we want -# model = cls.load_from_checkpoint(raw_checkpoint_path, batch_size=99) -# assert model.batch_size == 99 + # verify that we can overwrite whatever we want + model = cls.load_from_checkpoint(raw_checkpoint_path, batch_size=99) + assert model.hparams.batch_size == 99 def _raw_checkpoint_path(trainer) -> str: @@ -394,6 +380,28 @@ def test_single_config_models_fail(tmpdir, cls, config): _ = cls(**config) +@pytest.mark.parametrize("past_key", ['module_arguments']) +def test_load_past_checkpoint(tmpdir, past_key): + model = EvalModelTemplate() + + # verify we can train + trainer = Trainer(default_root_dir=tmpdir, max_epochs=1) + trainer.fit(model) + + # make sure the raw checkpoint saved the properties + raw_checkpoint_path = _raw_checkpoint_path(trainer) + raw_checkpoint = torch.load(raw_checkpoint_path) + raw_checkpoint[past_key] = raw_checkpoint[LightningModule.CHECKPOINT_HYPER_PARAMS_KEY] + raw_checkpoint[past_key]['batch_size'] = -17 + del raw_checkpoint[LightningModule.CHECKPOINT_HYPER_PARAMS_KEY] + # save back the checkpoint + torch.save(raw_checkpoint, raw_checkpoint_path) + + # verify that model loads correctly + model2 = EvalModelTemplate.load_from_checkpoint(raw_checkpoint_path) + assert model2.hparams.batch_size == -17 + + def test_hparams_pickle(tmpdir): ad = AttributeDict({'key1': 1, 'key2': 'abc'}) pkl = pickle.dumps(ad) diff --git a/tests/trainer/test_trainer.py b/tests/trainer/test_trainer.py index f391260c13..1e99236e7b 100644 --- a/tests/trainer/test_trainer.py +++ b/tests/trainer/test_trainer.py @@ -48,7 +48,7 @@ def test_no_val_module(monkeypatch, tmpdir, tmpdir_server, url_ckpt): # assert ckpt has hparams ckpt = torch.load(new_weights_path) - assert LightningModule.CHECKPOINT_KEY_HYPER_PARAMS in ckpt.keys(), 'module_arguments missing from checkpoints' + assert LightningModule.CHECKPOINT_HYPER_PARAMS_KEY in ckpt.keys(), 'module_arguments missing from checkpoints' # load new model hparams_path = tutils.get_data_path(logger, path_dir=tmpdir)