From a44b5dc0cb1489a805b73d0a786d34b9dea80c86 Mon Sep 17 00:00:00 2001 From: Rohit Gupta Date: Sat, 9 Apr 2022 01:10:19 +0530 Subject: [PATCH] Don't raise a warning when `nn.Module`s are not saved under hparams (#12669) --- CHANGELOG.md | 3 +++ pytorch_lightning/utilities/parsing.py | 19 ++++++++----------- tests/models/test_hparams.py | 17 ++++++++++++++--- 3 files changed, 25 insertions(+), 14 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 7d2e52577d..8e3052648e 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -98,6 +98,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixed `rank_zero_only` decorator in LSF environments ([#12587](https://github.com/PyTorchLightning/pytorch-lightning/pull/12587)) +- Don't raise a warning when `nn.Module` is not saved under hparams ([#12669](https://github.com/PyTorchLightning/pytorch-lightning/pull/12669)) + + - diff --git a/pytorch_lightning/utilities/parsing.py b/pytorch_lightning/utilities/parsing.py index e535d52847..92675bc991 100644 --- a/pytorch_lightning/utilities/parsing.py +++ b/pytorch_lightning/utilities/parsing.py @@ -234,17 +234,7 @@ def save_hyperparameters( ignore = [arg for arg in ignore if isinstance(arg, str)] ignore = list(set(ignore)) - - for k in list(init_args): - if k in ignore: - del init_args[k] - continue - - if isinstance(init_args[k], nn.Module): - rank_zero_warn( - f"Attribute {k!r} is an instance of `nn.Module` and is already saved during checkpointing." - f" It is recommended to ignore them using `self.save_hyperparameters(ignore=[{k!r}])`." - ) + init_args = {k: v for k, v in init_args.items() if k not in ignore} if not args: # take all arguments @@ -266,6 +256,13 @@ def save_hyperparameters( # make deep copy so there is not other runtime changes reflected obj._hparams_initial = copy.deepcopy(obj._hparams) + for k, v in obj._hparams.items(): + if isinstance(v, nn.Module): + rank_zero_warn( + f"Attribute {k!r} is an instance of `nn.Module` and is already saved during checkpointing." + f" It is recommended to ignore them using `self.save_hyperparameters(ignore=[{k!r}])`." + ) + class AttributeDict(Dict): """Extended dictionary accessible with dot notation. diff --git a/tests/models/test_hparams.py b/tests/models/test_hparams.py index 63424c8f47..40bde5b241 100644 --- a/tests/models/test_hparams.py +++ b/tests/models/test_hparams.py @@ -34,6 +34,7 @@ from pytorch_lightning.utilities import _HYDRA_EXPERIMENTAL_AVAILABLE, _OMEGACON from pytorch_lightning.utilities.exceptions import MisconfigurationException from tests.helpers import BoringModel, RandomDataset from tests.helpers.runif import RunIf +from tests.helpers.utils import no_warning_call if _HYDRA_EXPERIMENTAL_AVAILABLE: from hydra.experimental import compose, initialize @@ -819,18 +820,28 @@ def test_colliding_hparams(tmpdir): trainer.fit(model, datamodule=data) -def test_nn_modules_raises_warning_when_saved_as_hparams(): +def test_nn_modules_warning_when_saved_as_hparams(): class TorchModule(torch.nn.Module): def __init__(self): super().__init__() self.l1 = torch.nn.Linear(4, 5) - class CustomBoringModel(BoringModel): + class CustomBoringModelWarn(BoringModel): def __init__(self, encoder, decoder, other_hparam=7): super().__init__() self.save_hyperparameters() with pytest.warns(UserWarning, match="is an instance of `nn.Module` and is already saved"): - model = CustomBoringModel(encoder=TorchModule(), decoder=TorchModule()) + model = CustomBoringModelWarn(encoder=TorchModule(), decoder=TorchModule()) assert list(model.hparams) == ["encoder", "decoder", "other_hparam"] + + class CustomBoringModelNoWarn(BoringModel): + def __init__(self, encoder, decoder, other_hparam=7): + super().__init__() + self.save_hyperparameters("other_hparam") + + with no_warning_call(UserWarning, match="is an instance of `nn.Module` and is already saved"): + model = CustomBoringModelNoWarn(encoder=TorchModule(), decoder=TorchModule()) + + assert list(model.hparams) == ["other_hparam"]