Support `save_hyperparameters()` in LightningModule dataclass (#7992)
Co-authored-by: Kaushik B <45285388+kaushikb11@users.noreply.github.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Carlos Mocholi <carlossmocholi@gmail.com>
This commit is contained in:
parent
341adad819
commit
b093a9e66d
|
@ -92,6 +92,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
|
|||
- Added support for `torch.nn.UninitializedParameter` in `ModelSummary` ([#7642](https://github.com/PyTorchLightning/pytorch-lightning/pull/7642))
|
||||
|
||||
|
||||
- Added support `LightningModule.save_hyperparameters` when `LightningModule` is a dataclass ([#7992](https://github.com/PyTorchLightning/pytorch-lightning/pull/7992))
|
||||
|
||||
|
||||
### Changed
|
||||
|
||||
|
||||
|
|
|
@ -16,6 +16,7 @@ import inspect
|
|||
import pickle
|
||||
import types
|
||||
from argparse import Namespace
|
||||
from dataclasses import fields, is_dataclass
|
||||
from typing import Any, Dict, Optional, Sequence, Tuple, Union
|
||||
|
||||
from pytorch_lightning.utilities import rank_zero_warn
|
||||
|
@ -197,7 +198,11 @@ def save_hyperparameters(
|
|||
|
||||
if not frame:
|
||||
frame = inspect.currentframe().f_back
|
||||
init_args = get_init_args(frame)
|
||||
|
||||
if is_dataclass(obj):
|
||||
init_args = {f.name: getattr(obj, f.name) for f in fields(obj)}
|
||||
else:
|
||||
init_args = get_init_args(frame)
|
||||
assert init_args, "failed to inspect the obj init"
|
||||
|
||||
if ignore is not None:
|
||||
|
|
|
@ -15,6 +15,7 @@ import functools
|
|||
import os
|
||||
import pickle
|
||||
from argparse import Namespace
|
||||
from dataclasses import dataclass
|
||||
|
||||
import cloudpickle
|
||||
import pytest
|
||||
|
@ -719,3 +720,21 @@ def test_empty_hparams_container(tmpdir):
|
|||
assert not model.hparams
|
||||
model = HparamsNamespaceContainerModel(Namespace())
|
||||
assert not model.hparams
|
||||
|
||||
|
||||
@dataclass
|
||||
class DataClassModel(BoringModel):
|
||||
|
||||
mandatory: int
|
||||
optional: str = "optional"
|
||||
ignore_me: bool = False
|
||||
|
||||
def __post_init__(self):
|
||||
super().__init__()
|
||||
self.save_hyperparameters(ignore=("ignore_me", ))
|
||||
|
||||
|
||||
def test_dataclass_lightning_module(tmpdir):
|
||||
""" Test that save_hyperparameters() works with a LightningModule as a dataclass. """
|
||||
model = DataClassModel(33, optional="cocofruit")
|
||||
assert model.hparams == dict(mandatory=33, optional="cocofruit")
|
||||
|
|
Loading…
Reference in New Issue