Deprecate TrainerModelHooksMixin (#7422)
* Deprecate TrainerModelHooksMixin * Update CHANGELOG.md * Update model_hooks.py * Update model_hooks.py
This commit is contained in:
parent
8208c330eb
commit
fecce50355
|
@ -22,6 +22,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
|
|||
### Deprecated
|
||||
|
||||
|
||||
- Deprecated `TrainerModelHooksMixin` in favor of `pytorch_lightning.utilities.signature_utils` ([#7422](https://github.com/PyTorchLightning/pytorch-lightning/pull/7422))
|
||||
|
||||
|
||||
### Removed
|
||||
|
||||
|
||||
|
|
|
@ -12,18 +12,28 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import inspect
|
||||
from abc import ABC
|
||||
from typing import Optional
|
||||
|
||||
from pytorch_lightning.core.lightning import LightningModule
|
||||
from pytorch_lightning.utilities.distributed import rank_zero_deprecation
|
||||
from pytorch_lightning.utilities.signature_utils import is_param_in_hook_signature
|
||||
|
||||
|
||||
class TrainerModelHooksMixin(ABC):
|
||||
"""
|
||||
TODO: Remove this class in v1.6.
|
||||
|
||||
Use the utilities from ``pytorch_lightning.utilities.signature_utils`` instead.
|
||||
"""
|
||||
|
||||
lightning_module: LightningModule
|
||||
|
||||
def is_function_implemented(self, f_name: str, model: Optional[LightningModule] = None) -> bool:
|
||||
rank_zero_deprecation(
|
||||
"Internal: TrainerModelHooksMixin.is_function_implemented is deprecated in v1.4"
|
||||
" and will be removed in v1.6."
|
||||
)
|
||||
# note: currently unused - kept as it is public
|
||||
if model is None:
|
||||
model = self.lightning_module
|
||||
|
@ -31,6 +41,13 @@ class TrainerModelHooksMixin(ABC):
|
|||
return callable(f_op)
|
||||
|
||||
def has_arg(self, f_name: str, arg_name: str) -> bool:
|
||||
rank_zero_deprecation(
|
||||
"Internal: TrainerModelHooksMixin.is_function_implemented is deprecated in v1.4"
|
||||
" and will be removed in v1.6."
|
||||
" Use `pytorch_lightning.utilities.signature_utils.is_param_in_hook_signature` instead."
|
||||
)
|
||||
model = self.lightning_module
|
||||
f_op = getattr(model, f_name, None)
|
||||
return arg_name in inspect.signature(f_op).parameters
|
||||
if not f_op:
|
||||
return False
|
||||
return is_param_in_hook_signature(f_op, arg_name)
|
||||
|
|
|
@ -936,18 +936,20 @@ class TrainLoop:
|
|||
# enable not needing to add opt_idx to training_step
|
||||
args = [batch, batch_idx]
|
||||
|
||||
lightning_module = self.trainer.lightning_module
|
||||
|
||||
if len(self.trainer.optimizers) > 1:
|
||||
if self.trainer.has_arg("training_step", "optimizer_idx"):
|
||||
if not self.trainer.lightning_module.automatic_optimization:
|
||||
training_step_fx = getattr(lightning_module, "training_step")
|
||||
has_opt_idx_in_train_step = is_param_in_hook_signature(training_step_fx, "optimizer_idx")
|
||||
if has_opt_idx_in_train_step:
|
||||
if not lightning_module.automatic_optimization:
|
||||
self.warning_cache.warn(
|
||||
"`training_step` hook signature has changed in v1.3."
|
||||
" `optimizer_idx` argument has been removed in case of manual optimization. Support for"
|
||||
" the old signature will be removed in v1.5", DeprecationWarning
|
||||
)
|
||||
args.append(opt_idx)
|
||||
elif not self.trainer.has_arg(
|
||||
"training_step", "optimizer_idx"
|
||||
) and self.trainer.lightning_module.automatic_optimization:
|
||||
elif not has_opt_idx_in_train_step and lightning_module.automatic_optimization:
|
||||
raise ValueError(
|
||||
f"Your LightningModule defines {len(self.trainer.optimizers)} optimizers but"
|
||||
' `training_step` is missing the `optimizer_idx` argument.'
|
||||
|
|
|
@ -0,0 +1,30 @@
|
|||
# Copyright The PyTorch Lightning team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
""" Test deprecated functionality which will be removed in v1.6.0 """
|
||||
|
||||
import pytest
|
||||
|
||||
from pytorch_lightning import Trainer
|
||||
from tests.helpers import BoringModel
|
||||
|
||||
|
||||
def test_v1_6_0_trainer_model_hook_mixin(tmpdir):
|
||||
model = BoringModel()
|
||||
trainer = Trainer(default_root_dir=tmpdir, max_epochs=1, checkpoint_callback=False, logger=False)
|
||||
trainer.fit(model)
|
||||
with pytest.deprecated_call(match="is deprecated in v1.4 and will be removed in v1.6"):
|
||||
trainer.is_function_implemented("training_step", model)
|
||||
|
||||
with pytest.deprecated_call(match="is deprecated in v1.4 and will be removed in v1.6"):
|
||||
trainer.has_arg("training_step", "batch")
|
Loading…
Reference in New Issue