Fix `save_hyperparameters` when no parameters need saving (#11827)

This commit is contained in:
Adrian Wälchli 2022-02-10 00:10:14 +01:00 committed by GitHub
parent a2d8c4f6a6
commit c618e59689
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 29 additions and 3 deletions

View File

@ -610,6 +610,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Configure native Deepspeed schedulers with interval='step' ([#11788](https://github.com/PyTorchLightning/pytorch-lightning/pull/11788))
- Fixed an `AttributeError` when calling `save_hyperparameters` and no parameters need saving ([#11827](https://github.com/PyTorchLightning/pytorch-lightning/pull/11827))
## [1.5.9] - 2022-01-20
### Fixed

View File

@ -224,7 +224,6 @@ def save_hyperparameters(
init_args = {}
for local_args in collect_init_args(frame, []):
init_args.update(local_args)
assert init_args, "failed to inspect the obj init"
if ignore is not None:
if isinstance(ignore, str):
@ -249,8 +248,7 @@ def save_hyperparameters(
obj._hparams_name = "kwargs"
# `hparams` are expected here
if hp:
obj._set_hparams(hp)
obj._set_hparams(hp)
# make deep copy so there is not other runtime changes reflected
obj._hparams_initial = copy.deepcopy(obj._hparams)

View File

@ -672,6 +672,31 @@ def test_ignore_args_list_hparams(tmpdir, ignore):
assert arg not in model.hparams
class IgnoreAllParametersModel(BoringModel):
def __init__(self, arg1, arg2, arg3):
super().__init__()
self.save_hyperparameters(ignore=("arg1", "arg2", "arg3"))
class NoParametersModel(BoringModel):
def __init__(self):
super().__init__()
self.save_hyperparameters()
@pytest.mark.parametrize(
"model",
(
IgnoreAllParametersModel(arg1=14, arg2=90, arg3=50),
NoParametersModel(),
),
)
def test_save_no_parameters(model):
"""Test that calling save_hyperparameters works if no parameters need saving."""
assert model.hparams == {}
assert model._hparams_initial == {}
class HparamsKwargsContainerModel(BoringModel):
def __init__(self, **kwargs):
super().__init__()