fix loading model with kwargs (#2387)

* test

* fix

* fix
This commit is contained in:
Jirka Borovec 2020-06-27 22:38:03 +02:00 committed by GitHub
parent e82d9cdb66
commit 51711c265a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 42 additions and 38 deletions

View File

@ -11,7 +11,7 @@ Fixes # (issue)
- [ ] Was this discussed/approved via a Github issue? (no need for typos and docs improvements)
- [ ] Did you read the [contributor guideline](https://github.com/PyTorchLightning/pytorch-lightning/blob/master/.github/CONTRIBUTING.md), Pull Request section?
- [ ] Did you make sure your PR does only one thing, instead of bundling different changes together? Otherwise, we ask you create a separate PR for every change.
- [ ] Did you make sure your PR does only one thing, instead of bundling different changes together? Otherwise, we ask you to create a separate PR for every change.
- [ ] Did you make sure to update the documentation with your changes?
- [ ] Did you write any new necessary tests?
- [ ] Did you verify new and existing tests pass locally with your changes?

View File

@ -40,6 +40,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Fixed lost compatibility with custom datatypes implementing `.to` ([#2335](https://github.com/PyTorchLightning/pytorch-lightning/pull/2335))
- Fixed loading model with kwargs ([#2387](https://github.com/PyTorchLightning/pytorch-lightning/pull/2387))
## [0.8.1] - 2020-06-19
### Fixed

View File

@ -170,7 +170,7 @@ class ModelIO(object):
return model
@classmethod
def _load_model_state(cls, checkpoint: Dict[str, Any], *args, **kwargs):
def _load_model_state(cls, checkpoint: Dict[str, Any], *cls_args, **cls_kwargs):
# pass in the values we saved automatically
if cls.CHECKPOINT_HYPER_PARAMS_KEY in checkpoint:
model_args = {}
@ -184,19 +184,23 @@ class ModelIO(object):
model_args = checkpoint[cls.CHECKPOINT_HYPER_PARAMS_TYPE](model_args)
args_name = checkpoint.get(cls.CHECKPOINT_HYPER_PARAMS_NAME)
init_args_name = inspect.signature(cls).parameters.keys()
cls_spec = inspect.getfullargspec(cls.__init__)
kwargs_identifier = cls_spec.varkw
cls_init_args_name = inspect.signature(cls).parameters.keys()
if args_name == 'kwargs':
cls_kwargs = {k: v for k, v in model_args.items() if k in init_args_name}
kwargs.update(**cls_kwargs)
# in case the class cannot take any extra argument filter only the possible
if not kwargs_identifier:
model_args = {k: v for k, v in model_args.items() if k in cls_init_args_name}
cls_kwargs.update(**model_args)
elif args_name:
if args_name in init_args_name:
kwargs.update({args_name: model_args})
if args_name in cls_init_args_name:
cls_kwargs.update({args_name: model_args})
else:
args = (model_args, ) + args
cls_args = (model_args,) + cls_args
# load the state_dict on the model automatically
model = cls(*args, **kwargs)
model = cls(*cls_args, **cls_kwargs)
model.load_state_dict(checkpoint['state_dict'])
# give model a chance to load something

View File

@ -324,7 +324,7 @@ class TrainerIOMixin(ABC):
checkpoint = {
'epoch': self.current_epoch + 1,
'global_step': self.global_step + 1,
'pytorch-ligthning_version': pytorch_lightning.__version__,
'pytorch-lightning_version': pytorch_lightning.__version__,
}
if not weights_only:

View File

@ -36,19 +36,19 @@ class EvalModelTemplate(
>>> model = EvalModelTemplate()
"""
def __init__(self,
*args,
drop_prob: float = 0.2,
batch_size: int = 32,
in_features: int = 28 * 28,
learning_rate: float = 0.001 * 8,
optimizer_name: str = 'adam',
data_root: str = PATH_DATASETS,
out_features: int = 10,
hidden_dim: int = 1000,
b1: float = 0.5,
b2: float = 0.999,
**kwargs) -> object:
def __init__(
self,
drop_prob: float = 0.2,
batch_size: int = 32,
in_features: int = 28 * 28,
learning_rate: float = 0.001 * 8,
optimizer_name: str = 'adam',
data_root: str = PATH_DATASETS,
out_features: int = 10,
hidden_dim: int = 1000,
b1: float = 0.5,
b2: float = 0.999
):
# init superclass
super().__init__()
self.save_hyperparameters()

View File

@ -275,16 +275,15 @@ def test_collect_init_arguments(tmpdir, cls):
assert raw_checkpoint[LightningModule.CHECKPOINT_HYPER_PARAMS_KEY]['batch_size'] == 179
# verify that model loads correctly
# TODO: uncomment and get it pass
# model = cls.load_from_checkpoint(raw_checkpoint_path)
# assert model.hparams.batch_size == 179
#
# if isinstance(model, AggSubClassEvalModel):
# assert isinstance(model.hparams.my_loss, torch.nn.CrossEntropyLoss)
#
# if isinstance(model, DictConfSubClassEvalModel):
# assert isinstance(model.hparams.dict_conf, Container)
# assert model.hparams.dict_conf == 'anything'
model = cls.load_from_checkpoint(raw_checkpoint_path)
assert model.hparams.batch_size == 179
if isinstance(model, AggSubClassEvalModel):
assert isinstance(model.hparams.my_loss, torch.nn.CrossEntropyLoss)
if isinstance(model, DictConfSubClassEvalModel):
assert isinstance(model.hparams.dict_conf, Container)
assert model.hparams.dict_conf['my_param'] == 'anything'
# verify that we can overwrite whatever we want
model = cls.load_from_checkpoint(raw_checkpoint_path, batch_size=99)

View File

@ -128,9 +128,8 @@ class ModelVer0_7(EvalModelTemplate):
def test_tbd_remove_in_v1_0_0_model_hooks():
hparams = EvalModelTemplate.get_default_hparams()
model = ModelVer0_6(hparams)
model = ModelVer0_6()
with pytest.deprecated_call(match='v1.0'):
trainer = Trainer(logger=False)
@ -143,7 +142,7 @@ def test_tbd_remove_in_v1_0_0_model_hooks():
result = trainer._evaluate(model, dataloaders=[[None]], max_batches=1)
assert result == {'val_loss': torch.tensor(0.6)}
model = ModelVer0_7(hparams)
model = ModelVer0_7()
with pytest.deprecated_call(match='v1.0'):
trainer = Trainer(logger=False)

View File

@ -182,7 +182,7 @@ def test_suggestion_with_non_finite_values(tmpdir):
""" Test that non-finite values does not alter results """
hparams = EvalModelTemplate.get_default_hparams()
model = EvalModelTemplate(hparams)
model = EvalModelTemplate(**hparams)
# logger file to get meta
trainer = Trainer(

View File

@ -231,7 +231,7 @@ def test_configure_optimizer_from_dict(tmpdir):
return config
hparams = EvalModelTemplate.get_default_hparams()
model = CurrentModel(hparams)
model = CurrentModel(**hparams)
# fit model
trainer = Trainer(