Make sure all kwargs come after args in _load_model_state() (#3063)

This commit is contained in:
Peter Yu 2020-08-20 07:29:44 -04:00 committed by GitHub
parent 88886ace72
commit 9a605642a4
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 2 additions and 2 deletions

View File

@ -150,11 +150,11 @@ class ModelIO(object):
# override the hparams with values that were passed in
checkpoint[cls.CHECKPOINT_HYPER_PARAMS_KEY].update(kwargs)
model = cls._load_model_state(checkpoint, strict=strict, *args, **kwargs)
model = cls._load_model_state(checkpoint, *args, strict=strict, **kwargs)
return model
@classmethod
def _load_model_state(cls, checkpoint: Dict[str, Any], strict: bool = True, *cls_args, **cls_kwargs):
def _load_model_state(cls, checkpoint: Dict[str, Any], *cls_args, strict: bool = True, **cls_kwargs):
cls_spec = inspect.getfullargspec(cls.__init__)
cls_init_args_name = inspect.signature(cls.__init__).parameters.keys()
# pass in the values we saved automatically