From fecce50355b0c640f62093029b3af2429d12f99c Mon Sep 17 00:00:00 2001 From: ananthsub Date: Fri, 7 May 2021 13:19:36 -0700 Subject: [PATCH] Deprecate TrainerModelHooksMixin (#7422) * Deprecate TrainerModelHooksMixin * Update CHANGELOG.md * Update model_hooks.py * Update model_hooks.py --- CHANGELOG.md | 3 +++ pytorch_lightning/trainer/model_hooks.py | 21 +++++++++++++-- pytorch_lightning/trainer/training_loop.py | 12 +++++---- tests/deprecated_api/test_remove_1-6.py | 30 ++++++++++++++++++++++ 4 files changed, 59 insertions(+), 7 deletions(-) create mode 100644 tests/deprecated_api/test_remove_1-6.py diff --git a/CHANGELOG.md b/CHANGELOG.md index fdb02fb23e..899f9f74b9 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -22,6 +22,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Deprecated +- Deprecated `TrainerModelHooksMixin` in favor of `pytorch_lightning.utilities.signature_utils` ([#7422](https://github.com/PyTorchLightning/pytorch-lightning/pull/7422)) + + ### Removed diff --git a/pytorch_lightning/trainer/model_hooks.py b/pytorch_lightning/trainer/model_hooks.py index b924675d85..86cb1334a7 100644 --- a/pytorch_lightning/trainer/model_hooks.py +++ b/pytorch_lightning/trainer/model_hooks.py @@ -12,18 +12,28 @@ # See the License for the specific language governing permissions and # limitations under the License. -import inspect from abc import ABC from typing import Optional from pytorch_lightning.core.lightning import LightningModule +from pytorch_lightning.utilities.distributed import rank_zero_deprecation +from pytorch_lightning.utilities.signature_utils import is_param_in_hook_signature class TrainerModelHooksMixin(ABC): + """ + TODO: Remove this class in v1.6. + + Use the utilities from ``pytorch_lightning.utilities.signature_utils`` instead. + """ lightning_module: LightningModule def is_function_implemented(self, f_name: str, model: Optional[LightningModule] = None) -> bool: + rank_zero_deprecation( + "Internal: TrainerModelHooksMixin.is_function_implemented is deprecated in v1.4" + " and will be removed in v1.6." + ) # note: currently unused - kept as it is public if model is None: model = self.lightning_module @@ -31,6 +41,13 @@ class TrainerModelHooksMixin(ABC): return callable(f_op) def has_arg(self, f_name: str, arg_name: str) -> bool: + rank_zero_deprecation( + "Internal: TrainerModelHooksMixin.is_function_implemented is deprecated in v1.4" + " and will be removed in v1.6." + " Use `pytorch_lightning.utilities.signature_utils.is_param_in_hook_signature` instead." + ) model = self.lightning_module f_op = getattr(model, f_name, None) - return arg_name in inspect.signature(f_op).parameters + if not f_op: + return False + return is_param_in_hook_signature(f_op, arg_name) diff --git a/pytorch_lightning/trainer/training_loop.py b/pytorch_lightning/trainer/training_loop.py index 790dc4c70b..9c6f4b5458 100644 --- a/pytorch_lightning/trainer/training_loop.py +++ b/pytorch_lightning/trainer/training_loop.py @@ -936,18 +936,20 @@ class TrainLoop: # enable not needing to add opt_idx to training_step args = [batch, batch_idx] + lightning_module = self.trainer.lightning_module + if len(self.trainer.optimizers) > 1: - if self.trainer.has_arg("training_step", "optimizer_idx"): - if not self.trainer.lightning_module.automatic_optimization: + training_step_fx = getattr(lightning_module, "training_step") + has_opt_idx_in_train_step = is_param_in_hook_signature(training_step_fx, "optimizer_idx") + if has_opt_idx_in_train_step: + if not lightning_module.automatic_optimization: self.warning_cache.warn( "`training_step` hook signature has changed in v1.3." " `optimizer_idx` argument has been removed in case of manual optimization. Support for" " the old signature will be removed in v1.5", DeprecationWarning ) args.append(opt_idx) - elif not self.trainer.has_arg( - "training_step", "optimizer_idx" - ) and self.trainer.lightning_module.automatic_optimization: + elif not has_opt_idx_in_train_step and lightning_module.automatic_optimization: raise ValueError( f"Your LightningModule defines {len(self.trainer.optimizers)} optimizers but" ' `training_step` is missing the `optimizer_idx` argument.' diff --git a/tests/deprecated_api/test_remove_1-6.py b/tests/deprecated_api/test_remove_1-6.py new file mode 100644 index 0000000000..09312a4c41 --- /dev/null +++ b/tests/deprecated_api/test_remove_1-6.py @@ -0,0 +1,30 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" Test deprecated functionality which will be removed in v1.6.0 """ + +import pytest + +from pytorch_lightning import Trainer +from tests.helpers import BoringModel + + +def test_v1_6_0_trainer_model_hook_mixin(tmpdir): + model = BoringModel() + trainer = Trainer(default_root_dir=tmpdir, max_epochs=1, checkpoint_callback=False, logger=False) + trainer.fit(model) + with pytest.deprecated_call(match="is deprecated in v1.4 and will be removed in v1.6"): + trainer.is_function_implemented("training_step", model) + + with pytest.deprecated_call(match="is deprecated in v1.4 and will be removed in v1.6"): + trainer.has_arg("training_step", "batch")