Use LRScheduler for torch >= 1.14 otherwise use _LRScheduler (#15768)

Co-authored-by: Jirka Borovec <6035284+Borda@users.noreply.github.com>
Co-authored-by: Carlos Mocholí <carlossmocholi@gmail.com>
This commit is contained in:
Quazi Marufur Rahman 2022-12-12 23:33:26 +10:00 committed by GitHub
parent 4fea6bf43e
commit 2577285dd5
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
11 changed files with 34 additions and 36 deletions

View File

@ -201,7 +201,7 @@ If the scheduler you want needs other arguments, add them via the CLI (no need t
python main.py fit --lr_scheduler=ReduceLROnPlateau --lr_scheduler.monitor=epoch
Furthermore, any custom subclass of ``torch.optim.lr_scheduler._LRScheduler`` can be used as learning rate scheduler:
Furthermore, any custom subclass of ``torch.optim.lr_scheduler.LRScheduler`` can be used as learning rate scheduler:
.. code:: python

View File

@ -19,7 +19,7 @@ from torch import Tensor
from torch.optim import Optimizer
from typing_extensions import Protocol, runtime_checkable
from lightning_lite.utilities.imports import _TORCH_GREATER_EQUAL_1_13
from lightning_lite.utilities.imports import _TORCH_GREATER_EQUAL_1_13, _TORCH_GREATER_EQUAL_1_14
_PATH = Union[str, Path]
_DEVICE = Union[torch.device, str, int]
@ -63,7 +63,7 @@ class CollectibleGroup(Protocol):
# Inferred from `torch.optim.lr_scheduler.pyi`
# Missing attributes were added to improve typing
@runtime_checkable
class _LRScheduler(_Stateful[str], Protocol):
class LRScheduler(_Stateful[str], Protocol):
optimizer: Optimizer
base_lrs: List[float]
@ -74,6 +74,11 @@ class _LRScheduler(_Stateful[str], Protocol):
...
_TORCH_LRSCHEDULER = (
torch.optim.lr_scheduler.LRScheduler if _TORCH_GREATER_EQUAL_1_14 else torch.optim.lr_scheduler._LRScheduler
)
# Inferred from `torch.optim.lr_scheduler.pyi`
# Missing attributes were added to improve typing
@runtime_checkable

View File

@ -23,7 +23,7 @@ from torch import nn, Tensor
from torch.optim.swa_utils import SWALR
import pytorch_lightning as pl
from lightning_lite.utilities.types import _LRScheduler
from lightning_lite.utilities.types import LRScheduler
from pytorch_lightning.callbacks.callback import Callback
from pytorch_lightning.strategies import DDPFullyShardedStrategy, DeepSpeedStrategy
from pytorch_lightning.strategies.fully_sharded_native import DDPFullyShardedNativeStrategy
@ -125,7 +125,7 @@ class StochasticWeightAveraging(Callback):
self._model_contains_batch_norm: Optional[bool] = None
self._average_model: Optional["pl.LightningModule"] = None
self._initialized = False
self._swa_scheduler: Optional[_LRScheduler] = None
self._swa_scheduler: Optional[LRScheduler] = None
self._scheduler_state: Optional[Dict] = None
self._init_n_averaged = 0
self._latest_update_epoch = -1
@ -192,7 +192,7 @@ class StochasticWeightAveraging(Callback):
assert trainer.max_epochs is not None
self._swa_scheduler = cast(
_LRScheduler,
LRScheduler,
SWALR(
optimizer,
swa_lr=self._swa_lrs, # type: ignore[arg-type]

View File

@ -253,8 +253,8 @@ def _configure_optimizers(
" Output from `model.configure_optimizers()` should be one of:\n"
" * `Optimizer`\n"
" * [`Optimizer`]\n"
" * ([`Optimizer`], [`_LRScheduler`])\n"
' * {"optimizer": `Optimizer`, (optional) "lr_scheduler": `_LRScheduler`}\n'
" * ([`Optimizer`], [`LRScheduler`])\n"
' * {"optimizer": `Optimizer`, (optional) "lr_scheduler": `LRScheduler`}\n'
' * A list of the previously described dict format, with an optional "frequency" key (int)'
)
return optimizers, lr_schedulers, optimizer_frequencies, monitor

View File

@ -18,9 +18,9 @@ import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor
from torch.optim import Optimizer
from torch.optim.lr_scheduler import _LRScheduler
from torch.utils.data import DataLoader, Dataset, IterableDataset, Subset
from lightning_lite.utilities.types import _TORCH_LRSCHEDULER
from pytorch_lightning import LightningDataModule, LightningModule
from pytorch_lightning.core.optimizer import LightningOptimizer
from pytorch_lightning.utilities.types import EPOCH_OUTPUT, STEP_OUTPUT
@ -137,7 +137,7 @@ class BoringModel(LightningModule):
outputs = cast(List[Dict[str, Tensor]], outputs)
torch.stack([x["y"] for x in outputs]).mean()
def configure_optimizers(self) -> Tuple[List[torch.optim.Optimizer], List[_LRScheduler]]:
def configure_optimizers(self) -> Tuple[List[torch.optim.Optimizer], List[_TORCH_LRSCHEDULER]]:
optimizer = torch.optim.SGD(self.layer.parameters(), lr=0.1)
lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1)
return [optimizer], [lr_scheduler]

View File

@ -34,7 +34,7 @@ from lightning_lite.plugins import ClusterEnvironment
from lightning_lite.utilities.enums import AMPType, PrecisionType
from lightning_lite.utilities.optimizer import _optimizers_to_device
from lightning_lite.utilities.seed import reset_seed
from lightning_lite.utilities.types import _LRScheduler, _PATH, ReduceLROnPlateau
from lightning_lite.utilities.types import _PATH, LRScheduler, ReduceLROnPlateau
from pytorch_lightning.accelerators.cuda import CUDAAccelerator
from pytorch_lightning.core.optimizer import _init_optimizers_and_lr_schedulers
from pytorch_lightning.overrides.base import _LightningModuleWrapperBase, _LightningPrecisionModuleWrapperBase
@ -426,7 +426,7 @@ class DeepSpeedStrategy(DDPStrategy):
self,
model: Module,
optimizer: Optional[Optimizer],
lr_scheduler: Optional[Union[_LRScheduler, ReduceLROnPlateau]] = None,
lr_scheduler: Optional[Union[LRScheduler, ReduceLROnPlateau]] = None,
) -> Tuple["deepspeed.DeepSpeedEngine", Optimizer]:
"""Initialize one model and one optimizer with an optional learning rate scheduler.

View File

@ -9,7 +9,7 @@ from torch import Tensor
import pytorch_lightning as pl
from lightning_lite.utilities.enums import PrecisionType
from lightning_lite.utilities.types import _LRScheduler, ReduceLROnPlateau
from lightning_lite.utilities.types import LRScheduler, ReduceLROnPlateau
from pytorch_lightning.strategies.strategy import Strategy, TBroadcast
from pytorch_lightning.utilities.data import extract_batch_size
from pytorch_lightning.utilities.exceptions import MisconfigurationException
@ -312,7 +312,7 @@ class HiveMindScheduler:
base_lrs: List[float]
def __init__(self, optimizer: "hivemind.Optimizer", scheduler: _LRScheduler) -> None:
def __init__(self, optimizer: "hivemind.Optimizer", scheduler: LRScheduler) -> None:
# copy most of the `Scheduler` methods into this instance. `__del__` is skipped in case the scheduler has
# implemented custom logic which we would not want to call on destruction of the `HiveMindScheduler`
self.__dict__ = {k: v for k, v in scheduler.__dict__.items() if k not in ("step", "__del__")}

View File

@ -21,14 +21,14 @@ from typing import Any, cast, Dict, List, Optional, TYPE_CHECKING, Union
import numpy as np
import torch
from lightning_utilities.core.imports import RequirementCache
from torch.optim.lr_scheduler import _LRScheduler
import pytorch_lightning as pl
from lightning_lite.utilities.types import _TORCH_LRSCHEDULER
from pytorch_lightning.callbacks import Callback
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.parsing import lightning_hasattr, lightning_setattr
from pytorch_lightning.utilities.rank_zero import rank_zero_warn
from pytorch_lightning.utilities.types import LRSchedulerConfig, STEP_OUTPUT
from pytorch_lightning.utilities.types import LRScheduler, LRSchedulerConfig, STEP_OUTPUT
# check if ipywidgets is installed before importing tqdm.auto
# to ensure it won't fail and a progress bar is displayed
@ -124,7 +124,7 @@ class _LRFinder:
args = (optimizer, self.lr_max, self.num_training)
scheduler = _LinearLR(*args) if self.mode == "linear" else _ExponentialLR(*args)
scheduler = cast(pl.utilities.types._LRScheduler, scheduler)
scheduler = cast(LRScheduler, scheduler)
trainer.strategy.optimizers = [optimizer]
trainer.strategy.lr_scheduler_configs = [LRSchedulerConfig(scheduler, interval="step", opt_idx=0)]
@ -404,7 +404,7 @@ class _LRCallback(Callback):
self.losses.append(smoothed_loss)
class _LinearLR(_LRScheduler):
class _LinearLR(_TORCH_LRSCHEDULER):
"""Linearly increases the learning rate between two boundaries over a number of iterations.
Args:
@ -423,7 +423,7 @@ class _LinearLR(_LRScheduler):
self.num_iter = num_iter
super().__init__(optimizer, last_epoch)
def get_lr(self) -> List[float]: # type: ignore[override]
def get_lr(self) -> List[float]:
curr_iter = self.last_epoch + 1
r = curr_iter / self.num_iter
@ -439,7 +439,7 @@ class _LinearLR(_LRScheduler):
return self._lr
class _ExponentialLR(_LRScheduler):
class _ExponentialLR(_TORCH_LRSCHEDULER):
"""Exponentially increases the learning rate between two boundaries over a number of iterations.
Arguments:
@ -458,7 +458,7 @@ class _ExponentialLR(_LRScheduler):
self.num_iter = num_iter
super().__init__(optimizer, last_epoch)
def get_lr(self) -> List[float]: # type: ignore[override]
def get_lr(self) -> List[float]:
curr_iter = self.last_epoch + 1
r = curr_iter / self.num_iter

View File

@ -27,14 +27,7 @@ from torch.utils.data import DataLoader
from torchmetrics import Metric
from typing_extensions import Protocol, runtime_checkable
try:
from torch.optim.lr_scheduler import LRScheduler as TorchLRScheduler
except ImportError:
# For torch <= 1.13.x
# TODO: Remove once minimum torch version is 1.14 (or 2.0)
from torch.optim.lr_scheduler import _LRScheduler as TorchLRScheduler
from lightning_lite.utilities.types import _LRScheduler, ProcessGroup, ReduceLROnPlateau
from lightning_lite.utilities.types import _TORCH_LRSCHEDULER, LRScheduler, ProcessGroup, ReduceLROnPlateau
_NUMBER = Union[int, float]
_METRIC = Union[Metric, Tensor, _NUMBER]
@ -118,15 +111,15 @@ class DistributedDataParallel(Protocol):
# todo: improve LRSchedulerType naming/typing
LRSchedulerTypeTuple = (TorchLRScheduler, torch.optim.lr_scheduler.ReduceLROnPlateau)
LRSchedulerTypeUnion = Union[TorchLRScheduler, torch.optim.lr_scheduler.ReduceLROnPlateau]
LRSchedulerType = Union[Type[TorchLRScheduler], Type[torch.optim.lr_scheduler.ReduceLROnPlateau]]
LRSchedulerPLType = Union[_LRScheduler, ReduceLROnPlateau]
LRSchedulerTypeTuple = (_TORCH_LRSCHEDULER, torch.optim.lr_scheduler.ReduceLROnPlateau)
LRSchedulerTypeUnion = Union[_TORCH_LRSCHEDULER, torch.optim.lr_scheduler.ReduceLROnPlateau]
LRSchedulerType = Union[Type[_TORCH_LRSCHEDULER], Type[torch.optim.lr_scheduler.ReduceLROnPlateau]]
LRSchedulerPLType = Union[LRScheduler, ReduceLROnPlateau]
@dataclass
class LRSchedulerConfig:
scheduler: Union[_LRScheduler, ReduceLROnPlateau]
scheduler: Union[_TORCH_LRSCHEDULER, ReduceLROnPlateau]
# no custom name
name: Optional[str] = None
# after epoch is over

View File

@ -965,7 +965,7 @@ def test_lr_scheduler_step_not_called(tmpdir):
with patch("torch.optim.lr_scheduler.StepLR.step") as lr_step:
trainer.fit(model)
# If a lr scheduler inherits `torch.optim.lr_scheduler._LRScheduler`,
# If a lr scheduler inherits `torch.optim.lr_scheduler.LRScheduler`,
# `.step()` is called once during its instantiation.
# Thus, the call count should be 1, not 0.
assert lr_step.call_count == 1

View File

@ -761,7 +761,7 @@ def test_lr_scheduler_step_hook(tmpdir):
trainer.fit(model)
assert mock_method_epoch.mock_calls == [call(epoch=e) for e in range(max_epochs)]
# first step is called by PyTorch _LRScheduler
# first step is called by PyTorch LRScheduler
assert mock_method_step.call_count == max_epochs * limit_train_batches + 1