remove frame inspection on self.hparams (#2253)

* remove frame inspection on self.hparams

* remove frame inspection on self.hparams

* remove frame inspection on self.hparams

* remove frame inspection on self.hparams

* remove frame inspection on self.hparams

* remove frame inspection on self.hparams
This commit is contained in:
William Falcon 2020-06-18 23:08:25 -04:00 committed by GitHub
parent 4885cfad03
commit 6ae9a97b09
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 23 additions and 2 deletions

View File

@ -1,6 +1,7 @@
import collections
import inspect
import os
import re
from abc import ABC, abstractmethod
from argparse import Namespace
from typing import Any, Callable, Dict, List, Optional, Tuple, Union, Sequence
@ -1692,4 +1693,20 @@ class LightningModule(ABC, DeviceDtypeModuleMixin, GradInformation, ModelIO, Mod
@hparams.setter
def hparams(self, hp: Union[dict, Namespace, Any]):
self.save_hyperparameters(hp, frame=inspect.currentframe().f_back.f_back)
hparams_assignment_name = self.__get_hparams_assignment_variable()
self._hparams_name = hparams_assignment_name
self._set_hparams(hp)
def __get_hparams_assignment_variable(self):
"""
looks at the code of the class to figure out what the user named self.hparams
this only happens when the user explicitly sets self.hparams
"""
class_code = inspect.getsource(self.__class__)
lines = class_code.split('\n')
for line in lines:
line = re.sub(r"\s+", "", line, flags=re.UNICODE)
if 'self.hparams=' in line:
return line.split('=')[1]
return None

View File

@ -176,14 +176,18 @@ class ModelIO(object):
# pass in the values we saved automatically
if cls.CHECKPOINT_HYPER_PARAMS_KEY in checkpoint:
model_args = {}
# add some back compatibility, the actual one shall be last
for hparam_key in CHECKPOINT_PAST_HPARAMS_KEYS + (cls.CHECKPOINT_HYPER_PARAMS_KEY,):
if hparam_key in checkpoint:
model_args.update(checkpoint[hparam_key])
if cls.CHECKPOINT_HYPER_PARAMS_TYPE in checkpoint:
model_args = checkpoint[cls.CHECKPOINT_HYPER_PARAMS_TYPE](model_args)
args_name = checkpoint.get(cls.CHECKPOINT_HYPER_PARAMS_NAME)
init_args_name = inspect.signature(cls).parameters.keys()
if args_name == 'kwargs':
cls_kwargs = {k: v for k, v in model_args.items() if k in init_args_name}
kwargs.update(**cls_kwargs)

View File

@ -39,7 +39,7 @@ def _run_standard_hparams_test(tmpdir, model, cls, try_overwrite=False):
assert model.hparams.test_arg == 14
# verify we can train
trainer = Trainer(default_root_dir=tmpdir, max_epochs=1, overfit_batches=0.5)
trainer = Trainer(default_root_dir=tmpdir, max_epochs=1, overfit_batches=2)
trainer.fit(model)
# make sure the raw checkpoint saved the properties