Add ignore param to save_hyperparameters (#6056)
* add ignore param to save_hyperparameters * add docstring for ignore * add type for frame object * Update pytorch_lightning/core/lightning.py Co-authored-by: Nicki Skafte <skaftenicki@gmail.com> * Update pytorch_lightning/core/lightning.py Co-authored-by: Nicki Skafte <skaftenicki@gmail.com> * fix whitespace * Update pytorch_lightning/core/lightning.py Co-authored-by: Nicki Skafte <skaftenicki@gmail.com> * Parametrize tests * Update pytorch_lightning/core/lightning.py Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com> * Update pytorch_lightning/core/lightning.py Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com> * seq * fix docs * Update lightning.py * Update lightning.py * fix docs errors * add example keyword * update docstring Co-authored-by: Nicki Skafte <skaftenicki@gmail.com> Co-authored-by: Carlos Mocholi <carlossmocholi@gmail.com> Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com>
This commit is contained in:
parent
48a10f16ef
commit
59acf574e5
|
@ -19,12 +19,13 @@ import inspect
|
|||
import logging
|
||||
import os
|
||||
import tempfile
|
||||
import types
|
||||
import uuid
|
||||
from abc import ABC
|
||||
from argparse import Namespace
|
||||
from functools import partial
|
||||
from pathlib import Path
|
||||
from typing import Any, Callable, Dict, List, Optional, Tuple, TYPE_CHECKING, Union
|
||||
from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, TYPE_CHECKING, Union
|
||||
|
||||
import torch
|
||||
from torch import ScriptModule, Tensor
|
||||
|
@ -1591,55 +1592,84 @@ class LightningModule(
|
|||
parents_arguments.update(args)
|
||||
return self_arguments, parents_arguments
|
||||
|
||||
def save_hyperparameters(self, *args, frame=None) -> None:
|
||||
"""Save all model arguments.
|
||||
def save_hyperparameters(
|
||||
self,
|
||||
*args,
|
||||
ignore: Optional[Union[Sequence[str], str]] = None,
|
||||
frame: Optional[types.FrameType] = None
|
||||
) -> None:
|
||||
"""Save model arguments to ``hparams`` attribute.
|
||||
|
||||
Args:
|
||||
args: single object of `dict`, `NameSpace` or `OmegaConf`
|
||||
or string names or arguments from class `__init__`
|
||||
or string names or arguments from class ``__init__``
|
||||
ignore: an argument name or a list of argument names from
|
||||
class ``__init__`` to be ignored
|
||||
frame: a frame object. Default is None
|
||||
|
||||
>>> class ManuallyArgsModel(LightningModule):
|
||||
... def __init__(self, arg1, arg2, arg3):
|
||||
... super().__init__()
|
||||
... # manually assign arguments
|
||||
... self.save_hyperparameters('arg1', 'arg3')
|
||||
... def forward(self, *args, **kwargs):
|
||||
... ...
|
||||
>>> model = ManuallyArgsModel(1, 'abc', 3.14)
|
||||
>>> model.hparams
|
||||
"arg1": 1
|
||||
"arg3": 3.14
|
||||
Example::
|
||||
>>> class ManuallyArgsModel(LightningModule):
|
||||
... def __init__(self, arg1, arg2, arg3):
|
||||
... super().__init__()
|
||||
... # manually assign arguments
|
||||
... self.save_hyperparameters('arg1', 'arg3')
|
||||
... def forward(self, *args, **kwargs):
|
||||
... ...
|
||||
>>> model = ManuallyArgsModel(1, 'abc', 3.14)
|
||||
>>> model.hparams
|
||||
"arg1": 1
|
||||
"arg3": 3.14
|
||||
|
||||
>>> class AutomaticArgsModel(LightningModule):
|
||||
... def __init__(self, arg1, arg2, arg3):
|
||||
... super().__init__()
|
||||
... # equivalent automatic
|
||||
... self.save_hyperparameters()
|
||||
... def forward(self, *args, **kwargs):
|
||||
... ...
|
||||
>>> model = AutomaticArgsModel(1, 'abc', 3.14)
|
||||
>>> model.hparams
|
||||
"arg1": 1
|
||||
"arg2": abc
|
||||
"arg3": 3.14
|
||||
>>> class AutomaticArgsModel(LightningModule):
|
||||
... def __init__(self, arg1, arg2, arg3):
|
||||
... super().__init__()
|
||||
... # equivalent automatic
|
||||
... self.save_hyperparameters()
|
||||
... def forward(self, *args, **kwargs):
|
||||
... ...
|
||||
>>> model = AutomaticArgsModel(1, 'abc', 3.14)
|
||||
>>> model.hparams
|
||||
"arg1": 1
|
||||
"arg2": abc
|
||||
"arg3": 3.14
|
||||
|
||||
>>> class SingleArgModel(LightningModule):
|
||||
... def __init__(self, params):
|
||||
... super().__init__()
|
||||
... # manually assign single argument
|
||||
... self.save_hyperparameters(params)
|
||||
... def forward(self, *args, **kwargs):
|
||||
... ...
|
||||
>>> model = SingleArgModel(Namespace(p1=1, p2='abc', p3=3.14))
|
||||
>>> model.hparams
|
||||
"p1": 1
|
||||
"p2": abc
|
||||
"p3": 3.14
|
||||
>>> class SingleArgModel(LightningModule):
|
||||
... def __init__(self, params):
|
||||
... super().__init__()
|
||||
... # manually assign single argument
|
||||
... self.save_hyperparameters(params)
|
||||
... def forward(self, *args, **kwargs):
|
||||
... ...
|
||||
>>> model = SingleArgModel(Namespace(p1=1, p2='abc', p3=3.14))
|
||||
>>> model.hparams
|
||||
"p1": 1
|
||||
"p2": abc
|
||||
"p3": 3.14
|
||||
|
||||
>>> class ManuallyArgsModel(LightningModule):
|
||||
... def __init__(self, arg1, arg2, arg3):
|
||||
... super().__init__()
|
||||
... # pass argument(s) to ignore as a string or in a list
|
||||
... self.save_hyperparameters(ignore='arg2')
|
||||
... def forward(self, *args, **kwargs):
|
||||
... ...
|
||||
>>> model = ManuallyArgsModel(1, 'abc', 3.14)
|
||||
>>> model.hparams
|
||||
"arg1": 1
|
||||
"arg3": 3.14
|
||||
"""
|
||||
if not frame:
|
||||
frame = inspect.currentframe().f_back
|
||||
init_args = get_init_args(frame)
|
||||
assert init_args, "failed to inspect the self init"
|
||||
|
||||
if ignore is not None:
|
||||
if isinstance(ignore, str):
|
||||
ignore = [ignore]
|
||||
if isinstance(ignore, (list, tuple)):
|
||||
ignore = [arg for arg in ignore if isinstance(arg, str)]
|
||||
init_args = {k: v for k, v in init_args.items() if k not in ignore}
|
||||
|
||||
if not args:
|
||||
# take all arguments
|
||||
hp = init_args
|
||||
|
|
|
@ -661,3 +661,39 @@ def test_model_save_hyper_parameters_interpolation_with_hydra(tmpdir):
|
|||
)
|
||||
trainer.fit(model)
|
||||
_ = TestHydraModel.load_from_checkpoint(checkpoint_callback.best_model_path)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("ignore", ("arg2", ("arg2", "arg3")))
|
||||
def test_ignore_args_list_hparams(tmpdir, ignore):
|
||||
"""
|
||||
Tests that args can be ignored in save_hyperparameters
|
||||
"""
|
||||
|
||||
class LocalModel(BoringModel):
|
||||
|
||||
def __init__(self, arg1, arg2, arg3):
|
||||
super().__init__()
|
||||
self.save_hyperparameters(ignore=ignore)
|
||||
|
||||
model = LocalModel(arg1=14, arg2=90, arg3=50)
|
||||
|
||||
# test proper property assignments
|
||||
assert model.hparams.arg1 == 14
|
||||
for arg in ignore:
|
||||
assert arg not in model.hparams
|
||||
|
||||
# verify we can train
|
||||
trainer = Trainer(default_root_dir=tmpdir, max_epochs=2, overfit_batches=0.5)
|
||||
trainer.fit(model)
|
||||
|
||||
# make sure the raw checkpoint saved the properties
|
||||
raw_checkpoint_path = _raw_checkpoint_path(trainer)
|
||||
raw_checkpoint = torch.load(raw_checkpoint_path)
|
||||
assert LightningModule.CHECKPOINT_HYPER_PARAMS_KEY in raw_checkpoint
|
||||
assert raw_checkpoint[LightningModule.CHECKPOINT_HYPER_PARAMS_KEY]["arg1"] == 14
|
||||
|
||||
# verify that model loads correctly
|
||||
model = LocalModel.load_from_checkpoint(raw_checkpoint_path, arg2=123, arg3=100)
|
||||
assert model.hparams.arg1 == 14
|
||||
for arg in ignore:
|
||||
assert arg not in model.hparams
|
||||
|
|
Loading…
Reference in New Issue