Prune deprecated hparams setter (#6207)

This commit is contained in:
Carlos Mocholí 2021-02-27 13:24:50 +01:00 committed by GitHub
parent 40d5a9d6df
commit 111d9c7267
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 17 additions and 101 deletions

View File

@ -43,6 +43,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Removed `mode='auto'` from `EarlyStopping` ([#6167](https://github.com/PyTorchLightning/pytorch-lightning/pull/6167))
- Removed deprecated `LightningModule` `hparams` setter ([#6207](https://github.com/PyTorchLightning/pytorch-lightning/pull/6207))
### Fixed
- Made the `Plugin.reduce` method more consistent across all Plugins to reflect a mean-reduction by default ([#6011](https://github.com/PyTorchLightning/pytorch-lightning/pull/6011))

View File

@ -17,7 +17,6 @@ import collections
import copy
import inspect
import os
import re
import tempfile
import uuid
from abc import ABC
@ -1806,39 +1805,6 @@ class LightningModule(
# prevent any change
return copy.deepcopy(self._hparams_initial)
@hparams.setter
def hparams(self, hp: Union[dict, Namespace, Any]):
# TODO: remove this method in v1.3.0.
rank_zero_warn(
"The setter for self.hparams in LightningModule is deprecated since v1.1.0 and will be"
" removed in v1.3.0. Replace the assignment `self.hparams = hparams` with "
" `self.save_hyperparameters()`.", DeprecationWarning
)
hparams_assignment_name = self.__get_hparams_assignment_variable()
self._hparams_name = hparams_assignment_name
self._set_hparams(hp)
# this resolves case when user does not uses `save_hyperparameters` and do hard assignement in init
if not hasattr(self, "_hparams_initial"):
self._hparams_initial = copy.deepcopy(self._hparams)
def __get_hparams_assignment_variable(self):
"""
looks at the code of the class to figure out what the user named self.hparams
this only happens when the user explicitly sets self.hparams
"""
try:
class_code = inspect.getsource(self.__class__)
lines = class_code.split("\n")
for line in lines:
line = re.sub(r"\s+", "", line, flags=re.UNICODE)
if ".hparams=" in line:
return line.split("=")[1]
# todo: specify the possible exception
except Exception:
return "hparams"
return None
@property
def model_size(self) -> float:
# todo: think about better way without need to dump model to drive

View File

@ -1,30 +0,0 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Test deprecated functionality which will be removed in vX.Y.Z"""
import pytest
from pytorch_lightning import LightningModule
def test_v1_3_0_deprecated_arguments(tmpdir):
with pytest.deprecated_call(match="The setter for self.hparams in LightningModule is deprecated"):
class DeprecatedHparamsModel(LightningModule):
def __init__(self, hparams):
super().__init__()
self.hparams = hparams
DeprecatedHparamsModel({})

View File

@ -41,14 +41,6 @@ class SaveHparamsModel(BoringModel):
self.save_hyperparameters(hparams)
class AssignHparamsModel(BoringModel):
""" Tests that a model can take an object with explicit setter """
def __init__(self, hparams):
super().__init__()
self.hparams = hparams
def decorate(func):
@functools.wraps(func)
@ -68,16 +60,6 @@ class SaveHparamsDecoratedModel(BoringModel):
self.save_hyperparameters(hparams)
class AssignHparamsDecoratedModel(BoringModel):
""" Tests that a model can take an object with explicit setter"""
@decorate
@decorate
def __init__(self, hparams, *my_args, **my_kwargs):
super().__init__()
self.hparams = hparams
# -------------------------
# STANDARD TESTS
# -------------------------
@ -114,7 +96,7 @@ def _run_standard_hparams_test(tmpdir, model, cls, try_overwrite=False):
@pytest.mark.parametrize(
"cls", [SaveHparamsModel, AssignHparamsModel, SaveHparamsDecoratedModel, AssignHparamsDecoratedModel]
"cls", [SaveHparamsModel, SaveHparamsDecoratedModel]
)
def test_namespace_hparams(tmpdir, cls):
# init model
@ -125,7 +107,7 @@ def test_namespace_hparams(tmpdir, cls):
@pytest.mark.parametrize(
"cls", [SaveHparamsModel, AssignHparamsModel, SaveHparamsDecoratedModel, AssignHparamsDecoratedModel]
"cls", [SaveHparamsModel, SaveHparamsDecoratedModel]
)
def test_dict_hparams(tmpdir, cls):
# init model
@ -136,7 +118,7 @@ def test_dict_hparams(tmpdir, cls):
@pytest.mark.parametrize(
"cls", [SaveHparamsModel, AssignHparamsModel, SaveHparamsDecoratedModel, AssignHparamsDecoratedModel]
"cls", [SaveHparamsModel, SaveHparamsDecoratedModel]
)
def test_omega_conf_hparams(tmpdir, cls):
# init model
@ -580,8 +562,7 @@ class SuperClassPositionalArgs(BoringModel):
def __init__(self, hparams):
super().__init__()
self._hparams = None # pretend BoringModel did not call self.save_hyperparameters()
self.hparams = hparams
self._hparams = hparams # pretend BoringModel did not call self.save_hyperparameters()
class SubClassVarArgs(SuperClassPositionalArgs):
@ -617,8 +598,6 @@ def test_init_arg_with_runtime_change(tmpdir, cls):
assert model.hparams.running_arg == 123
model.hparams.running_arg = -1
assert model.hparams.running_arg == -1
model.hparams = Namespace(abc=42)
assert model.hparams.abc == 42
trainer = Trainer(
default_root_dir=tmpdir,
@ -664,18 +643,11 @@ def test_model_save_hyper_parameters_interpolation_with_hydra(tmpdir):
def __init__(self, args_0, args_1, args_2, kwarg_1=None):
self.save_hyperparameters()
self.test_hparams()
config_file = f"{tmpdir}/hparams.yaml"
save_hparams_to_yaml(config_file, self.hparams)
self.hparams = load_hparams_from_yaml(config_file)
self.test_hparams()
super().__init__()
def test_hparams(self):
assert self.hparams.args_0.log == "Something"
assert self.hparams.args_1['cfg'].log == "Something"
assert self.hparams.args_2[0].log == "Something"
assert self.hparams.kwarg_1['cfg'][0].log == "Something"
super().__init__()
with initialize(config_path="conf"):
args_0 = compose(config_name="config")

View File

@ -23,6 +23,7 @@ from pytorch_lightning import Trainer
from pytorch_lightning.utilities import _NATIVE_AMP_AVAILABLE, AMPType
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from tests.base import EvalModelTemplate
from tests.helpers import BoringModel
from tests.helpers.datamodules import MNISTDataModule
@ -282,10 +283,14 @@ def test_auto_scale_batch_size_set_model_attribute(tmpdir, use_hparams):
def test_auto_scale_batch_size_duplicate_attribute_warning(tmpdir):
""" Test for a warning when model.batch_size and model.hparams.batch_size both present. """
hparams = EvalModelTemplate.get_default_hparams()
model = EvalModelTemplate(**hparams)
model.hparams = hparams
# now we have model.batch_size and model.hparams.batch_size
class TestModel(BoringModel):
def __init__(self, batch_size=1):
super().__init__()
# now we have model.batch_size and model.hparams.batch_size
self.batch_size = 1
self.save_hyperparameters()
model = TestModel()
trainer = Trainer(default_root_dir=tmpdir, max_steps=1, max_epochs=1000, auto_scale_batch_size=True)
expected_message = "Field `model.batch_size` and `model.hparams.batch_size` are mutually exclusive!"
with pytest.warns(UserWarning, match=expected_message):