Fix missing arguments when saving hyperparams from parent class only (#9800)

* Fix missing arguments when saving hyperparams from parent class only

* fix antipattern
This commit is contained in:
Elad Segal 2021-10-06 10:32:29 +03:00 committed by GitHub
parent 7c6efbc8a8
commit 86ad941d06
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 14 additions and 2 deletions

View File

@ -444,6 +444,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Fixed DeepSpeed and Lightning both calling the scheduler ([#9788](https://github.com/PyTorchLightning/pytorch-lightning/pull/9788))
- Fixed missing arguments when saving hyperparameters from the parent class but not from the child class ([#9800](https://github.com/PyTorchLightning/pytorch-lightning/pull/9800))
## [1.4.9] - 2021-09-30

View File

@ -217,7 +217,9 @@ def save_hyperparameters(
if is_dataclass(obj):
init_args = {f.name: getattr(obj, f.name) for f in fields(obj)}
else:
init_args = get_init_args(frame)
init_args = {}
for local_args in collect_init_args(frame, []):
init_args.update(local_args)
assert init_args, "failed to inspect the obj init"
if ignore is not None:

View File

@ -247,6 +247,13 @@ class SubClassBoringModel(CustomBoringModel):
self.save_hyperparameters()
class NonSavingSubClassBoringModel(CustomBoringModel):
any_other_loss = torch.nn.CrossEntropyLoss()
def __init__(self, *args, subclass_arg=1200, **kwargs):
super().__init__(*args, **kwargs)
class SubSubClassBoringModel(SubClassBoringModel):
pass
@ -277,6 +284,7 @@ class DictConfSubClassBoringModel(SubClassBoringModel):
[
CustomBoringModel,
SubClassBoringModel,
NonSavingSubClassBoringModel,
SubSubClassBoringModel,
AggSubClassBoringModel,
UnconventionalArgsBoringModel,
@ -296,7 +304,7 @@ def test_collect_init_arguments(tmpdir, cls):
model = cls(batch_size=179, **extra_args)
assert model.hparams.batch_size == 179
if isinstance(model, SubClassBoringModel):
if isinstance(model, (SubClassBoringModel, NonSavingSubClassBoringModel)):
assert model.hparams.subclass_arg == 1200
if isinstance(model, AggSubClassBoringModel):