diff --git a/pytorch_lightning/core/lightning.py b/pytorch_lightning/core/lightning.py index df163aaa10..d2190a37f5 100644 --- a/pytorch_lightning/core/lightning.py +++ b/pytorch_lightning/core/lightning.py @@ -1,6 +1,7 @@ import collections import inspect import os +import re from abc import ABC, abstractmethod from argparse import Namespace from typing import Any, Callable, Dict, List, Optional, Tuple, Union, Sequence @@ -1692,4 +1693,20 @@ class LightningModule(ABC, DeviceDtypeModuleMixin, GradInformation, ModelIO, Mod @hparams.setter def hparams(self, hp: Union[dict, Namespace, Any]): - self.save_hyperparameters(hp, frame=inspect.currentframe().f_back.f_back) + hparams_assignment_name = self.__get_hparams_assignment_variable() + self._hparams_name = hparams_assignment_name + self._set_hparams(hp) + + def __get_hparams_assignment_variable(self): + """ + looks at the code of the class to figure out what the user named self.hparams + this only happens when the user explicitly sets self.hparams + """ + class_code = inspect.getsource(self.__class__) + lines = class_code.split('\n') + for line in lines: + line = re.sub(r"\s+", "", line, flags=re.UNICODE) + if 'self.hparams=' in line: + return line.split('=')[1] + + return None diff --git a/pytorch_lightning/core/saving.py b/pytorch_lightning/core/saving.py index 453b163a02..e6c7b0e222 100644 --- a/pytorch_lightning/core/saving.py +++ b/pytorch_lightning/core/saving.py @@ -176,14 +176,18 @@ class ModelIO(object): # pass in the values we saved automatically if cls.CHECKPOINT_HYPER_PARAMS_KEY in checkpoint: model_args = {} + # add some back compatibility, the actual one shall be last for hparam_key in CHECKPOINT_PAST_HPARAMS_KEYS + (cls.CHECKPOINT_HYPER_PARAMS_KEY,): if hparam_key in checkpoint: model_args.update(checkpoint[hparam_key]) + if cls.CHECKPOINT_HYPER_PARAMS_TYPE in checkpoint: model_args = checkpoint[cls.CHECKPOINT_HYPER_PARAMS_TYPE](model_args) + args_name = checkpoint.get(cls.CHECKPOINT_HYPER_PARAMS_NAME) init_args_name = inspect.signature(cls).parameters.keys() + if args_name == 'kwargs': cls_kwargs = {k: v for k, v in model_args.items() if k in init_args_name} kwargs.update(**cls_kwargs) diff --git a/tests/models/test_hparams.py b/tests/models/test_hparams.py index b802f894f4..7c85e734f9 100644 --- a/tests/models/test_hparams.py +++ b/tests/models/test_hparams.py @@ -39,7 +39,7 @@ def _run_standard_hparams_test(tmpdir, model, cls, try_overwrite=False): assert model.hparams.test_arg == 14 # verify we can train - trainer = Trainer(default_root_dir=tmpdir, max_epochs=1, overfit_batches=0.5) + trainer = Trainer(default_root_dir=tmpdir, max_epochs=1, overfit_batches=2) trainer.fit(model) # make sure the raw checkpoint saved the properties