From 51711c265a9e234f2b4164f1a2fab73373707d61 Mon Sep 17 00:00:00 2001 From: Jirka Borovec Date: Sat, 27 Jun 2020 22:38:03 +0200 Subject: [PATCH] fix loading model with kwargs (#2387) * test * fix * fix --- .github/PULL_REQUEST_TEMPLATE.md | 2 +- CHANGELOG.md | 2 ++ pytorch_lightning/core/saving.py | 20 ++++++++++-------- pytorch_lightning/trainer/training_io.py | 2 +- tests/base/model_template.py | 26 ++++++++++++------------ tests/models/test_hparams.py | 19 ++++++++--------- tests/test_deprecated.py | 5 ++--- tests/trainer/test_lr_finder.py | 2 +- tests/trainer/test_optimizers.py | 2 +- 9 files changed, 42 insertions(+), 38 deletions(-) diff --git a/.github/PULL_REQUEST_TEMPLATE.md b/.github/PULL_REQUEST_TEMPLATE.md index 9a84d70f8f..c3e42731e3 100644 --- a/.github/PULL_REQUEST_TEMPLATE.md +++ b/.github/PULL_REQUEST_TEMPLATE.md @@ -11,7 +11,7 @@ Fixes # (issue) - [ ] Was this discussed/approved via a Github issue? (no need for typos and docs improvements) - [ ] Did you read the [contributor guideline](https://github.com/PyTorchLightning/pytorch-lightning/blob/master/.github/CONTRIBUTING.md), Pull Request section? -- [ ] Did you make sure your PR does only one thing, instead of bundling different changes together? Otherwise, we ask you create a separate PR for every change. +- [ ] Did you make sure your PR does only one thing, instead of bundling different changes together? Otherwise, we ask you to create a separate PR for every change. - [ ] Did you make sure to update the documentation with your changes? - [ ] Did you write any new necessary tests? - [ ] Did you verify new and existing tests pass locally with your changes? diff --git a/CHANGELOG.md b/CHANGELOG.md index 617bbcb749..003aec2ea7 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -40,6 +40,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixed lost compatibility with custom datatypes implementing `.to` ([#2335](https://github.com/PyTorchLightning/pytorch-lightning/pull/2335)) +- Fixed loading model with kwargs ([#2387](https://github.com/PyTorchLightning/pytorch-lightning/pull/2387)) + ## [0.8.1] - 2020-06-19 ### Fixed diff --git a/pytorch_lightning/core/saving.py b/pytorch_lightning/core/saving.py index d3085a8bfe..4ad11e6ec1 100644 --- a/pytorch_lightning/core/saving.py +++ b/pytorch_lightning/core/saving.py @@ -170,7 +170,7 @@ class ModelIO(object): return model @classmethod - def _load_model_state(cls, checkpoint: Dict[str, Any], *args, **kwargs): + def _load_model_state(cls, checkpoint: Dict[str, Any], *cls_args, **cls_kwargs): # pass in the values we saved automatically if cls.CHECKPOINT_HYPER_PARAMS_KEY in checkpoint: model_args = {} @@ -184,19 +184,23 @@ class ModelIO(object): model_args = checkpoint[cls.CHECKPOINT_HYPER_PARAMS_TYPE](model_args) args_name = checkpoint.get(cls.CHECKPOINT_HYPER_PARAMS_NAME) - init_args_name = inspect.signature(cls).parameters.keys() + cls_spec = inspect.getfullargspec(cls.__init__) + kwargs_identifier = cls_spec.varkw + cls_init_args_name = inspect.signature(cls).parameters.keys() if args_name == 'kwargs': - cls_kwargs = {k: v for k, v in model_args.items() if k in init_args_name} - kwargs.update(**cls_kwargs) + # in case the class cannot take any extra argument filter only the possible + if not kwargs_identifier: + model_args = {k: v for k, v in model_args.items() if k in cls_init_args_name} + cls_kwargs.update(**model_args) elif args_name: - if args_name in init_args_name: - kwargs.update({args_name: model_args}) + if args_name in cls_init_args_name: + cls_kwargs.update({args_name: model_args}) else: - args = (model_args, ) + args + cls_args = (model_args,) + cls_args # load the state_dict on the model automatically - model = cls(*args, **kwargs) + model = cls(*cls_args, **cls_kwargs) model.load_state_dict(checkpoint['state_dict']) # give model a chance to load something diff --git a/pytorch_lightning/trainer/training_io.py b/pytorch_lightning/trainer/training_io.py index b84a26c9e2..4c672743dc 100644 --- a/pytorch_lightning/trainer/training_io.py +++ b/pytorch_lightning/trainer/training_io.py @@ -324,7 +324,7 @@ class TrainerIOMixin(ABC): checkpoint = { 'epoch': self.current_epoch + 1, 'global_step': self.global_step + 1, - 'pytorch-ligthning_version': pytorch_lightning.__version__, + 'pytorch-lightning_version': pytorch_lightning.__version__, } if not weights_only: diff --git a/tests/base/model_template.py b/tests/base/model_template.py index 44dba72270..48851cdb08 100644 --- a/tests/base/model_template.py +++ b/tests/base/model_template.py @@ -36,19 +36,19 @@ class EvalModelTemplate( >>> model = EvalModelTemplate() """ - def __init__(self, - *args, - drop_prob: float = 0.2, - batch_size: int = 32, - in_features: int = 28 * 28, - learning_rate: float = 0.001 * 8, - optimizer_name: str = 'adam', - data_root: str = PATH_DATASETS, - out_features: int = 10, - hidden_dim: int = 1000, - b1: float = 0.5, - b2: float = 0.999, - **kwargs) -> object: + def __init__( + self, + drop_prob: float = 0.2, + batch_size: int = 32, + in_features: int = 28 * 28, + learning_rate: float = 0.001 * 8, + optimizer_name: str = 'adam', + data_root: str = PATH_DATASETS, + out_features: int = 10, + hidden_dim: int = 1000, + b1: float = 0.5, + b2: float = 0.999 + ): # init superclass super().__init__() self.save_hyperparameters() diff --git a/tests/models/test_hparams.py b/tests/models/test_hparams.py index 7c85e734f9..e2b200adc5 100644 --- a/tests/models/test_hparams.py +++ b/tests/models/test_hparams.py @@ -275,16 +275,15 @@ def test_collect_init_arguments(tmpdir, cls): assert raw_checkpoint[LightningModule.CHECKPOINT_HYPER_PARAMS_KEY]['batch_size'] == 179 # verify that model loads correctly - # TODO: uncomment and get it pass - # model = cls.load_from_checkpoint(raw_checkpoint_path) - # assert model.hparams.batch_size == 179 - # - # if isinstance(model, AggSubClassEvalModel): - # assert isinstance(model.hparams.my_loss, torch.nn.CrossEntropyLoss) - # - # if isinstance(model, DictConfSubClassEvalModel): - # assert isinstance(model.hparams.dict_conf, Container) - # assert model.hparams.dict_conf == 'anything' + model = cls.load_from_checkpoint(raw_checkpoint_path) + assert model.hparams.batch_size == 179 + + if isinstance(model, AggSubClassEvalModel): + assert isinstance(model.hparams.my_loss, torch.nn.CrossEntropyLoss) + + if isinstance(model, DictConfSubClassEvalModel): + assert isinstance(model.hparams.dict_conf, Container) + assert model.hparams.dict_conf['my_param'] == 'anything' # verify that we can overwrite whatever we want model = cls.load_from_checkpoint(raw_checkpoint_path, batch_size=99) diff --git a/tests/test_deprecated.py b/tests/test_deprecated.py index 0ec60a4a2e..e119c2bff8 100644 --- a/tests/test_deprecated.py +++ b/tests/test_deprecated.py @@ -128,9 +128,8 @@ class ModelVer0_7(EvalModelTemplate): def test_tbd_remove_in_v1_0_0_model_hooks(): - hparams = EvalModelTemplate.get_default_hparams() - model = ModelVer0_6(hparams) + model = ModelVer0_6() with pytest.deprecated_call(match='v1.0'): trainer = Trainer(logger=False) @@ -143,7 +142,7 @@ def test_tbd_remove_in_v1_0_0_model_hooks(): result = trainer._evaluate(model, dataloaders=[[None]], max_batches=1) assert result == {'val_loss': torch.tensor(0.6)} - model = ModelVer0_7(hparams) + model = ModelVer0_7() with pytest.deprecated_call(match='v1.0'): trainer = Trainer(logger=False) diff --git a/tests/trainer/test_lr_finder.py b/tests/trainer/test_lr_finder.py index 69cbc47db5..08f97cf6b1 100755 --- a/tests/trainer/test_lr_finder.py +++ b/tests/trainer/test_lr_finder.py @@ -182,7 +182,7 @@ def test_suggestion_with_non_finite_values(tmpdir): """ Test that non-finite values does not alter results """ hparams = EvalModelTemplate.get_default_hparams() - model = EvalModelTemplate(hparams) + model = EvalModelTemplate(**hparams) # logger file to get meta trainer = Trainer( diff --git a/tests/trainer/test_optimizers.py b/tests/trainer/test_optimizers.py index 5d49543705..3a85a756fc 100644 --- a/tests/trainer/test_optimizers.py +++ b/tests/trainer/test_optimizers.py @@ -231,7 +231,7 @@ def test_configure_optimizer_from_dict(tmpdir): return config hparams = EvalModelTemplate.get_default_hparams() - model = CurrentModel(hparams) + model = CurrentModel(**hparams) # fit model trainer = Trainer(