Use LRScheduler for torch >= 1.14 otherwise use _LRScheduler (#15768)
Co-authored-by: Jirka Borovec <6035284+Borda@users.noreply.github.com> Co-authored-by: Carlos Mocholí <carlossmocholi@gmail.com>
This commit is contained in:
parent
4fea6bf43e
commit
2577285dd5
|
@ -201,7 +201,7 @@ If the scheduler you want needs other arguments, add them via the CLI (no need t
|
|||
|
||||
python main.py fit --lr_scheduler=ReduceLROnPlateau --lr_scheduler.monitor=epoch
|
||||
|
||||
Furthermore, any custom subclass of ``torch.optim.lr_scheduler._LRScheduler`` can be used as learning rate scheduler:
|
||||
Furthermore, any custom subclass of ``torch.optim.lr_scheduler.LRScheduler`` can be used as learning rate scheduler:
|
||||
|
||||
.. code:: python
|
||||
|
||||
|
|
|
@ -19,7 +19,7 @@ from torch import Tensor
|
|||
from torch.optim import Optimizer
|
||||
from typing_extensions import Protocol, runtime_checkable
|
||||
|
||||
from lightning_lite.utilities.imports import _TORCH_GREATER_EQUAL_1_13
|
||||
from lightning_lite.utilities.imports import _TORCH_GREATER_EQUAL_1_13, _TORCH_GREATER_EQUAL_1_14
|
||||
|
||||
_PATH = Union[str, Path]
|
||||
_DEVICE = Union[torch.device, str, int]
|
||||
|
@ -63,7 +63,7 @@ class CollectibleGroup(Protocol):
|
|||
# Inferred from `torch.optim.lr_scheduler.pyi`
|
||||
# Missing attributes were added to improve typing
|
||||
@runtime_checkable
|
||||
class _LRScheduler(_Stateful[str], Protocol):
|
||||
class LRScheduler(_Stateful[str], Protocol):
|
||||
optimizer: Optimizer
|
||||
base_lrs: List[float]
|
||||
|
||||
|
@ -74,6 +74,11 @@ class _LRScheduler(_Stateful[str], Protocol):
|
|||
...
|
||||
|
||||
|
||||
_TORCH_LRSCHEDULER = (
|
||||
torch.optim.lr_scheduler.LRScheduler if _TORCH_GREATER_EQUAL_1_14 else torch.optim.lr_scheduler._LRScheduler
|
||||
)
|
||||
|
||||
|
||||
# Inferred from `torch.optim.lr_scheduler.pyi`
|
||||
# Missing attributes were added to improve typing
|
||||
@runtime_checkable
|
||||
|
|
|
@ -23,7 +23,7 @@ from torch import nn, Tensor
|
|||
from torch.optim.swa_utils import SWALR
|
||||
|
||||
import pytorch_lightning as pl
|
||||
from lightning_lite.utilities.types import _LRScheduler
|
||||
from lightning_lite.utilities.types import LRScheduler
|
||||
from pytorch_lightning.callbacks.callback import Callback
|
||||
from pytorch_lightning.strategies import DDPFullyShardedStrategy, DeepSpeedStrategy
|
||||
from pytorch_lightning.strategies.fully_sharded_native import DDPFullyShardedNativeStrategy
|
||||
|
@ -125,7 +125,7 @@ class StochasticWeightAveraging(Callback):
|
|||
self._model_contains_batch_norm: Optional[bool] = None
|
||||
self._average_model: Optional["pl.LightningModule"] = None
|
||||
self._initialized = False
|
||||
self._swa_scheduler: Optional[_LRScheduler] = None
|
||||
self._swa_scheduler: Optional[LRScheduler] = None
|
||||
self._scheduler_state: Optional[Dict] = None
|
||||
self._init_n_averaged = 0
|
||||
self._latest_update_epoch = -1
|
||||
|
@ -192,7 +192,7 @@ class StochasticWeightAveraging(Callback):
|
|||
|
||||
assert trainer.max_epochs is not None
|
||||
self._swa_scheduler = cast(
|
||||
_LRScheduler,
|
||||
LRScheduler,
|
||||
SWALR(
|
||||
optimizer,
|
||||
swa_lr=self._swa_lrs, # type: ignore[arg-type]
|
||||
|
|
|
@ -253,8 +253,8 @@ def _configure_optimizers(
|
|||
" Output from `model.configure_optimizers()` should be one of:\n"
|
||||
" * `Optimizer`\n"
|
||||
" * [`Optimizer`]\n"
|
||||
" * ([`Optimizer`], [`_LRScheduler`])\n"
|
||||
' * {"optimizer": `Optimizer`, (optional) "lr_scheduler": `_LRScheduler`}\n'
|
||||
" * ([`Optimizer`], [`LRScheduler`])\n"
|
||||
' * {"optimizer": `Optimizer`, (optional) "lr_scheduler": `LRScheduler`}\n'
|
||||
' * A list of the previously described dict format, with an optional "frequency" key (int)'
|
||||
)
|
||||
return optimizers, lr_schedulers, optimizer_frequencies, monitor
|
||||
|
|
|
@ -18,9 +18,9 @@ import torch.nn as nn
|
|||
import torch.nn.functional as F
|
||||
from torch import Tensor
|
||||
from torch.optim import Optimizer
|
||||
from torch.optim.lr_scheduler import _LRScheduler
|
||||
from torch.utils.data import DataLoader, Dataset, IterableDataset, Subset
|
||||
|
||||
from lightning_lite.utilities.types import _TORCH_LRSCHEDULER
|
||||
from pytorch_lightning import LightningDataModule, LightningModule
|
||||
from pytorch_lightning.core.optimizer import LightningOptimizer
|
||||
from pytorch_lightning.utilities.types import EPOCH_OUTPUT, STEP_OUTPUT
|
||||
|
@ -137,7 +137,7 @@ class BoringModel(LightningModule):
|
|||
outputs = cast(List[Dict[str, Tensor]], outputs)
|
||||
torch.stack([x["y"] for x in outputs]).mean()
|
||||
|
||||
def configure_optimizers(self) -> Tuple[List[torch.optim.Optimizer], List[_LRScheduler]]:
|
||||
def configure_optimizers(self) -> Tuple[List[torch.optim.Optimizer], List[_TORCH_LRSCHEDULER]]:
|
||||
optimizer = torch.optim.SGD(self.layer.parameters(), lr=0.1)
|
||||
lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1)
|
||||
return [optimizer], [lr_scheduler]
|
||||
|
|
|
@ -34,7 +34,7 @@ from lightning_lite.plugins import ClusterEnvironment
|
|||
from lightning_lite.utilities.enums import AMPType, PrecisionType
|
||||
from lightning_lite.utilities.optimizer import _optimizers_to_device
|
||||
from lightning_lite.utilities.seed import reset_seed
|
||||
from lightning_lite.utilities.types import _LRScheduler, _PATH, ReduceLROnPlateau
|
||||
from lightning_lite.utilities.types import _PATH, LRScheduler, ReduceLROnPlateau
|
||||
from pytorch_lightning.accelerators.cuda import CUDAAccelerator
|
||||
from pytorch_lightning.core.optimizer import _init_optimizers_and_lr_schedulers
|
||||
from pytorch_lightning.overrides.base import _LightningModuleWrapperBase, _LightningPrecisionModuleWrapperBase
|
||||
|
@ -426,7 +426,7 @@ class DeepSpeedStrategy(DDPStrategy):
|
|||
self,
|
||||
model: Module,
|
||||
optimizer: Optional[Optimizer],
|
||||
lr_scheduler: Optional[Union[_LRScheduler, ReduceLROnPlateau]] = None,
|
||||
lr_scheduler: Optional[Union[LRScheduler, ReduceLROnPlateau]] = None,
|
||||
) -> Tuple["deepspeed.DeepSpeedEngine", Optimizer]:
|
||||
"""Initialize one model and one optimizer with an optional learning rate scheduler.
|
||||
|
||||
|
|
|
@ -9,7 +9,7 @@ from torch import Tensor
|
|||
|
||||
import pytorch_lightning as pl
|
||||
from lightning_lite.utilities.enums import PrecisionType
|
||||
from lightning_lite.utilities.types import _LRScheduler, ReduceLROnPlateau
|
||||
from lightning_lite.utilities.types import LRScheduler, ReduceLROnPlateau
|
||||
from pytorch_lightning.strategies.strategy import Strategy, TBroadcast
|
||||
from pytorch_lightning.utilities.data import extract_batch_size
|
||||
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
||||
|
@ -312,7 +312,7 @@ class HiveMindScheduler:
|
|||
|
||||
base_lrs: List[float]
|
||||
|
||||
def __init__(self, optimizer: "hivemind.Optimizer", scheduler: _LRScheduler) -> None:
|
||||
def __init__(self, optimizer: "hivemind.Optimizer", scheduler: LRScheduler) -> None:
|
||||
# copy most of the `Scheduler` methods into this instance. `__del__` is skipped in case the scheduler has
|
||||
# implemented custom logic which we would not want to call on destruction of the `HiveMindScheduler`
|
||||
self.__dict__ = {k: v for k, v in scheduler.__dict__.items() if k not in ("step", "__del__")}
|
||||
|
|
|
@ -21,14 +21,14 @@ from typing import Any, cast, Dict, List, Optional, TYPE_CHECKING, Union
|
|||
import numpy as np
|
||||
import torch
|
||||
from lightning_utilities.core.imports import RequirementCache
|
||||
from torch.optim.lr_scheduler import _LRScheduler
|
||||
|
||||
import pytorch_lightning as pl
|
||||
from lightning_lite.utilities.types import _TORCH_LRSCHEDULER
|
||||
from pytorch_lightning.callbacks import Callback
|
||||
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
||||
from pytorch_lightning.utilities.parsing import lightning_hasattr, lightning_setattr
|
||||
from pytorch_lightning.utilities.rank_zero import rank_zero_warn
|
||||
from pytorch_lightning.utilities.types import LRSchedulerConfig, STEP_OUTPUT
|
||||
from pytorch_lightning.utilities.types import LRScheduler, LRSchedulerConfig, STEP_OUTPUT
|
||||
|
||||
# check if ipywidgets is installed before importing tqdm.auto
|
||||
# to ensure it won't fail and a progress bar is displayed
|
||||
|
@ -124,7 +124,7 @@ class _LRFinder:
|
|||
|
||||
args = (optimizer, self.lr_max, self.num_training)
|
||||
scheduler = _LinearLR(*args) if self.mode == "linear" else _ExponentialLR(*args)
|
||||
scheduler = cast(pl.utilities.types._LRScheduler, scheduler)
|
||||
scheduler = cast(LRScheduler, scheduler)
|
||||
|
||||
trainer.strategy.optimizers = [optimizer]
|
||||
trainer.strategy.lr_scheduler_configs = [LRSchedulerConfig(scheduler, interval="step", opt_idx=0)]
|
||||
|
@ -404,7 +404,7 @@ class _LRCallback(Callback):
|
|||
self.losses.append(smoothed_loss)
|
||||
|
||||
|
||||
class _LinearLR(_LRScheduler):
|
||||
class _LinearLR(_TORCH_LRSCHEDULER):
|
||||
"""Linearly increases the learning rate between two boundaries over a number of iterations.
|
||||
|
||||
Args:
|
||||
|
@ -423,7 +423,7 @@ class _LinearLR(_LRScheduler):
|
|||
self.num_iter = num_iter
|
||||
super().__init__(optimizer, last_epoch)
|
||||
|
||||
def get_lr(self) -> List[float]: # type: ignore[override]
|
||||
def get_lr(self) -> List[float]:
|
||||
curr_iter = self.last_epoch + 1
|
||||
r = curr_iter / self.num_iter
|
||||
|
||||
|
@ -439,7 +439,7 @@ class _LinearLR(_LRScheduler):
|
|||
return self._lr
|
||||
|
||||
|
||||
class _ExponentialLR(_LRScheduler):
|
||||
class _ExponentialLR(_TORCH_LRSCHEDULER):
|
||||
"""Exponentially increases the learning rate between two boundaries over a number of iterations.
|
||||
|
||||
Arguments:
|
||||
|
@ -458,7 +458,7 @@ class _ExponentialLR(_LRScheduler):
|
|||
self.num_iter = num_iter
|
||||
super().__init__(optimizer, last_epoch)
|
||||
|
||||
def get_lr(self) -> List[float]: # type: ignore[override]
|
||||
def get_lr(self) -> List[float]:
|
||||
curr_iter = self.last_epoch + 1
|
||||
r = curr_iter / self.num_iter
|
||||
|
||||
|
|
|
@ -27,14 +27,7 @@ from torch.utils.data import DataLoader
|
|||
from torchmetrics import Metric
|
||||
from typing_extensions import Protocol, runtime_checkable
|
||||
|
||||
try:
|
||||
from torch.optim.lr_scheduler import LRScheduler as TorchLRScheduler
|
||||
except ImportError:
|
||||
# For torch <= 1.13.x
|
||||
# TODO: Remove once minimum torch version is 1.14 (or 2.0)
|
||||
from torch.optim.lr_scheduler import _LRScheduler as TorchLRScheduler
|
||||
|
||||
from lightning_lite.utilities.types import _LRScheduler, ProcessGroup, ReduceLROnPlateau
|
||||
from lightning_lite.utilities.types import _TORCH_LRSCHEDULER, LRScheduler, ProcessGroup, ReduceLROnPlateau
|
||||
|
||||
_NUMBER = Union[int, float]
|
||||
_METRIC = Union[Metric, Tensor, _NUMBER]
|
||||
|
@ -118,15 +111,15 @@ class DistributedDataParallel(Protocol):
|
|||
|
||||
|
||||
# todo: improve LRSchedulerType naming/typing
|
||||
LRSchedulerTypeTuple = (TorchLRScheduler, torch.optim.lr_scheduler.ReduceLROnPlateau)
|
||||
LRSchedulerTypeUnion = Union[TorchLRScheduler, torch.optim.lr_scheduler.ReduceLROnPlateau]
|
||||
LRSchedulerType = Union[Type[TorchLRScheduler], Type[torch.optim.lr_scheduler.ReduceLROnPlateau]]
|
||||
LRSchedulerPLType = Union[_LRScheduler, ReduceLROnPlateau]
|
||||
LRSchedulerTypeTuple = (_TORCH_LRSCHEDULER, torch.optim.lr_scheduler.ReduceLROnPlateau)
|
||||
LRSchedulerTypeUnion = Union[_TORCH_LRSCHEDULER, torch.optim.lr_scheduler.ReduceLROnPlateau]
|
||||
LRSchedulerType = Union[Type[_TORCH_LRSCHEDULER], Type[torch.optim.lr_scheduler.ReduceLROnPlateau]]
|
||||
LRSchedulerPLType = Union[LRScheduler, ReduceLROnPlateau]
|
||||
|
||||
|
||||
@dataclass
|
||||
class LRSchedulerConfig:
|
||||
scheduler: Union[_LRScheduler, ReduceLROnPlateau]
|
||||
scheduler: Union[_TORCH_LRSCHEDULER, ReduceLROnPlateau]
|
||||
# no custom name
|
||||
name: Optional[str] = None
|
||||
# after epoch is over
|
||||
|
|
|
@ -965,7 +965,7 @@ def test_lr_scheduler_step_not_called(tmpdir):
|
|||
with patch("torch.optim.lr_scheduler.StepLR.step") as lr_step:
|
||||
trainer.fit(model)
|
||||
|
||||
# If a lr scheduler inherits `torch.optim.lr_scheduler._LRScheduler`,
|
||||
# If a lr scheduler inherits `torch.optim.lr_scheduler.LRScheduler`,
|
||||
# `.step()` is called once during its instantiation.
|
||||
# Thus, the call count should be 1, not 0.
|
||||
assert lr_step.call_count == 1
|
||||
|
|
|
@ -761,7 +761,7 @@ def test_lr_scheduler_step_hook(tmpdir):
|
|||
trainer.fit(model)
|
||||
|
||||
assert mock_method_epoch.mock_calls == [call(epoch=e) for e in range(max_epochs)]
|
||||
# first step is called by PyTorch _LRScheduler
|
||||
# first step is called by PyTorch LRScheduler
|
||||
assert mock_method_step.call_count == max_epochs * limit_train_batches + 1
|
||||
|
||||
|
||||
|
|
Loading…
Reference in New Issue