Deprecate TrainerModelHooksMixin (#7422)

* Deprecate TrainerModelHooksMixin

* Update CHANGELOG.md

* Update model_hooks.py

* Update model_hooks.py
This commit is contained in:
ananthsub 2021-05-07 13:19:36 -07:00 committed by GitHub
parent 8208c330eb
commit fecce50355
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 59 additions and 7 deletions

View File

@ -22,6 +22,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
### Deprecated
- Deprecated `TrainerModelHooksMixin` in favor of `pytorch_lightning.utilities.signature_utils` ([#7422](https://github.com/PyTorchLightning/pytorch-lightning/pull/7422))
### Removed

View File

@ -12,18 +12,28 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import inspect
from abc import ABC
from typing import Optional
from pytorch_lightning.core.lightning import LightningModule
from pytorch_lightning.utilities.distributed import rank_zero_deprecation
from pytorch_lightning.utilities.signature_utils import is_param_in_hook_signature
class TrainerModelHooksMixin(ABC):
"""
TODO: Remove this class in v1.6.
Use the utilities from ``pytorch_lightning.utilities.signature_utils`` instead.
"""
lightning_module: LightningModule
def is_function_implemented(self, f_name: str, model: Optional[LightningModule] = None) -> bool:
rank_zero_deprecation(
"Internal: TrainerModelHooksMixin.is_function_implemented is deprecated in v1.4"
" and will be removed in v1.6."
)
# note: currently unused - kept as it is public
if model is None:
model = self.lightning_module
@ -31,6 +41,13 @@ class TrainerModelHooksMixin(ABC):
return callable(f_op)
def has_arg(self, f_name: str, arg_name: str) -> bool:
rank_zero_deprecation(
"Internal: TrainerModelHooksMixin.is_function_implemented is deprecated in v1.4"
" and will be removed in v1.6."
" Use `pytorch_lightning.utilities.signature_utils.is_param_in_hook_signature` instead."
)
model = self.lightning_module
f_op = getattr(model, f_name, None)
return arg_name in inspect.signature(f_op).parameters
if not f_op:
return False
return is_param_in_hook_signature(f_op, arg_name)

View File

@ -936,18 +936,20 @@ class TrainLoop:
# enable not needing to add opt_idx to training_step
args = [batch, batch_idx]
lightning_module = self.trainer.lightning_module
if len(self.trainer.optimizers) > 1:
if self.trainer.has_arg("training_step", "optimizer_idx"):
if not self.trainer.lightning_module.automatic_optimization:
training_step_fx = getattr(lightning_module, "training_step")
has_opt_idx_in_train_step = is_param_in_hook_signature(training_step_fx, "optimizer_idx")
if has_opt_idx_in_train_step:
if not lightning_module.automatic_optimization:
self.warning_cache.warn(
"`training_step` hook signature has changed in v1.3."
" `optimizer_idx` argument has been removed in case of manual optimization. Support for"
" the old signature will be removed in v1.5", DeprecationWarning
)
args.append(opt_idx)
elif not self.trainer.has_arg(
"training_step", "optimizer_idx"
) and self.trainer.lightning_module.automatic_optimization:
elif not has_opt_idx_in_train_step and lightning_module.automatic_optimization:
raise ValueError(
f"Your LightningModule defines {len(self.trainer.optimizers)} optimizers but"
' `training_step` is missing the `optimizer_idx` argument.'

View File

@ -0,0 +1,30 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" Test deprecated functionality which will be removed in v1.6.0 """
import pytest
from pytorch_lightning import Trainer
from tests.helpers import BoringModel
def test_v1_6_0_trainer_model_hook_mixin(tmpdir):
model = BoringModel()
trainer = Trainer(default_root_dir=tmpdir, max_epochs=1, checkpoint_callback=False, logger=False)
trainer.fit(model)
with pytest.deprecated_call(match="is deprecated in v1.4 and will be removed in v1.6"):
trainer.is_function_implemented("training_step", model)
with pytest.deprecated_call(match="is deprecated in v1.4 and will be removed in v1.6"):
trainer.has_arg("training_step", "batch")