Remove the deprecated tuning property and enums (#16379)
This commit is contained in:
parent
67eb931cdf
commit
3cc376e240
|
@ -87,6 +87,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
|
|||
|
||||
- Removed the automatic addition of a moving average of the `training_step` loss in the progress bar. Use `self.log("loss", ..., prog_bar=True)` instead. ([#16192](https://github.com/Lightning-AI/lightning/issues/16192))
|
||||
|
||||
- Tuner removal
|
||||
* Removed the deprecated `trainer.tuning` property ([#16379](https://github.com/Lightning-AI/lightning/pull/16379))
|
||||
* Removed the deprecated `TrainerFn.TUNING` and `RunningStage.TUNING` enums ([#16379](https://github.com/Lightning-AI/lightning/pull/16379))
|
||||
|
||||
### Fixed
|
||||
|
||||
|
|
|
@ -95,8 +95,8 @@ class Timer(Callback):
|
|||
self._duration = duration.total_seconds() if duration is not None else None
|
||||
self._interval = interval
|
||||
self._verbose = verbose
|
||||
self._start_time: Dict[RunningStage, Optional[float]] = {stage: None for stage in RunningStage._without_tune()}
|
||||
self._end_time: Dict[RunningStage, Optional[float]] = {stage: None for stage in RunningStage._without_tune()}
|
||||
self._start_time: Dict[RunningStage, Optional[float]] = {stage: None for stage in RunningStage}
|
||||
self._end_time: Dict[RunningStage, Optional[float]] = {stage: None for stage in RunningStage}
|
||||
self._offset = 0
|
||||
|
||||
def start_time(self, stage: str = RunningStage.TRAINING) -> Optional[float]:
|
||||
|
@ -161,7 +161,7 @@ class Timer(Callback):
|
|||
self._check_time_remaining(trainer)
|
||||
|
||||
def state_dict(self) -> Dict[str, Any]:
|
||||
return {"time_elapsed": {stage.value: self.time_elapsed(stage) for stage in RunningStage._without_tune()}}
|
||||
return {"time_elapsed": {stage.value: self.time_elapsed(stage) for stage in RunningStage}}
|
||||
|
||||
def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
|
||||
time_elapsed = state_dict.get("time_elapsed", {})
|
||||
|
|
|
@ -12,37 +12,10 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from dataclasses import dataclass, field
|
||||
from enum import Enum, EnumMeta
|
||||
from typing import Any, List, Optional
|
||||
from typing import Optional
|
||||
|
||||
from pytorch_lightning.utilities import LightningEnum
|
||||
from pytorch_lightning.utilities.enums import _FaultTolerantMode
|
||||
from pytorch_lightning.utilities.rank_zero import rank_zero_deprecation
|
||||
|
||||
|
||||
class _DeprecationManagingEnumMeta(EnumMeta):
|
||||
"""Enum that calls `deprecate()` whenever a member is accessed.
|
||||
|
||||
Adapted from: https://stackoverflow.com/a/62309159/208880
|
||||
"""
|
||||
|
||||
def __getattribute__(cls, name: str) -> Any:
|
||||
obj = super().__getattribute__(name)
|
||||
# ignore __dunder__ names -- prevents potential recursion errors
|
||||
if not (name.startswith("__") and name.endswith("__")) and isinstance(obj, Enum):
|
||||
obj.deprecate()
|
||||
return obj
|
||||
|
||||
def __getitem__(cls, name: str) -> Any:
|
||||
member: _DeprecationManagingEnumMeta = super().__getitem__(name)
|
||||
member.deprecate()
|
||||
return member
|
||||
|
||||
def __call__(cls, *args: Any, **kwargs: Any) -> Any:
|
||||
obj = super().__call__(*args, **kwargs)
|
||||
if isinstance(obj, Enum):
|
||||
obj.deprecate()
|
||||
return obj
|
||||
|
||||
|
||||
class TrainerStatus(LightningEnum):
|
||||
|
@ -58,7 +31,7 @@ class TrainerStatus(LightningEnum):
|
|||
return self in (self.FINISHED, self.INTERRUPTED)
|
||||
|
||||
|
||||
class TrainerFn(LightningEnum, metaclass=_DeprecationManagingEnumMeta):
|
||||
class TrainerFn(LightningEnum):
|
||||
"""
|
||||
Enum for the user-facing functions of the :class:`~pytorch_lightning.trainer.trainer.Trainer`
|
||||
such as :meth:`~pytorch_lightning.trainer.trainer.Trainer.fit` and
|
||||
|
@ -69,21 +42,9 @@ class TrainerFn(LightningEnum, metaclass=_DeprecationManagingEnumMeta):
|
|||
VALIDATING = "validate"
|
||||
TESTING = "test"
|
||||
PREDICTING = "predict"
|
||||
TUNING = "tune"
|
||||
|
||||
def deprecate(self) -> None:
|
||||
if self == self.TUNING:
|
||||
rank_zero_deprecation(
|
||||
f"`TrainerFn.{self.name}` has been deprecated in v1.8.0 and will be removed in v2.0.0."
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def _without_tune(cls) -> List["TrainerFn"]:
|
||||
fns = [fn for fn in cls if fn != "tune"]
|
||||
return fns
|
||||
|
||||
|
||||
class RunningStage(LightningEnum, metaclass=_DeprecationManagingEnumMeta):
|
||||
class RunningStage(LightningEnum):
|
||||
"""Enum for the current running stage.
|
||||
|
||||
This stage complements :class:`TrainerFn` by specifying the current running stage for each function.
|
||||
|
@ -93,7 +54,6 @@ class RunningStage(LightningEnum, metaclass=_DeprecationManagingEnumMeta):
|
|||
- ``TrainerFn.VALIDATING`` - ``RunningStage.VALIDATING``
|
||||
- ``TrainerFn.TESTING`` - ``RunningStage.TESTING``
|
||||
- ``TrainerFn.PREDICTING`` - ``RunningStage.PREDICTING``
|
||||
- ``TrainerFn.TUNING`` - ``RunningStage.{TUNING,SANITY_CHECKING,TRAINING,VALIDATING}``
|
||||
"""
|
||||
|
||||
TRAINING = "train"
|
||||
|
@ -101,7 +61,6 @@ class RunningStage(LightningEnum, metaclass=_DeprecationManagingEnumMeta):
|
|||
VALIDATING = "validate"
|
||||
TESTING = "test"
|
||||
PREDICTING = "predict"
|
||||
TUNING = "tune"
|
||||
|
||||
@property
|
||||
def evaluating(self) -> bool:
|
||||
|
@ -115,17 +74,6 @@ class RunningStage(LightningEnum, metaclass=_DeprecationManagingEnumMeta):
|
|||
return "val"
|
||||
return self.value
|
||||
|
||||
def deprecate(self) -> None:
|
||||
if self == self.TUNING:
|
||||
rank_zero_deprecation(
|
||||
f"`RunningStage.{self.name}` has been deprecated in v1.8.0 and will be removed in v2.0.0."
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def _without_tune(cls) -> List["RunningStage"]:
|
||||
fns = [fn for fn in cls if fn != "tune"]
|
||||
return fns
|
||||
|
||||
|
||||
@dataclass
|
||||
class TrainerState:
|
||||
|
|
|
@ -96,7 +96,7 @@ from pytorch_lightning.utilities.data import has_len_all_ranks
|
|||
from pytorch_lightning.utilities.exceptions import ExitGracefullyException, MisconfigurationException
|
||||
from pytorch_lightning.utilities.imports import _fault_tolerant_training
|
||||
from pytorch_lightning.utilities.model_helpers import is_overridden
|
||||
from pytorch_lightning.utilities.rank_zero import rank_zero_deprecation, rank_zero_info, rank_zero_warn
|
||||
from pytorch_lightning.utilities.rank_zero import rank_zero_info, rank_zero_warn
|
||||
from pytorch_lightning.utilities.seed import isolate_rng
|
||||
from pytorch_lightning.utilities.types import (
|
||||
_EVALUATE_OUTPUT,
|
||||
|
@ -1891,20 +1891,6 @@ class Trainer:
|
|||
elif self.predicting:
|
||||
self.state.stage = None
|
||||
|
||||
@property
|
||||
def tuning(self) -> bool:
|
||||
rank_zero_deprecation("`Trainer.tuning` has been deprecated in v1.8.0 and will be removed in v2.0.0.")
|
||||
return self.state.stage == RunningStage.TUNING
|
||||
|
||||
@tuning.setter
|
||||
def tuning(self, val: bool) -> None:
|
||||
rank_zero_deprecation("Setting `Trainer.tuning` has been deprecated in v1.8.0 and will be removed in v2.0.0.")
|
||||
|
||||
if val:
|
||||
self.state.stage = RunningStage.TUNING
|
||||
elif self.tuning:
|
||||
self.state.stage = None
|
||||
|
||||
@property
|
||||
def validating(self) -> bool:
|
||||
return self.state.stage == RunningStage.VALIDATING
|
||||
|
|
|
@ -20,7 +20,6 @@ import torch
|
|||
from lightning_utilities.test.warning import no_warning_call
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
from pytorch_lightning import Trainer
|
||||
from pytorch_lightning.accelerators.cpu import CPUAccelerator
|
||||
from pytorch_lightning.core.mixins.device_dtype_mixin import DeviceDtypeModuleMixin
|
||||
from pytorch_lightning.demos.boring_classes import BoringModel, RandomDataset
|
||||
|
@ -30,7 +29,6 @@ from pytorch_lightning.overrides.fairscale import LightningShardedDataParallel,
|
|||
from pytorch_lightning.plugins.environments import LightningEnvironment
|
||||
from pytorch_lightning.strategies.bagua import LightningBaguaModule
|
||||
from pytorch_lightning.strategies.utils import on_colab_kaggle
|
||||
from pytorch_lightning.trainer.states import RunningStage, TrainerFn
|
||||
from pytorch_lightning.utilities.apply_func import (
|
||||
apply_to_collection,
|
||||
apply_to_collections,
|
||||
|
@ -271,30 +269,6 @@ def test_v1_10_deprecated_accelerator_setup_environment_method():
|
|||
CPUAccelerator().setup_environment(torch.device("cpu"))
|
||||
|
||||
|
||||
def test_tuning_enum():
|
||||
with pytest.deprecated_call(
|
||||
match="`TrainerFn.TUNING` has been deprecated in v1.8.0 and will be removed in v2.0.0."
|
||||
):
|
||||
TrainerFn.TUNING
|
||||
|
||||
with pytest.deprecated_call(
|
||||
match="`RunningStage.TUNING` has been deprecated in v1.8.0 and will be removed in v2.0.0."
|
||||
):
|
||||
RunningStage.TUNING
|
||||
|
||||
|
||||
def test_tuning_trainer_property():
|
||||
trainer = Trainer()
|
||||
|
||||
with pytest.deprecated_call(match="`Trainer.tuning` has been deprecated in v1.8.0 and will be removed in v2.0.0."):
|
||||
trainer.tuning
|
||||
|
||||
with pytest.deprecated_call(
|
||||
match="Setting `Trainer.tuning` has been deprecated in v1.8.0 and will be removed in v2.0.0."
|
||||
):
|
||||
trainer.tuning = True
|
||||
|
||||
|
||||
def test_v1_8_1_deprecated_rank_zero_only():
|
||||
from pytorch_lightning.utilities.distributed import rank_zero_only
|
||||
|
||||
|
|
|
@ -150,8 +150,7 @@ def test_loops_restore(tmpdir):
|
|||
trainer = Trainer(**trainer_args)
|
||||
trainer.strategy.connect(model)
|
||||
|
||||
trainer_fns = [fn for fn in TrainerFn._without_tune()]
|
||||
|
||||
trainer_fns = list(TrainerFn)
|
||||
for fn in trainer_fns:
|
||||
trainer_fn = getattr(trainer, f"{fn}_loop")
|
||||
trainer_fn.load_state_dict = mock.Mock()
|
||||
|
|
Loading…
Reference in New Issue