Support `save_hyperparameters()` in LightningModule dataclass (#7992)

Co-authored-by: Kaushik B <45285388+kaushikb11@users.noreply.github.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Carlos Mocholi <carlossmocholi@gmail.com>
This commit is contained in:
Adrian Wälchli 2021-06-16 10:30:58 +02:00 committed by GitHub
parent 341adad819
commit b093a9e66d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 28 additions and 1 deletions

View File

@ -92,6 +92,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Added support for `torch.nn.UninitializedParameter` in `ModelSummary` ([#7642](https://github.com/PyTorchLightning/pytorch-lightning/pull/7642))
- Added support `LightningModule.save_hyperparameters` when `LightningModule` is a dataclass ([#7992](https://github.com/PyTorchLightning/pytorch-lightning/pull/7992))
### Changed

View File

@ -16,6 +16,7 @@ import inspect
import pickle
import types
from argparse import Namespace
from dataclasses import fields, is_dataclass
from typing import Any, Dict, Optional, Sequence, Tuple, Union
from pytorch_lightning.utilities import rank_zero_warn
@ -197,7 +198,11 @@ def save_hyperparameters(
if not frame:
frame = inspect.currentframe().f_back
init_args = get_init_args(frame)
if is_dataclass(obj):
init_args = {f.name: getattr(obj, f.name) for f in fields(obj)}
else:
init_args = get_init_args(frame)
assert init_args, "failed to inspect the obj init"
if ignore is not None:

View File

@ -15,6 +15,7 @@ import functools
import os
import pickle
from argparse import Namespace
from dataclasses import dataclass
import cloudpickle
import pytest
@ -719,3 +720,21 @@ def test_empty_hparams_container(tmpdir):
assert not model.hparams
model = HparamsNamespaceContainerModel(Namespace())
assert not model.hparams
@dataclass
class DataClassModel(BoringModel):
mandatory: int
optional: str = "optional"
ignore_me: bool = False
def __post_init__(self):
super().__init__()
self.save_hyperparameters(ignore=("ignore_me", ))
def test_dataclass_lightning_module(tmpdir):
""" Test that save_hyperparameters() works with a LightningModule as a dataclass. """
model = DataClassModel(33, optional="cocofruit")
assert model.hparams == dict(mandatory=33, optional="cocofruit")