remove frame inspection on self.hparams (#2253)
* remove frame inspection on self.hparams * remove frame inspection on self.hparams * remove frame inspection on self.hparams * remove frame inspection on self.hparams * remove frame inspection on self.hparams * remove frame inspection on self.hparams
This commit is contained in:
parent
4885cfad03
commit
6ae9a97b09
|
@ -1,6 +1,7 @@
|
|||
import collections
|
||||
import inspect
|
||||
import os
|
||||
import re
|
||||
from abc import ABC, abstractmethod
|
||||
from argparse import Namespace
|
||||
from typing import Any, Callable, Dict, List, Optional, Tuple, Union, Sequence
|
||||
|
@ -1692,4 +1693,20 @@ class LightningModule(ABC, DeviceDtypeModuleMixin, GradInformation, ModelIO, Mod
|
|||
|
||||
@hparams.setter
|
||||
def hparams(self, hp: Union[dict, Namespace, Any]):
|
||||
self.save_hyperparameters(hp, frame=inspect.currentframe().f_back.f_back)
|
||||
hparams_assignment_name = self.__get_hparams_assignment_variable()
|
||||
self._hparams_name = hparams_assignment_name
|
||||
self._set_hparams(hp)
|
||||
|
||||
def __get_hparams_assignment_variable(self):
|
||||
"""
|
||||
looks at the code of the class to figure out what the user named self.hparams
|
||||
this only happens when the user explicitly sets self.hparams
|
||||
"""
|
||||
class_code = inspect.getsource(self.__class__)
|
||||
lines = class_code.split('\n')
|
||||
for line in lines:
|
||||
line = re.sub(r"\s+", "", line, flags=re.UNICODE)
|
||||
if 'self.hparams=' in line:
|
||||
return line.split('=')[1]
|
||||
|
||||
return None
|
||||
|
|
|
@ -176,14 +176,18 @@ class ModelIO(object):
|
|||
# pass in the values we saved automatically
|
||||
if cls.CHECKPOINT_HYPER_PARAMS_KEY in checkpoint:
|
||||
model_args = {}
|
||||
|
||||
# add some back compatibility, the actual one shall be last
|
||||
for hparam_key in CHECKPOINT_PAST_HPARAMS_KEYS + (cls.CHECKPOINT_HYPER_PARAMS_KEY,):
|
||||
if hparam_key in checkpoint:
|
||||
model_args.update(checkpoint[hparam_key])
|
||||
|
||||
if cls.CHECKPOINT_HYPER_PARAMS_TYPE in checkpoint:
|
||||
model_args = checkpoint[cls.CHECKPOINT_HYPER_PARAMS_TYPE](model_args)
|
||||
|
||||
args_name = checkpoint.get(cls.CHECKPOINT_HYPER_PARAMS_NAME)
|
||||
init_args_name = inspect.signature(cls).parameters.keys()
|
||||
|
||||
if args_name == 'kwargs':
|
||||
cls_kwargs = {k: v for k, v in model_args.items() if k in init_args_name}
|
||||
kwargs.update(**cls_kwargs)
|
||||
|
|
|
@ -39,7 +39,7 @@ def _run_standard_hparams_test(tmpdir, model, cls, try_overwrite=False):
|
|||
assert model.hparams.test_arg == 14
|
||||
|
||||
# verify we can train
|
||||
trainer = Trainer(default_root_dir=tmpdir, max_epochs=1, overfit_batches=0.5)
|
||||
trainer = Trainer(default_root_dir=tmpdir, max_epochs=1, overfit_batches=2)
|
||||
trainer.fit(model)
|
||||
|
||||
# make sure the raw checkpoint saved the properties
|
||||
|
|
Loading…
Reference in New Issue