diff --git a/pytorch_lightning/core/saving.py b/pytorch_lightning/core/saving.py index bed26fa756..97c4f87cae 100644 --- a/pytorch_lightning/core/saving.py +++ b/pytorch_lightning/core/saving.py @@ -150,11 +150,11 @@ class ModelIO(object): # override the hparams with values that were passed in checkpoint[cls.CHECKPOINT_HYPER_PARAMS_KEY].update(kwargs) - model = cls._load_model_state(checkpoint, strict=strict, *args, **kwargs) + model = cls._load_model_state(checkpoint, *args, strict=strict, **kwargs) return model @classmethod - def _load_model_state(cls, checkpoint: Dict[str, Any], strict: bool = True, *cls_args, **cls_kwargs): + def _load_model_state(cls, checkpoint: Dict[str, Any], *cls_args, strict: bool = True, **cls_kwargs): cls_spec = inspect.getfullargspec(cls.__init__) cls_init_args_name = inspect.signature(cls.__init__).parameters.keys() # pass in the values we saved automatically