Remove the deprecated tuning property and enums (#16379)

This commit is contained in:
Carlos Mocholí 2023-01-17 02:52:30 +01:00 committed by Luca Antiga
parent 67eb931cdf
commit 3cc376e240
6 changed files with 11 additions and 101 deletions

View File

@ -87,6 +87,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Removed the automatic addition of a moving average of the `training_step` loss in the progress bar. Use `self.log("loss", ..., prog_bar=True)` instead. ([#16192](https://github.com/Lightning-AI/lightning/issues/16192))
- Tuner removal
* Removed the deprecated `trainer.tuning` property ([#16379](https://github.com/Lightning-AI/lightning/pull/16379))
* Removed the deprecated `TrainerFn.TUNING` and `RunningStage.TUNING` enums ([#16379](https://github.com/Lightning-AI/lightning/pull/16379))
### Fixed

View File

@ -95,8 +95,8 @@ class Timer(Callback):
self._duration = duration.total_seconds() if duration is not None else None
self._interval = interval
self._verbose = verbose
self._start_time: Dict[RunningStage, Optional[float]] = {stage: None for stage in RunningStage._without_tune()}
self._end_time: Dict[RunningStage, Optional[float]] = {stage: None for stage in RunningStage._without_tune()}
self._start_time: Dict[RunningStage, Optional[float]] = {stage: None for stage in RunningStage}
self._end_time: Dict[RunningStage, Optional[float]] = {stage: None for stage in RunningStage}
self._offset = 0
def start_time(self, stage: str = RunningStage.TRAINING) -> Optional[float]:
@ -161,7 +161,7 @@ class Timer(Callback):
self._check_time_remaining(trainer)
def state_dict(self) -> Dict[str, Any]:
return {"time_elapsed": {stage.value: self.time_elapsed(stage) for stage in RunningStage._without_tune()}}
return {"time_elapsed": {stage.value: self.time_elapsed(stage) for stage in RunningStage}}
def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
time_elapsed = state_dict.get("time_elapsed", {})

View File

@ -12,37 +12,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass, field
from enum import Enum, EnumMeta
from typing import Any, List, Optional
from typing import Optional
from pytorch_lightning.utilities import LightningEnum
from pytorch_lightning.utilities.enums import _FaultTolerantMode
from pytorch_lightning.utilities.rank_zero import rank_zero_deprecation
class _DeprecationManagingEnumMeta(EnumMeta):
"""Enum that calls `deprecate()` whenever a member is accessed.
Adapted from: https://stackoverflow.com/a/62309159/208880
"""
def __getattribute__(cls, name: str) -> Any:
obj = super().__getattribute__(name)
# ignore __dunder__ names -- prevents potential recursion errors
if not (name.startswith("__") and name.endswith("__")) and isinstance(obj, Enum):
obj.deprecate()
return obj
def __getitem__(cls, name: str) -> Any:
member: _DeprecationManagingEnumMeta = super().__getitem__(name)
member.deprecate()
return member
def __call__(cls, *args: Any, **kwargs: Any) -> Any:
obj = super().__call__(*args, **kwargs)
if isinstance(obj, Enum):
obj.deprecate()
return obj
class TrainerStatus(LightningEnum):
@ -58,7 +31,7 @@ class TrainerStatus(LightningEnum):
return self in (self.FINISHED, self.INTERRUPTED)
class TrainerFn(LightningEnum, metaclass=_DeprecationManagingEnumMeta):
class TrainerFn(LightningEnum):
"""
Enum for the user-facing functions of the :class:`~pytorch_lightning.trainer.trainer.Trainer`
such as :meth:`~pytorch_lightning.trainer.trainer.Trainer.fit` and
@ -69,21 +42,9 @@ class TrainerFn(LightningEnum, metaclass=_DeprecationManagingEnumMeta):
VALIDATING = "validate"
TESTING = "test"
PREDICTING = "predict"
TUNING = "tune"
def deprecate(self) -> None:
if self == self.TUNING:
rank_zero_deprecation(
f"`TrainerFn.{self.name}` has been deprecated in v1.8.0 and will be removed in v2.0.0."
)
@classmethod
def _without_tune(cls) -> List["TrainerFn"]:
fns = [fn for fn in cls if fn != "tune"]
return fns
class RunningStage(LightningEnum, metaclass=_DeprecationManagingEnumMeta):
class RunningStage(LightningEnum):
"""Enum for the current running stage.
This stage complements :class:`TrainerFn` by specifying the current running stage for each function.
@ -93,7 +54,6 @@ class RunningStage(LightningEnum, metaclass=_DeprecationManagingEnumMeta):
- ``TrainerFn.VALIDATING`` - ``RunningStage.VALIDATING``
- ``TrainerFn.TESTING`` - ``RunningStage.TESTING``
- ``TrainerFn.PREDICTING`` - ``RunningStage.PREDICTING``
- ``TrainerFn.TUNING`` - ``RunningStage.{TUNING,SANITY_CHECKING,TRAINING,VALIDATING}``
"""
TRAINING = "train"
@ -101,7 +61,6 @@ class RunningStage(LightningEnum, metaclass=_DeprecationManagingEnumMeta):
VALIDATING = "validate"
TESTING = "test"
PREDICTING = "predict"
TUNING = "tune"
@property
def evaluating(self) -> bool:
@ -115,17 +74,6 @@ class RunningStage(LightningEnum, metaclass=_DeprecationManagingEnumMeta):
return "val"
return self.value
def deprecate(self) -> None:
if self == self.TUNING:
rank_zero_deprecation(
f"`RunningStage.{self.name}` has been deprecated in v1.8.0 and will be removed in v2.0.0."
)
@classmethod
def _without_tune(cls) -> List["RunningStage"]:
fns = [fn for fn in cls if fn != "tune"]
return fns
@dataclass
class TrainerState:

View File

@ -96,7 +96,7 @@ from pytorch_lightning.utilities.data import has_len_all_ranks
from pytorch_lightning.utilities.exceptions import ExitGracefullyException, MisconfigurationException
from pytorch_lightning.utilities.imports import _fault_tolerant_training
from pytorch_lightning.utilities.model_helpers import is_overridden
from pytorch_lightning.utilities.rank_zero import rank_zero_deprecation, rank_zero_info, rank_zero_warn
from pytorch_lightning.utilities.rank_zero import rank_zero_info, rank_zero_warn
from pytorch_lightning.utilities.seed import isolate_rng
from pytorch_lightning.utilities.types import (
_EVALUATE_OUTPUT,
@ -1891,20 +1891,6 @@ class Trainer:
elif self.predicting:
self.state.stage = None
@property
def tuning(self) -> bool:
rank_zero_deprecation("`Trainer.tuning` has been deprecated in v1.8.0 and will be removed in v2.0.0.")
return self.state.stage == RunningStage.TUNING
@tuning.setter
def tuning(self, val: bool) -> None:
rank_zero_deprecation("Setting `Trainer.tuning` has been deprecated in v1.8.0 and will be removed in v2.0.0.")
if val:
self.state.stage = RunningStage.TUNING
elif self.tuning:
self.state.stage = None
@property
def validating(self) -> bool:
return self.state.stage == RunningStage.VALIDATING

View File

@ -20,7 +20,6 @@ import torch
from lightning_utilities.test.warning import no_warning_call
from torch.utils.data import DataLoader
from pytorch_lightning import Trainer
from pytorch_lightning.accelerators.cpu import CPUAccelerator
from pytorch_lightning.core.mixins.device_dtype_mixin import DeviceDtypeModuleMixin
from pytorch_lightning.demos.boring_classes import BoringModel, RandomDataset
@ -30,7 +29,6 @@ from pytorch_lightning.overrides.fairscale import LightningShardedDataParallel,
from pytorch_lightning.plugins.environments import LightningEnvironment
from pytorch_lightning.strategies.bagua import LightningBaguaModule
from pytorch_lightning.strategies.utils import on_colab_kaggle
from pytorch_lightning.trainer.states import RunningStage, TrainerFn
from pytorch_lightning.utilities.apply_func import (
apply_to_collection,
apply_to_collections,
@ -271,30 +269,6 @@ def test_v1_10_deprecated_accelerator_setup_environment_method():
CPUAccelerator().setup_environment(torch.device("cpu"))
def test_tuning_enum():
with pytest.deprecated_call(
match="`TrainerFn.TUNING` has been deprecated in v1.8.0 and will be removed in v2.0.0."
):
TrainerFn.TUNING
with pytest.deprecated_call(
match="`RunningStage.TUNING` has been deprecated in v1.8.0 and will be removed in v2.0.0."
):
RunningStage.TUNING
def test_tuning_trainer_property():
trainer = Trainer()
with pytest.deprecated_call(match="`Trainer.tuning` has been deprecated in v1.8.0 and will be removed in v2.0.0."):
trainer.tuning
with pytest.deprecated_call(
match="Setting `Trainer.tuning` has been deprecated in v1.8.0 and will be removed in v2.0.0."
):
trainer.tuning = True
def test_v1_8_1_deprecated_rank_zero_only():
from pytorch_lightning.utilities.distributed import rank_zero_only

View File

@ -150,8 +150,7 @@ def test_loops_restore(tmpdir):
trainer = Trainer(**trainer_args)
trainer.strategy.connect(model)
trainer_fns = [fn for fn in TrainerFn._without_tune()]
trainer_fns = list(TrainerFn)
for fn in trainer_fns:
trainer_fn = getattr(trainer, f"{fn}_loop")
trainer_fn.load_state_dict = mock.Mock()