Add ignore param to save_hyperparameters (#6056)

* add ignore param to save_hyperparameters

* add docstring for ignore

* add type for frame object

* Update pytorch_lightning/core/lightning.py

Co-authored-by: Nicki Skafte <skaftenicki@gmail.com>

* Update pytorch_lightning/core/lightning.py

Co-authored-by: Nicki Skafte <skaftenicki@gmail.com>

* fix whitespace

* Update pytorch_lightning/core/lightning.py

Co-authored-by: Nicki Skafte <skaftenicki@gmail.com>

* Parametrize tests

* Update pytorch_lightning/core/lightning.py

Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com>

* Update pytorch_lightning/core/lightning.py

Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com>

* seq

* fix docs

* Update lightning.py

* Update lightning.py

* fix docs errors

* add example keyword

* update docstring

Co-authored-by: Nicki Skafte <skaftenicki@gmail.com>
Co-authored-by: Carlos Mocholi <carlossmocholi@gmail.com>
Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com>
This commit is contained in:
Kaushik B 2021-03-05 00:32:42 +05:30 committed by GitHub
parent 48a10f16ef
commit 59acf574e5
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 105 additions and 39 deletions

View File

@ -19,12 +19,13 @@ import inspect
import logging
import os
import tempfile
import types
import uuid
from abc import ABC
from argparse import Namespace
from functools import partial
from pathlib import Path
from typing import Any, Callable, Dict, List, Optional, Tuple, TYPE_CHECKING, Union
from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, TYPE_CHECKING, Union
import torch
from torch import ScriptModule, Tensor
@ -1591,55 +1592,84 @@ class LightningModule(
parents_arguments.update(args)
return self_arguments, parents_arguments
def save_hyperparameters(self, *args, frame=None) -> None:
"""Save all model arguments.
def save_hyperparameters(
self,
*args,
ignore: Optional[Union[Sequence[str], str]] = None,
frame: Optional[types.FrameType] = None
) -> None:
"""Save model arguments to ``hparams`` attribute.
Args:
args: single object of `dict`, `NameSpace` or `OmegaConf`
or string names or arguments from class `__init__`
or string names or arguments from class ``__init__``
ignore: an argument name or a list of argument names from
class ``__init__`` to be ignored
frame: a frame object. Default is None
>>> class ManuallyArgsModel(LightningModule):
... def __init__(self, arg1, arg2, arg3):
... super().__init__()
... # manually assign arguments
... self.save_hyperparameters('arg1', 'arg3')
... def forward(self, *args, **kwargs):
... ...
>>> model = ManuallyArgsModel(1, 'abc', 3.14)
>>> model.hparams
"arg1": 1
"arg3": 3.14
Example::
>>> class ManuallyArgsModel(LightningModule):
... def __init__(self, arg1, arg2, arg3):
... super().__init__()
... # manually assign arguments
... self.save_hyperparameters('arg1', 'arg3')
... def forward(self, *args, **kwargs):
... ...
>>> model = ManuallyArgsModel(1, 'abc', 3.14)
>>> model.hparams
"arg1": 1
"arg3": 3.14
>>> class AutomaticArgsModel(LightningModule):
... def __init__(self, arg1, arg2, arg3):
... super().__init__()
... # equivalent automatic
... self.save_hyperparameters()
... def forward(self, *args, **kwargs):
... ...
>>> model = AutomaticArgsModel(1, 'abc', 3.14)
>>> model.hparams
"arg1": 1
"arg2": abc
"arg3": 3.14
>>> class AutomaticArgsModel(LightningModule):
... def __init__(self, arg1, arg2, arg3):
... super().__init__()
... # equivalent automatic
... self.save_hyperparameters()
... def forward(self, *args, **kwargs):
... ...
>>> model = AutomaticArgsModel(1, 'abc', 3.14)
>>> model.hparams
"arg1": 1
"arg2": abc
"arg3": 3.14
>>> class SingleArgModel(LightningModule):
... def __init__(self, params):
... super().__init__()
... # manually assign single argument
... self.save_hyperparameters(params)
... def forward(self, *args, **kwargs):
... ...
>>> model = SingleArgModel(Namespace(p1=1, p2='abc', p3=3.14))
>>> model.hparams
"p1": 1
"p2": abc
"p3": 3.14
>>> class SingleArgModel(LightningModule):
... def __init__(self, params):
... super().__init__()
... # manually assign single argument
... self.save_hyperparameters(params)
... def forward(self, *args, **kwargs):
... ...
>>> model = SingleArgModel(Namespace(p1=1, p2='abc', p3=3.14))
>>> model.hparams
"p1": 1
"p2": abc
"p3": 3.14
>>> class ManuallyArgsModel(LightningModule):
... def __init__(self, arg1, arg2, arg3):
... super().__init__()
... # pass argument(s) to ignore as a string or in a list
... self.save_hyperparameters(ignore='arg2')
... def forward(self, *args, **kwargs):
... ...
>>> model = ManuallyArgsModel(1, 'abc', 3.14)
>>> model.hparams
"arg1": 1
"arg3": 3.14
"""
if not frame:
frame = inspect.currentframe().f_back
init_args = get_init_args(frame)
assert init_args, "failed to inspect the self init"
if ignore is not None:
if isinstance(ignore, str):
ignore = [ignore]
if isinstance(ignore, (list, tuple)):
ignore = [arg for arg in ignore if isinstance(arg, str)]
init_args = {k: v for k, v in init_args.items() if k not in ignore}
if not args:
# take all arguments
hp = init_args

View File

@ -661,3 +661,39 @@ def test_model_save_hyper_parameters_interpolation_with_hydra(tmpdir):
)
trainer.fit(model)
_ = TestHydraModel.load_from_checkpoint(checkpoint_callback.best_model_path)
@pytest.mark.parametrize("ignore", ("arg2", ("arg2", "arg3")))
def test_ignore_args_list_hparams(tmpdir, ignore):
"""
Tests that args can be ignored in save_hyperparameters
"""
class LocalModel(BoringModel):
def __init__(self, arg1, arg2, arg3):
super().__init__()
self.save_hyperparameters(ignore=ignore)
model = LocalModel(arg1=14, arg2=90, arg3=50)
# test proper property assignments
assert model.hparams.arg1 == 14
for arg in ignore:
assert arg not in model.hparams
# verify we can train
trainer = Trainer(default_root_dir=tmpdir, max_epochs=2, overfit_batches=0.5)
trainer.fit(model)
# make sure the raw checkpoint saved the properties
raw_checkpoint_path = _raw_checkpoint_path(trainer)
raw_checkpoint = torch.load(raw_checkpoint_path)
assert LightningModule.CHECKPOINT_HYPER_PARAMS_KEY in raw_checkpoint
assert raw_checkpoint[LightningModule.CHECKPOINT_HYPER_PARAMS_KEY]["arg1"] == 14
# verify that model loads correctly
model = LocalModel.load_from_checkpoint(raw_checkpoint_path, arg2=123, arg3=100)
assert model.hparams.arg1 == 14
for arg in ignore:
assert arg not in model.hparams