From 2577285dd5b31d3457d8d40a738228137cf8f87a Mon Sep 17 00:00:00 2001 From: Quazi Marufur Rahman Date: Mon, 12 Dec 2022 23:33:26 +1000 Subject: [PATCH] Use LRScheduler for torch >= 1.14 otherwise use _LRScheduler (#15768) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Jirka Borovec <6035284+Borda@users.noreply.github.com> Co-authored-by: Carlos MocholĂ­ --- .../cli/lightning_cli_intermediate_2.rst | 2 +- src/lightning_lite/utilities/types.py | 9 +++++++-- .../callbacks/stochastic_weight_avg.py | 6 +++--- src/pytorch_lightning/core/optimizer.py | 4 ++-- src/pytorch_lightning/demos/boring_classes.py | 4 ++-- src/pytorch_lightning/strategies/deepspeed.py | 4 ++-- src/pytorch_lightning/strategies/hivemind.py | 4 ++-- src/pytorch_lightning/tuner/lr_finder.py | 14 +++++++------- src/pytorch_lightning/utilities/types.py | 19 ++++++------------- .../optimization/test_manual_optimization.py | 2 +- .../trainer/optimization/test_optimizers.py | 2 +- 11 files changed, 34 insertions(+), 36 deletions(-) diff --git a/docs/source-pytorch/cli/lightning_cli_intermediate_2.rst b/docs/source-pytorch/cli/lightning_cli_intermediate_2.rst index 04a2795840..c60c339607 100644 --- a/docs/source-pytorch/cli/lightning_cli_intermediate_2.rst +++ b/docs/source-pytorch/cli/lightning_cli_intermediate_2.rst @@ -201,7 +201,7 @@ If the scheduler you want needs other arguments, add them via the CLI (no need t python main.py fit --lr_scheduler=ReduceLROnPlateau --lr_scheduler.monitor=epoch -Furthermore, any custom subclass of ``torch.optim.lr_scheduler._LRScheduler`` can be used as learning rate scheduler: +Furthermore, any custom subclass of ``torch.optim.lr_scheduler.LRScheduler`` can be used as learning rate scheduler: .. code:: python diff --git a/src/lightning_lite/utilities/types.py b/src/lightning_lite/utilities/types.py index a3ee70ea68..de834212ec 100644 --- a/src/lightning_lite/utilities/types.py +++ b/src/lightning_lite/utilities/types.py @@ -19,7 +19,7 @@ from torch import Tensor from torch.optim import Optimizer from typing_extensions import Protocol, runtime_checkable -from lightning_lite.utilities.imports import _TORCH_GREATER_EQUAL_1_13 +from lightning_lite.utilities.imports import _TORCH_GREATER_EQUAL_1_13, _TORCH_GREATER_EQUAL_1_14 _PATH = Union[str, Path] _DEVICE = Union[torch.device, str, int] @@ -63,7 +63,7 @@ class CollectibleGroup(Protocol): # Inferred from `torch.optim.lr_scheduler.pyi` # Missing attributes were added to improve typing @runtime_checkable -class _LRScheduler(_Stateful[str], Protocol): +class LRScheduler(_Stateful[str], Protocol): optimizer: Optimizer base_lrs: List[float] @@ -74,6 +74,11 @@ class _LRScheduler(_Stateful[str], Protocol): ... +_TORCH_LRSCHEDULER = ( + torch.optim.lr_scheduler.LRScheduler if _TORCH_GREATER_EQUAL_1_14 else torch.optim.lr_scheduler._LRScheduler +) + + # Inferred from `torch.optim.lr_scheduler.pyi` # Missing attributes were added to improve typing @runtime_checkable diff --git a/src/pytorch_lightning/callbacks/stochastic_weight_avg.py b/src/pytorch_lightning/callbacks/stochastic_weight_avg.py index 53111868f3..b2577358eb 100644 --- a/src/pytorch_lightning/callbacks/stochastic_weight_avg.py +++ b/src/pytorch_lightning/callbacks/stochastic_weight_avg.py @@ -23,7 +23,7 @@ from torch import nn, Tensor from torch.optim.swa_utils import SWALR import pytorch_lightning as pl -from lightning_lite.utilities.types import _LRScheduler +from lightning_lite.utilities.types import LRScheduler from pytorch_lightning.callbacks.callback import Callback from pytorch_lightning.strategies import DDPFullyShardedStrategy, DeepSpeedStrategy from pytorch_lightning.strategies.fully_sharded_native import DDPFullyShardedNativeStrategy @@ -125,7 +125,7 @@ class StochasticWeightAveraging(Callback): self._model_contains_batch_norm: Optional[bool] = None self._average_model: Optional["pl.LightningModule"] = None self._initialized = False - self._swa_scheduler: Optional[_LRScheduler] = None + self._swa_scheduler: Optional[LRScheduler] = None self._scheduler_state: Optional[Dict] = None self._init_n_averaged = 0 self._latest_update_epoch = -1 @@ -192,7 +192,7 @@ class StochasticWeightAveraging(Callback): assert trainer.max_epochs is not None self._swa_scheduler = cast( - _LRScheduler, + LRScheduler, SWALR( optimizer, swa_lr=self._swa_lrs, # type: ignore[arg-type] diff --git a/src/pytorch_lightning/core/optimizer.py b/src/pytorch_lightning/core/optimizer.py index e1a834f8c8..c18c1e2697 100644 --- a/src/pytorch_lightning/core/optimizer.py +++ b/src/pytorch_lightning/core/optimizer.py @@ -253,8 +253,8 @@ def _configure_optimizers( " Output from `model.configure_optimizers()` should be one of:\n" " * `Optimizer`\n" " * [`Optimizer`]\n" - " * ([`Optimizer`], [`_LRScheduler`])\n" - ' * {"optimizer": `Optimizer`, (optional) "lr_scheduler": `_LRScheduler`}\n' + " * ([`Optimizer`], [`LRScheduler`])\n" + ' * {"optimizer": `Optimizer`, (optional) "lr_scheduler": `LRScheduler`}\n' ' * A list of the previously described dict format, with an optional "frequency" key (int)' ) return optimizers, lr_schedulers, optimizer_frequencies, monitor diff --git a/src/pytorch_lightning/demos/boring_classes.py b/src/pytorch_lightning/demos/boring_classes.py index 9967cdddf7..db2c509c98 100644 --- a/src/pytorch_lightning/demos/boring_classes.py +++ b/src/pytorch_lightning/demos/boring_classes.py @@ -18,9 +18,9 @@ import torch.nn as nn import torch.nn.functional as F from torch import Tensor from torch.optim import Optimizer -from torch.optim.lr_scheduler import _LRScheduler from torch.utils.data import DataLoader, Dataset, IterableDataset, Subset +from lightning_lite.utilities.types import _TORCH_LRSCHEDULER from pytorch_lightning import LightningDataModule, LightningModule from pytorch_lightning.core.optimizer import LightningOptimizer from pytorch_lightning.utilities.types import EPOCH_OUTPUT, STEP_OUTPUT @@ -137,7 +137,7 @@ class BoringModel(LightningModule): outputs = cast(List[Dict[str, Tensor]], outputs) torch.stack([x["y"] for x in outputs]).mean() - def configure_optimizers(self) -> Tuple[List[torch.optim.Optimizer], List[_LRScheduler]]: + def configure_optimizers(self) -> Tuple[List[torch.optim.Optimizer], List[_TORCH_LRSCHEDULER]]: optimizer = torch.optim.SGD(self.layer.parameters(), lr=0.1) lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1) return [optimizer], [lr_scheduler] diff --git a/src/pytorch_lightning/strategies/deepspeed.py b/src/pytorch_lightning/strategies/deepspeed.py index 465c65bfa7..5e2dae688f 100644 --- a/src/pytorch_lightning/strategies/deepspeed.py +++ b/src/pytorch_lightning/strategies/deepspeed.py @@ -34,7 +34,7 @@ from lightning_lite.plugins import ClusterEnvironment from lightning_lite.utilities.enums import AMPType, PrecisionType from lightning_lite.utilities.optimizer import _optimizers_to_device from lightning_lite.utilities.seed import reset_seed -from lightning_lite.utilities.types import _LRScheduler, _PATH, ReduceLROnPlateau +from lightning_lite.utilities.types import _PATH, LRScheduler, ReduceLROnPlateau from pytorch_lightning.accelerators.cuda import CUDAAccelerator from pytorch_lightning.core.optimizer import _init_optimizers_and_lr_schedulers from pytorch_lightning.overrides.base import _LightningModuleWrapperBase, _LightningPrecisionModuleWrapperBase @@ -426,7 +426,7 @@ class DeepSpeedStrategy(DDPStrategy): self, model: Module, optimizer: Optional[Optimizer], - lr_scheduler: Optional[Union[_LRScheduler, ReduceLROnPlateau]] = None, + lr_scheduler: Optional[Union[LRScheduler, ReduceLROnPlateau]] = None, ) -> Tuple["deepspeed.DeepSpeedEngine", Optimizer]: """Initialize one model and one optimizer with an optional learning rate scheduler. diff --git a/src/pytorch_lightning/strategies/hivemind.py b/src/pytorch_lightning/strategies/hivemind.py index 7cad027ac6..61d39367d0 100644 --- a/src/pytorch_lightning/strategies/hivemind.py +++ b/src/pytorch_lightning/strategies/hivemind.py @@ -9,7 +9,7 @@ from torch import Tensor import pytorch_lightning as pl from lightning_lite.utilities.enums import PrecisionType -from lightning_lite.utilities.types import _LRScheduler, ReduceLROnPlateau +from lightning_lite.utilities.types import LRScheduler, ReduceLROnPlateau from pytorch_lightning.strategies.strategy import Strategy, TBroadcast from pytorch_lightning.utilities.data import extract_batch_size from pytorch_lightning.utilities.exceptions import MisconfigurationException @@ -312,7 +312,7 @@ class HiveMindScheduler: base_lrs: List[float] - def __init__(self, optimizer: "hivemind.Optimizer", scheduler: _LRScheduler) -> None: + def __init__(self, optimizer: "hivemind.Optimizer", scheduler: LRScheduler) -> None: # copy most of the `Scheduler` methods into this instance. `__del__` is skipped in case the scheduler has # implemented custom logic which we would not want to call on destruction of the `HiveMindScheduler` self.__dict__ = {k: v for k, v in scheduler.__dict__.items() if k not in ("step", "__del__")} diff --git a/src/pytorch_lightning/tuner/lr_finder.py b/src/pytorch_lightning/tuner/lr_finder.py index 2652267c93..30b2cc5fff 100644 --- a/src/pytorch_lightning/tuner/lr_finder.py +++ b/src/pytorch_lightning/tuner/lr_finder.py @@ -21,14 +21,14 @@ from typing import Any, cast, Dict, List, Optional, TYPE_CHECKING, Union import numpy as np import torch from lightning_utilities.core.imports import RequirementCache -from torch.optim.lr_scheduler import _LRScheduler import pytorch_lightning as pl +from lightning_lite.utilities.types import _TORCH_LRSCHEDULER from pytorch_lightning.callbacks import Callback from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.parsing import lightning_hasattr, lightning_setattr from pytorch_lightning.utilities.rank_zero import rank_zero_warn -from pytorch_lightning.utilities.types import LRSchedulerConfig, STEP_OUTPUT +from pytorch_lightning.utilities.types import LRScheduler, LRSchedulerConfig, STEP_OUTPUT # check if ipywidgets is installed before importing tqdm.auto # to ensure it won't fail and a progress bar is displayed @@ -124,7 +124,7 @@ class _LRFinder: args = (optimizer, self.lr_max, self.num_training) scheduler = _LinearLR(*args) if self.mode == "linear" else _ExponentialLR(*args) - scheduler = cast(pl.utilities.types._LRScheduler, scheduler) + scheduler = cast(LRScheduler, scheduler) trainer.strategy.optimizers = [optimizer] trainer.strategy.lr_scheduler_configs = [LRSchedulerConfig(scheduler, interval="step", opt_idx=0)] @@ -404,7 +404,7 @@ class _LRCallback(Callback): self.losses.append(smoothed_loss) -class _LinearLR(_LRScheduler): +class _LinearLR(_TORCH_LRSCHEDULER): """Linearly increases the learning rate between two boundaries over a number of iterations. Args: @@ -423,7 +423,7 @@ class _LinearLR(_LRScheduler): self.num_iter = num_iter super().__init__(optimizer, last_epoch) - def get_lr(self) -> List[float]: # type: ignore[override] + def get_lr(self) -> List[float]: curr_iter = self.last_epoch + 1 r = curr_iter / self.num_iter @@ -439,7 +439,7 @@ class _LinearLR(_LRScheduler): return self._lr -class _ExponentialLR(_LRScheduler): +class _ExponentialLR(_TORCH_LRSCHEDULER): """Exponentially increases the learning rate between two boundaries over a number of iterations. Arguments: @@ -458,7 +458,7 @@ class _ExponentialLR(_LRScheduler): self.num_iter = num_iter super().__init__(optimizer, last_epoch) - def get_lr(self) -> List[float]: # type: ignore[override] + def get_lr(self) -> List[float]: curr_iter = self.last_epoch + 1 r = curr_iter / self.num_iter diff --git a/src/pytorch_lightning/utilities/types.py b/src/pytorch_lightning/utilities/types.py index d766e4fdb7..db736e9cc2 100644 --- a/src/pytorch_lightning/utilities/types.py +++ b/src/pytorch_lightning/utilities/types.py @@ -27,14 +27,7 @@ from torch.utils.data import DataLoader from torchmetrics import Metric from typing_extensions import Protocol, runtime_checkable -try: - from torch.optim.lr_scheduler import LRScheduler as TorchLRScheduler -except ImportError: - # For torch <= 1.13.x - # TODO: Remove once minimum torch version is 1.14 (or 2.0) - from torch.optim.lr_scheduler import _LRScheduler as TorchLRScheduler - -from lightning_lite.utilities.types import _LRScheduler, ProcessGroup, ReduceLROnPlateau +from lightning_lite.utilities.types import _TORCH_LRSCHEDULER, LRScheduler, ProcessGroup, ReduceLROnPlateau _NUMBER = Union[int, float] _METRIC = Union[Metric, Tensor, _NUMBER] @@ -118,15 +111,15 @@ class DistributedDataParallel(Protocol): # todo: improve LRSchedulerType naming/typing -LRSchedulerTypeTuple = (TorchLRScheduler, torch.optim.lr_scheduler.ReduceLROnPlateau) -LRSchedulerTypeUnion = Union[TorchLRScheduler, torch.optim.lr_scheduler.ReduceLROnPlateau] -LRSchedulerType = Union[Type[TorchLRScheduler], Type[torch.optim.lr_scheduler.ReduceLROnPlateau]] -LRSchedulerPLType = Union[_LRScheduler, ReduceLROnPlateau] +LRSchedulerTypeTuple = (_TORCH_LRSCHEDULER, torch.optim.lr_scheduler.ReduceLROnPlateau) +LRSchedulerTypeUnion = Union[_TORCH_LRSCHEDULER, torch.optim.lr_scheduler.ReduceLROnPlateau] +LRSchedulerType = Union[Type[_TORCH_LRSCHEDULER], Type[torch.optim.lr_scheduler.ReduceLROnPlateau]] +LRSchedulerPLType = Union[LRScheduler, ReduceLROnPlateau] @dataclass class LRSchedulerConfig: - scheduler: Union[_LRScheduler, ReduceLROnPlateau] + scheduler: Union[_TORCH_LRSCHEDULER, ReduceLROnPlateau] # no custom name name: Optional[str] = None # after epoch is over diff --git a/tests/tests_pytorch/trainer/optimization/test_manual_optimization.py b/tests/tests_pytorch/trainer/optimization/test_manual_optimization.py index 0fcacf080a..2224ed8569 100644 --- a/tests/tests_pytorch/trainer/optimization/test_manual_optimization.py +++ b/tests/tests_pytorch/trainer/optimization/test_manual_optimization.py @@ -965,7 +965,7 @@ def test_lr_scheduler_step_not_called(tmpdir): with patch("torch.optim.lr_scheduler.StepLR.step") as lr_step: trainer.fit(model) - # If a lr scheduler inherits `torch.optim.lr_scheduler._LRScheduler`, + # If a lr scheduler inherits `torch.optim.lr_scheduler.LRScheduler`, # `.step()` is called once during its instantiation. # Thus, the call count should be 1, not 0. assert lr_step.call_count == 1 diff --git a/tests/tests_pytorch/trainer/optimization/test_optimizers.py b/tests/tests_pytorch/trainer/optimization/test_optimizers.py index 52fb6ba502..ed821b0d6f 100644 --- a/tests/tests_pytorch/trainer/optimization/test_optimizers.py +++ b/tests/tests_pytorch/trainer/optimization/test_optimizers.py @@ -761,7 +761,7 @@ def test_lr_scheduler_step_hook(tmpdir): trainer.fit(model) assert mock_method_epoch.mock_calls == [call(epoch=e) for e in range(max_epochs)] - # first step is called by PyTorch _LRScheduler + # first step is called by PyTorch LRScheduler assert mock_method_step.call_count == max_epochs * limit_train_batches + 1