Fix missing arguments when saving hyperparams from parent class only (#9800)
* Fix missing arguments when saving hyperparams from parent class only * fix antipattern
This commit is contained in:
parent
7c6efbc8a8
commit
86ad941d06
|
@ -444,6 +444,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
|
|||
|
||||
- Fixed DeepSpeed and Lightning both calling the scheduler ([#9788](https://github.com/PyTorchLightning/pytorch-lightning/pull/9788))
|
||||
|
||||
- Fixed missing arguments when saving hyperparameters from the parent class but not from the child class ([#9800](https://github.com/PyTorchLightning/pytorch-lightning/pull/9800))
|
||||
|
||||
|
||||
## [1.4.9] - 2021-09-30
|
||||
|
||||
|
|
|
@ -217,7 +217,9 @@ def save_hyperparameters(
|
|||
if is_dataclass(obj):
|
||||
init_args = {f.name: getattr(obj, f.name) for f in fields(obj)}
|
||||
else:
|
||||
init_args = get_init_args(frame)
|
||||
init_args = {}
|
||||
for local_args in collect_init_args(frame, []):
|
||||
init_args.update(local_args)
|
||||
assert init_args, "failed to inspect the obj init"
|
||||
|
||||
if ignore is not None:
|
||||
|
|
|
@ -247,6 +247,13 @@ class SubClassBoringModel(CustomBoringModel):
|
|||
self.save_hyperparameters()
|
||||
|
||||
|
||||
class NonSavingSubClassBoringModel(CustomBoringModel):
|
||||
any_other_loss = torch.nn.CrossEntropyLoss()
|
||||
|
||||
def __init__(self, *args, subclass_arg=1200, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
|
||||
class SubSubClassBoringModel(SubClassBoringModel):
|
||||
pass
|
||||
|
||||
|
@ -277,6 +284,7 @@ class DictConfSubClassBoringModel(SubClassBoringModel):
|
|||
[
|
||||
CustomBoringModel,
|
||||
SubClassBoringModel,
|
||||
NonSavingSubClassBoringModel,
|
||||
SubSubClassBoringModel,
|
||||
AggSubClassBoringModel,
|
||||
UnconventionalArgsBoringModel,
|
||||
|
@ -296,7 +304,7 @@ def test_collect_init_arguments(tmpdir, cls):
|
|||
model = cls(batch_size=179, **extra_args)
|
||||
assert model.hparams.batch_size == 179
|
||||
|
||||
if isinstance(model, SubClassBoringModel):
|
||||
if isinstance(model, (SubClassBoringModel, NonSavingSubClassBoringModel)):
|
||||
assert model.hparams.subclass_arg == 1200
|
||||
|
||||
if isinstance(model, AggSubClassBoringModel):
|
||||
|
|
Loading…
Reference in New Issue