parent
e82d9cdb66
commit
51711c265a
|
@ -11,7 +11,7 @@ Fixes # (issue)
|
|||
|
||||
- [ ] Was this discussed/approved via a Github issue? (no need for typos and docs improvements)
|
||||
- [ ] Did you read the [contributor guideline](https://github.com/PyTorchLightning/pytorch-lightning/blob/master/.github/CONTRIBUTING.md), Pull Request section?
|
||||
- [ ] Did you make sure your PR does only one thing, instead of bundling different changes together? Otherwise, we ask you create a separate PR for every change.
|
||||
- [ ] Did you make sure your PR does only one thing, instead of bundling different changes together? Otherwise, we ask you to create a separate PR for every change.
|
||||
- [ ] Did you make sure to update the documentation with your changes?
|
||||
- [ ] Did you write any new necessary tests?
|
||||
- [ ] Did you verify new and existing tests pass locally with your changes?
|
||||
|
|
|
@ -40,6 +40,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
|
|||
|
||||
- Fixed lost compatibility with custom datatypes implementing `.to` ([#2335](https://github.com/PyTorchLightning/pytorch-lightning/pull/2335))
|
||||
|
||||
- Fixed loading model with kwargs ([#2387](https://github.com/PyTorchLightning/pytorch-lightning/pull/2387))
|
||||
|
||||
## [0.8.1] - 2020-06-19
|
||||
|
||||
### Fixed
|
||||
|
|
|
@ -170,7 +170,7 @@ class ModelIO(object):
|
|||
return model
|
||||
|
||||
@classmethod
|
||||
def _load_model_state(cls, checkpoint: Dict[str, Any], *args, **kwargs):
|
||||
def _load_model_state(cls, checkpoint: Dict[str, Any], *cls_args, **cls_kwargs):
|
||||
# pass in the values we saved automatically
|
||||
if cls.CHECKPOINT_HYPER_PARAMS_KEY in checkpoint:
|
||||
model_args = {}
|
||||
|
@ -184,19 +184,23 @@ class ModelIO(object):
|
|||
model_args = checkpoint[cls.CHECKPOINT_HYPER_PARAMS_TYPE](model_args)
|
||||
|
||||
args_name = checkpoint.get(cls.CHECKPOINT_HYPER_PARAMS_NAME)
|
||||
init_args_name = inspect.signature(cls).parameters.keys()
|
||||
cls_spec = inspect.getfullargspec(cls.__init__)
|
||||
kwargs_identifier = cls_spec.varkw
|
||||
cls_init_args_name = inspect.signature(cls).parameters.keys()
|
||||
|
||||
if args_name == 'kwargs':
|
||||
cls_kwargs = {k: v for k, v in model_args.items() if k in init_args_name}
|
||||
kwargs.update(**cls_kwargs)
|
||||
# in case the class cannot take any extra argument filter only the possible
|
||||
if not kwargs_identifier:
|
||||
model_args = {k: v for k, v in model_args.items() if k in cls_init_args_name}
|
||||
cls_kwargs.update(**model_args)
|
||||
elif args_name:
|
||||
if args_name in init_args_name:
|
||||
kwargs.update({args_name: model_args})
|
||||
if args_name in cls_init_args_name:
|
||||
cls_kwargs.update({args_name: model_args})
|
||||
else:
|
||||
args = (model_args, ) + args
|
||||
cls_args = (model_args,) + cls_args
|
||||
|
||||
# load the state_dict on the model automatically
|
||||
model = cls(*args, **kwargs)
|
||||
model = cls(*cls_args, **cls_kwargs)
|
||||
model.load_state_dict(checkpoint['state_dict'])
|
||||
|
||||
# give model a chance to load something
|
||||
|
|
|
@ -324,7 +324,7 @@ class TrainerIOMixin(ABC):
|
|||
checkpoint = {
|
||||
'epoch': self.current_epoch + 1,
|
||||
'global_step': self.global_step + 1,
|
||||
'pytorch-ligthning_version': pytorch_lightning.__version__,
|
||||
'pytorch-lightning_version': pytorch_lightning.__version__,
|
||||
}
|
||||
|
||||
if not weights_only:
|
||||
|
|
|
@ -36,19 +36,19 @@ class EvalModelTemplate(
|
|||
>>> model = EvalModelTemplate()
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
*args,
|
||||
drop_prob: float = 0.2,
|
||||
batch_size: int = 32,
|
||||
in_features: int = 28 * 28,
|
||||
learning_rate: float = 0.001 * 8,
|
||||
optimizer_name: str = 'adam',
|
||||
data_root: str = PATH_DATASETS,
|
||||
out_features: int = 10,
|
||||
hidden_dim: int = 1000,
|
||||
b1: float = 0.5,
|
||||
b2: float = 0.999,
|
||||
**kwargs) -> object:
|
||||
def __init__(
|
||||
self,
|
||||
drop_prob: float = 0.2,
|
||||
batch_size: int = 32,
|
||||
in_features: int = 28 * 28,
|
||||
learning_rate: float = 0.001 * 8,
|
||||
optimizer_name: str = 'adam',
|
||||
data_root: str = PATH_DATASETS,
|
||||
out_features: int = 10,
|
||||
hidden_dim: int = 1000,
|
||||
b1: float = 0.5,
|
||||
b2: float = 0.999
|
||||
):
|
||||
# init superclass
|
||||
super().__init__()
|
||||
self.save_hyperparameters()
|
||||
|
|
|
@ -275,16 +275,15 @@ def test_collect_init_arguments(tmpdir, cls):
|
|||
assert raw_checkpoint[LightningModule.CHECKPOINT_HYPER_PARAMS_KEY]['batch_size'] == 179
|
||||
|
||||
# verify that model loads correctly
|
||||
# TODO: uncomment and get it pass
|
||||
# model = cls.load_from_checkpoint(raw_checkpoint_path)
|
||||
# assert model.hparams.batch_size == 179
|
||||
#
|
||||
# if isinstance(model, AggSubClassEvalModel):
|
||||
# assert isinstance(model.hparams.my_loss, torch.nn.CrossEntropyLoss)
|
||||
#
|
||||
# if isinstance(model, DictConfSubClassEvalModel):
|
||||
# assert isinstance(model.hparams.dict_conf, Container)
|
||||
# assert model.hparams.dict_conf == 'anything'
|
||||
model = cls.load_from_checkpoint(raw_checkpoint_path)
|
||||
assert model.hparams.batch_size == 179
|
||||
|
||||
if isinstance(model, AggSubClassEvalModel):
|
||||
assert isinstance(model.hparams.my_loss, torch.nn.CrossEntropyLoss)
|
||||
|
||||
if isinstance(model, DictConfSubClassEvalModel):
|
||||
assert isinstance(model.hparams.dict_conf, Container)
|
||||
assert model.hparams.dict_conf['my_param'] == 'anything'
|
||||
|
||||
# verify that we can overwrite whatever we want
|
||||
model = cls.load_from_checkpoint(raw_checkpoint_path, batch_size=99)
|
||||
|
|
|
@ -128,9 +128,8 @@ class ModelVer0_7(EvalModelTemplate):
|
|||
|
||||
|
||||
def test_tbd_remove_in_v1_0_0_model_hooks():
|
||||
hparams = EvalModelTemplate.get_default_hparams()
|
||||
|
||||
model = ModelVer0_6(hparams)
|
||||
model = ModelVer0_6()
|
||||
|
||||
with pytest.deprecated_call(match='v1.0'):
|
||||
trainer = Trainer(logger=False)
|
||||
|
@ -143,7 +142,7 @@ def test_tbd_remove_in_v1_0_0_model_hooks():
|
|||
result = trainer._evaluate(model, dataloaders=[[None]], max_batches=1)
|
||||
assert result == {'val_loss': torch.tensor(0.6)}
|
||||
|
||||
model = ModelVer0_7(hparams)
|
||||
model = ModelVer0_7()
|
||||
|
||||
with pytest.deprecated_call(match='v1.0'):
|
||||
trainer = Trainer(logger=False)
|
||||
|
|
|
@ -182,7 +182,7 @@ def test_suggestion_with_non_finite_values(tmpdir):
|
|||
""" Test that non-finite values does not alter results """
|
||||
|
||||
hparams = EvalModelTemplate.get_default_hparams()
|
||||
model = EvalModelTemplate(hparams)
|
||||
model = EvalModelTemplate(**hparams)
|
||||
|
||||
# logger file to get meta
|
||||
trainer = Trainer(
|
||||
|
|
|
@ -231,7 +231,7 @@ def test_configure_optimizer_from_dict(tmpdir):
|
|||
return config
|
||||
|
||||
hparams = EvalModelTemplate.get_default_hparams()
|
||||
model = CurrentModel(hparams)
|
||||
model = CurrentModel(**hparams)
|
||||
|
||||
# fit model
|
||||
trainer = Trainer(
|
||||
|
|
Loading…
Reference in New Issue