Deprecate `TrainerOptimizersMixin` and move functionality to `core/optimizer.py` (#11155)

This commit is contained in:
Danielle Pintz 2021-12-22 17:56:37 -08:00 committed by GitHub
parent 81301dbba7
commit a6a28e08d2
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
13 changed files with 321 additions and 271 deletions

View File

@ -213,6 +213,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Deprecated `Trainer.should_rank_save_checkpoint` Trainer property ([#11068](https://github.com/PyTorchLightning/pytorch-lightning/pull/11068))
- Deprecated `TrainerOptimizersMixin` and moved functionality to `core/optimizer.py`([#11155](https://github.com/PyTorchLightning/pytorch-lightning/pull/11155))
- Deprecated `TrainerCallbackHookMixin` ([#11148](https://github.com/PyTorchLightning/pytorch-lightning/pull/11148))
### Removed
@ -351,6 +354,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Fixed an issue with the `TPUSpawnPlugin` handling the `XLA_USE_BF16` environment variable incorrectly ([#10990](https://github.com/PyTorchLightning/pytorch-lightning/pull/10990))
- Fixed wrong typehint for `Trainer.lightning_optimizers` ([#11155](https://github.com/PyTorchLightning/pytorch-lightning/pull/11155))
## [1.5.7] - 2021-12-21

View File

@ -24,7 +24,7 @@ from torch.optim.swa_utils import SWALR
import pytorch_lightning as pl
from pytorch_lightning.callbacks.base import Callback
from pytorch_lightning.trainer.optimizers import _get_default_scheduler_config
from pytorch_lightning.core.optimizer import _get_default_scheduler_config
from pytorch_lightning.utilities import rank_zero_info, rank_zero_warn
from pytorch_lightning.utilities.exceptions import MisconfigurationException

View File

@ -11,14 +11,17 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import weakref
from contextlib import contextmanager
from typing import Any, Callable, Generator, Optional
from typing import Any, Callable, Dict, Generator, List, Optional, Tuple, Union
from weakref import proxy
import torch
from torch import optim
from torch.optim import Optimizer
import pytorch_lightning as pl
from pytorch_lightning.utilities import AMPType
from pytorch_lightning.utilities import AMPType, rank_zero_warn
from pytorch_lightning.utilities.exceptions import MisconfigurationException
@ -54,7 +57,8 @@ class LightningOptimizer:
return self._optimizer
def _on_trainer_init(self, trainer: "pl.Trainer") -> None:
self._trainer = proxy(trainer)
# check if trainer is already of type weakproxy since we can't call proxy on a weakproxy
self._trainer = trainer if isinstance(trainer, weakref.ProxyType) else proxy(trainer)
for opt_idx, opt in enumerate(trainer.optimizers):
if opt == self._optimizer:
self._optimizer_idx = opt_idx
@ -162,3 +166,227 @@ class LightningOptimizer:
assert trainer is not None
with trainer.profiler.profile(profiler_action):
trainer.strategy.optimizer_step(self._optimizer, self._optimizer_idx, closure, **kwargs)
def _init_optimizers_and_lr_schedulers(model: "pl.LightningModule") -> Tuple[List, List, List]:
"""Calls `LightningModule.configure_optimizers` and parses and validates the output."""
model.trainer._lightning_optimizers = None
optim_conf = model.trainer._call_lightning_module_hook("configure_optimizers", pl_module=model)
if optim_conf is None:
rank_zero_warn(
"`LightningModule.configure_optimizers` returned `None`, this fit will run with no optimizer",
)
optim_conf = _MockOptimizer()
optimizers, lr_schedulers, optimizer_frequencies, monitor = _configure_optimizers(optim_conf)
lr_schedulers = _configure_schedulers(lr_schedulers, monitor, not model.automatic_optimization)
_validate_scheduler_optimizer(optimizers, lr_schedulers)
return optimizers, lr_schedulers, optimizer_frequencies
def _configure_optimizers(
optim_conf: Union[Dict[str, Any], List, Optimizer, Tuple]
) -> Tuple[List, List, List, Optional[str]]:
optimizers, lr_schedulers, optimizer_frequencies = [], [], []
monitor = None
# single output, single optimizer
if isinstance(optim_conf, Optimizer):
optimizers = [optim_conf]
# two lists, optimizer + lr schedulers
elif (
isinstance(optim_conf, (list, tuple))
and len(optim_conf) == 2
and isinstance(optim_conf[0], list)
and all(isinstance(opt, Optimizer) for opt in optim_conf[0])
):
opt, sch = optim_conf
optimizers = opt
lr_schedulers = sch if isinstance(sch, list) else [sch]
# single dictionary
elif isinstance(optim_conf, dict):
_validate_optim_conf(optim_conf)
optimizers = [optim_conf["optimizer"]]
monitor = optim_conf.get("monitor", None)
lr_schedulers = [optim_conf["lr_scheduler"]] if "lr_scheduler" in optim_conf else []
# multiple dictionaries
elif isinstance(optim_conf, (list, tuple)) and all(isinstance(d, dict) for d in optim_conf):
for opt_dict in optim_conf:
_validate_optim_conf(opt_dict)
optimizers = [opt_dict["optimizer"] for opt_dict in optim_conf]
scheduler_dict = (
lambda scheduler, opt_idx: dict(scheduler, opt_idx=opt_idx)
if isinstance(scheduler, dict)
else {"scheduler": scheduler, "opt_idx": opt_idx}
)
lr_schedulers = [
scheduler_dict(opt_dict["lr_scheduler"], opt_idx)
for opt_idx, opt_dict in enumerate(optim_conf)
if "lr_scheduler" in opt_dict
]
optimizer_frequencies = [
opt_dict["frequency"] for opt_dict in optim_conf if opt_dict.get("frequency", None) is not None
]
# assert that if frequencies are present, they are given for all optimizers
if optimizer_frequencies and len(optimizer_frequencies) != len(optimizers):
raise ValueError("A frequency must be given to each optimizer.")
# single list or tuple, multiple optimizer
elif isinstance(optim_conf, (list, tuple)) and all(isinstance(opt, Optimizer) for opt in optim_conf):
optimizers = list(optim_conf)
# unknown configuration
else:
raise MisconfigurationException(
"Unknown configuration for model optimizers."
" Output from `model.configure_optimizers()` should be one of:\n"
" * `Optimizer`\n"
" * [`Optimizer`]\n"
" * ([`Optimizer`], [`_LRScheduler`])\n"
' * {"optimizer": `Optimizer`, (optional) "lr_scheduler": `_LRScheduler`}\n'
' * A list of the previously described dict format, with an optional "frequency" key (int)'
)
return optimizers, lr_schedulers, optimizer_frequencies, monitor
def _configure_schedulers(
schedulers: list, monitor: Optional[str], is_manual_optimization: bool
) -> List[Dict[str, Any]]:
"""Convert each scheduler into dict structure with relevant information."""
lr_schedulers = []
default_config = _get_default_scheduler_config()
# TODO: move is_manual_optimization check out of for loop
for scheduler in schedulers:
if is_manual_optimization:
if isinstance(scheduler, dict):
invalid_keys = {"interval", "frequency", "reduce_on_plateau", "monitor", "strict"}
keys_to_warn = [k for k in scheduler.keys() if k in invalid_keys]
if keys_to_warn:
rank_zero_warn(
f"The lr scheduler dict contains the key(s) {keys_to_warn}, but the keys will be ignored."
" You need to call `lr_scheduler.step()` manually in manual optimization.",
category=RuntimeWarning,
)
scheduler = {key: scheduler[key] for key in scheduler if key not in invalid_keys}
lr_schedulers.append({**default_config, **scheduler})
else:
lr_schedulers.append({**default_config, "scheduler": scheduler})
else:
if isinstance(scheduler, dict):
# check provided keys
extra_keys = [k for k in scheduler.keys() if k not in default_config.keys()]
if extra_keys:
rank_zero_warn(
f"Found unsupported keys in the lr scheduler dict: {extra_keys}", category=RuntimeWarning
)
if "scheduler" not in scheduler:
raise MisconfigurationException(
'The lr scheduler dict must have the key "scheduler" with its item being an lr scheduler'
)
if "interval" in scheduler and scheduler["interval"] not in ("step", "epoch"):
raise MisconfigurationException(
'The "interval" key in lr scheduler dict must be "step" or "epoch"'
f' but is "{scheduler["interval"]}"'
)
scheduler["reduce_on_plateau"] = isinstance(
scheduler["scheduler"], optim.lr_scheduler.ReduceLROnPlateau
)
if scheduler["reduce_on_plateau"] and scheduler.get("monitor", None) is None:
raise MisconfigurationException(
"The lr scheduler dict must include a monitor when a `ReduceLROnPlateau` scheduler is used."
' For example: {"optimizer": optimizer, "lr_scheduler":'
' {"scheduler": scheduler, "monitor": "your_loss"}}'
)
is_one_cycle = isinstance(scheduler["scheduler"], optim.lr_scheduler.OneCycleLR)
if is_one_cycle and scheduler.get("interval", "epoch") == "epoch":
rank_zero_warn(
"A `OneCycleLR` scheduler is using 'interval': 'epoch'."
" Are you sure you didn't mean 'interval': 'step'?",
category=RuntimeWarning,
)
lr_schedulers.append({**default_config, **scheduler})
elif isinstance(scheduler, optim.lr_scheduler.ReduceLROnPlateau):
if monitor is None:
raise MisconfigurationException(
"`configure_optimizers` must include a monitor when a `ReduceLROnPlateau`"
" scheduler is used. For example:"
' {"optimizer": optimizer, "lr_scheduler": scheduler, "monitor": "metric_to_track"}'
)
lr_schedulers.append(
{**default_config, "scheduler": scheduler, "reduce_on_plateau": True, "monitor": monitor}
)
elif isinstance(scheduler, optim.lr_scheduler._LRScheduler):
lr_schedulers.append({**default_config, "scheduler": scheduler})
else:
raise ValueError(f'The provided lr scheduler "{scheduler}" is invalid')
return lr_schedulers
def _get_default_scheduler_config() -> Dict[str, Any]:
return {
"scheduler": None,
"name": None, # no custom name
"interval": "epoch", # after epoch is over
"frequency": 1, # every epoch/batch
"reduce_on_plateau": False, # most often not ReduceLROnPlateau scheduler
"monitor": None, # value to monitor for ReduceLROnPlateau
"strict": True, # enforce that the monitor exists for ReduceLROnPlateau
"opt_idx": None, # necessary to store opt_idx when optimizer frequencies are specified
}
def _validate_scheduler_optimizer(optimizers: List[Any], lr_schedulers: List[Any]) -> None:
if any(sch["scheduler"].optimizer not in optimizers for sch in lr_schedulers):
raise MisconfigurationException(
"Some schedulers are attached with an optimizer that wasn't returned from `configure_optimizers`."
)
def _validate_optim_conf(optim_conf: Dict[str, Any]) -> None:
valid_keys = {"optimizer", "lr_scheduler", "frequency", "monitor"}
extra_keys = optim_conf.keys() - valid_keys
if extra_keys:
rank_zero_warn(
f"Found unsupported keys in the optimizer configuration: {set(extra_keys)}", category=RuntimeWarning
)
def _convert_to_lightning_optimizers(trainer: "pl.Trainer") -> None:
def _convert_to_lightning_optimizer(optimizer: Optimizer) -> LightningOptimizer:
if not isinstance(optimizer, LightningOptimizer):
optimizer = LightningOptimizer(optimizer) # type: ignore [assignment]
optimizer._on_trainer_init(trainer)
return optimizer # type: ignore [return-value]
trainer._lightning_optimizers = { # type: ignore [assignment]
opt_idx: _convert_to_lightning_optimizer(opt) for opt_idx, opt in enumerate(trainer.optimizers)
}
class _MockOptimizer(Optimizer):
"""The `_MockOptimizer` will be used inplace of an optimizer in the event that `None` is returned from
`configure_optimizers`."""
def __init__(self) -> None:
super().__init__([torch.zeros(1)], {})
def add_param_group(self, param_group: Dict[Any, Any]) -> None:
pass # Do Nothing
def load_state_dict(self, state_dict: Dict[Any, Any]) -> None:
pass # Do Nothing
def state_dict(self) -> Dict[str, Any]:
return {} # Return Empty
def step(self, closure: Callable = None) -> None:
if closure is not None:
closure()
def zero_grad(self, set_to_none: Optional[bool] = False) -> None:
pass # Do Nothing
def __repr__(self) -> str:
return "No Optimizer"

View File

@ -31,7 +31,7 @@ from torch.nn import Module
from torch.nn.parallel.distributed import DistributedDataParallel
import pytorch_lightning as pl
from pytorch_lightning.core.optimizer import LightningOptimizer
from pytorch_lightning.core.optimizer import _convert_to_lightning_optimizers, LightningOptimizer
from pytorch_lightning.overrides import LightningDistributedModule
from pytorch_lightning.overrides.distributed import prepare_for_backward
from pytorch_lightning.plugins.environments.cluster_environment import ClusterEnvironment
@ -347,7 +347,7 @@ class DDPStrategy(ParallelStrategy):
del optimizer
trainer = self.lightning_module.trainer
trainer.optimizers = optimizers
trainer.convert_to_lightning_optimizers()
_convert_to_lightning_optimizers(trainer)
def configure_ddp(self) -> None:
self.pre_configure_ddp()

View File

@ -27,11 +27,11 @@ from torch.optim import Optimizer
from torch.optim.lr_scheduler import _LRScheduler
import pytorch_lightning as pl
from pytorch_lightning.core.optimizer import _get_default_scheduler_config, _init_optimizers_and_lr_schedulers
from pytorch_lightning.overrides.base import _LightningModuleWrapperBase
from pytorch_lightning.plugins.environments.cluster_environment import ClusterEnvironment
from pytorch_lightning.plugins.precision import PrecisionPlugin
from pytorch_lightning.strategies.ddp import DDPStrategy
from pytorch_lightning.trainer.optimizers import _get_default_scheduler_config
from pytorch_lightning.trainer.states import TrainerFn
from pytorch_lightning.utilities import GradClipAlgorithmType
from pytorch_lightning.utilities.apply_func import apply_to_collection
@ -446,9 +446,7 @@ class DeepSpeedStrategy(DDPStrategy):
self._initialize_deepspeed_inference(model)
def _init_optimizers(self) -> Tuple[Optimizer, Optional[Union[LRSchedulerTypeTuple]], Optional[int]]:
optimizers, schedulers, optimizer_frequencies = self.lightning_module.trainer.init_optimizers(
self.lightning_module
)
optimizers, schedulers, optimizer_frequencies = _init_optimizers_and_lr_schedulers(self.lightning_module)
if len(optimizers) > 1 or len(schedulers) > 1:
raise MisconfigurationException(
"DeepSpeed currently only supports single optimizer, single optional scheduler."

View File

@ -19,7 +19,7 @@ from torch.nn import Module
from torch.optim import Optimizer
import pytorch_lightning as pl
from pytorch_lightning.core.optimizer import LightningOptimizer
from pytorch_lightning.core.optimizer import _convert_to_lightning_optimizers, LightningOptimizer
from pytorch_lightning.strategies.ddp import DDPStrategy
from pytorch_lightning.trainer.states import TrainerFn
from pytorch_lightning.utilities import _FAIRSCALE_AVAILABLE, _FAIRSCALE_OSS_FP16_BROADCAST_AVAILABLE, rank_zero_only
@ -50,7 +50,7 @@ class DDPShardedStrategy(DDPStrategy):
optimizers=trainer.optimizers,
)
trainer.optimizers = optimizers
trainer.convert_to_lightning_optimizers()
_convert_to_lightning_optimizers(trainer)
def _setup_model_and_optimizers(self, model: Module, optimizers: List[Optimizer]) -> Tuple[Module, List[Optimizer]]:
"""Wraps the model and optimizers with fairscale components.

View File

@ -22,6 +22,7 @@ from torch.optim import Optimizer
from torch.utils.data import DataLoader
import pytorch_lightning as pl
from pytorch_lightning.core.optimizer import _init_optimizers_and_lr_schedulers
from pytorch_lightning.overrides.base import unwrap_lightning_module
from pytorch_lightning.plugins import TorchCheckpointIO
from pytorch_lightning.plugins.io.checkpoint_plugin import CheckpointIO
@ -377,7 +378,7 @@ class Strategy(ABC):
return dataloader
def init_optimizers(self, trainer: "pl.Trainer", model: "pl.LightningModule"):
return trainer.init_optimizers(model)
return _init_optimizers_and_lr_schedulers(model)
@property
def restore_checkpoint_after_setup(self) -> bool:

View File

@ -13,239 +13,43 @@
# limitations under the License.
from abc import ABC
from typing import Any, Dict, List, Optional, Tuple, Union
import torch
from torch import optim
from torch.optim.optimizer import Optimizer
from typing import List, Optional, Tuple
import pytorch_lightning as pl
from pytorch_lightning.core.optimizer import LightningOptimizer
from pytorch_lightning.utilities import rank_zero_warn
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.types import LRSchedulerConfig
from pytorch_lightning.core.optimizer import (
_convert_to_lightning_optimizers,
_init_optimizers_and_lr_schedulers,
LightningOptimizer,
)
from pytorch_lightning.utilities import rank_zero_deprecation
class TrainerOptimizersMixin(ABC):
r"""
.. deprecated:: v1.6
The `TrainerOptimizersMixin` was deprecated in v1.6 and will be removed in v1.8.
"""
_lightning_optimizers: Optional[List[LightningOptimizer]]
def init_optimizers(self, model: Optional["pl.LightningModule"]) -> Tuple[List, List, List]:
r"""
.. deprecated:: v1.6
`TrainerOptimizersMixin.init_optimizers` was deprecated in v1.6 and will be removed in v1.8.
"""
rank_zero_deprecation(
"`TrainerOptimizersMixin.init_optimizers` was deprecated in v1.6 and will be removed in v1.8."
)
pl_module = self.lightning_module or model
self._lightning_optimizers = None
optim_conf = self._call_lightning_module_hook("configure_optimizers", pl_module=pl_module)
if optim_conf is None:
rank_zero_warn(
"`LightningModule.configure_optimizers` returned `None`, this fit will run with no optimizer",
)
optim_conf = _MockOptimizer()
optimizers, lr_schedulers, optimizer_frequencies, monitor = self._configure_optimizers(optim_conf)
lr_schedulers = self._configure_schedulers(lr_schedulers, monitor, not pl_module.automatic_optimization)
_validate_scheduler_optimizer(optimizers, lr_schedulers)
return optimizers, lr_schedulers, optimizer_frequencies
@staticmethod
def _configure_optimizers(
optim_conf: Union[Dict[str, Any], List, Optimizer, Tuple]
) -> Tuple[List, List, List, Optional[str]]:
optimizers, lr_schedulers, optimizer_frequencies = [], [], []
monitor = None
# single output, single optimizer
if isinstance(optim_conf, Optimizer):
optimizers = [optim_conf]
# two lists, optimizer + lr schedulers
elif (
isinstance(optim_conf, (list, tuple))
and len(optim_conf) == 2
and isinstance(optim_conf[0], list)
and all(isinstance(opt, Optimizer) for opt in optim_conf[0])
):
opt, sch = optim_conf
optimizers = opt
lr_schedulers = sch if isinstance(sch, list) else [sch]
# single dictionary
elif isinstance(optim_conf, dict):
_validate_optim_conf(optim_conf)
optimizers = [optim_conf["optimizer"]]
monitor = optim_conf.get("monitor", None)
lr_schedulers = [optim_conf["lr_scheduler"]] if "lr_scheduler" in optim_conf else []
# multiple dictionaries
elif isinstance(optim_conf, (list, tuple)) and all(isinstance(d, dict) for d in optim_conf):
for opt_dict in optim_conf:
_validate_optim_conf(opt_dict)
optimizers = [opt_dict["optimizer"] for opt_dict in optim_conf]
scheduler_dict = (
lambda scheduler, opt_idx: dict(scheduler, opt_idx=opt_idx)
if isinstance(scheduler, dict)
else {"scheduler": scheduler, "opt_idx": opt_idx}
)
lr_schedulers = [
scheduler_dict(opt_dict["lr_scheduler"], opt_idx)
for opt_idx, opt_dict in enumerate(optim_conf)
if "lr_scheduler" in opt_dict
]
optimizer_frequencies = [
opt_dict["frequency"] for opt_dict in optim_conf if opt_dict.get("frequency", None) is not None
]
# assert that if frequencies are present, they are given for all optimizers
if optimizer_frequencies and len(optimizer_frequencies) != len(optimizers):
raise ValueError("A frequency must be given to each optimizer.")
# single list or tuple, multiple optimizer
elif isinstance(optim_conf, (list, tuple)) and all(isinstance(opt, Optimizer) for opt in optim_conf):
optimizers = list(optim_conf)
# unknown configuration
else:
raise MisconfigurationException(
"Unknown configuration for model optimizers."
" Output from `model.configure_optimizers()` should either be:\n"
" * `torch.optim.Optimizer`\n"
" * [`torch.optim.Optimizer`]\n"
" * ([`torch.optim.Optimizer`], [`torch.optim.lr_scheduler`])\n"
' * {"optimizer": `torch.optim.Optimizer`, (optional) "lr_scheduler": `torch.optim.lr_scheduler`}\n'
' * A list of the previously described dict format, with an optional "frequency" key (int)'
)
return optimizers, lr_schedulers, optimizer_frequencies, monitor
return _init_optimizers_and_lr_schedulers(pl_module)
def convert_to_lightning_optimizers(self):
def _convert_to_lightning_optimizer(trainer, optimizer):
if not isinstance(optimizer, LightningOptimizer):
optimizer = LightningOptimizer(optimizer)
optimizer._on_trainer_init(trainer)
return optimizer
self._lightning_optimizers = {
opt_idx: _convert_to_lightning_optimizer(self, opt) for opt_idx, opt in enumerate(self.optimizers)
}
@staticmethod
def _configure_schedulers(
schedulers: list, monitor: Optional[str], is_manual_optimization: bool
) -> List[LRSchedulerConfig]:
"""Convert each scheduler into dict structure with relevant information."""
lr_schedulers = []
default_config = _get_default_scheduler_config()
for scheduler in schedulers:
if is_manual_optimization:
if isinstance(scheduler, dict):
invalid_keys = {"interval", "frequency", "reduce_on_plateau", "monitor", "strict"}
keys_to_warn = [k for k in scheduler.keys() if k in invalid_keys]
if keys_to_warn:
rank_zero_warn(
f"The lr scheduler dict contains the key(s) {keys_to_warn}, but the keys will be ignored."
" You need to call `lr_scheduler.step()` manually in manual optimization.",
category=RuntimeWarning,
)
scheduler = {key: scheduler[key] for key in scheduler if key not in invalid_keys}
lr_schedulers.append({**default_config, **scheduler})
else:
lr_schedulers.append({**default_config, "scheduler": scheduler})
else:
if isinstance(scheduler, dict):
# check provided keys
extra_keys = [k for k in scheduler.keys() if k not in default_config.keys()]
if extra_keys:
rank_zero_warn(
f"Found unsupported keys in the lr scheduler dict: {extra_keys}", category=RuntimeWarning
)
if "scheduler" not in scheduler:
raise MisconfigurationException(
'The lr scheduler dict must have the key "scheduler" with its item being an lr scheduler'
)
if "interval" in scheduler and scheduler["interval"] not in ("step", "epoch"):
raise MisconfigurationException(
'The "interval" key in lr scheduler dict must be "step" or "epoch"'
f' but is "{scheduler["interval"]}"'
)
scheduler["reduce_on_plateau"] = isinstance(
scheduler["scheduler"], optim.lr_scheduler.ReduceLROnPlateau
)
if scheduler["reduce_on_plateau"] and scheduler.get("monitor", None) is None:
raise MisconfigurationException(
"The lr scheduler dict must include a monitor when a `ReduceLROnPlateau` scheduler is used."
' For example: {"optimizer": optimizer, "lr_scheduler":'
' {"scheduler": scheduler, "monitor": "your_loss"}}'
)
is_one_cycle = isinstance(scheduler["scheduler"], optim.lr_scheduler.OneCycleLR)
if is_one_cycle and scheduler.get("interval", "epoch") == "epoch":
rank_zero_warn(
"A `OneCycleLR` scheduler is using 'interval': 'epoch'."
" Are you sure you didn't mean 'interval': 'step'?",
category=RuntimeWarning,
)
lr_schedulers.append({**default_config, **scheduler})
elif isinstance(scheduler, optim.lr_scheduler.ReduceLROnPlateau):
if monitor is None:
raise MisconfigurationException(
"`configure_optimizers` must include a monitor when a `ReduceLROnPlateau`"
" scheduler is used. For example:"
' {"optimizer": optimizer, "lr_scheduler": scheduler, "monitor": "metric_to_track"}'
)
lr_schedulers.append(
{**default_config, "scheduler": scheduler, "reduce_on_plateau": True, "monitor": monitor}
)
elif isinstance(scheduler, optim.lr_scheduler._LRScheduler):
lr_schedulers.append({**default_config, "scheduler": scheduler})
else:
raise ValueError(f'The provided lr scheduler "{scheduler}" is invalid')
return lr_schedulers
class _MockOptimizer(Optimizer):
"""The `_MockOptimizer` will be used inplace of an optimizer in the event that `None` is returned from
`configure_optimizers`."""
def __init__(self):
super().__init__([torch.zeros(1)], {})
def add_param_group(self, param_group):
pass # Do Nothing
def load_state_dict(self, state_dict):
pass # Do Nothing
def state_dict(self):
return {} # Return Empty
def step(self, closure=None):
if closure is not None:
closure()
def zero_grad(self):
pass # Do Nothing
def __repr__(self):
return "No Optimizer"
def _validate_optim_conf(optim_conf: Dict[str, Any]) -> None:
valid_keys = {"optimizer", "lr_scheduler", "frequency", "monitor"}
extra_keys = optim_conf.keys() - valid_keys
if extra_keys:
rank_zero_warn(
f"Found unsupported keys in the optimizer configuration: {set(extra_keys)}", category=RuntimeWarning
r"""
.. deprecated:: v1.6
`TrainerOptimizersMixin.convert_to_lightning_optimizers` was deprecated in v1.6 and will be removed in v1.8.
"""
rank_zero_deprecation(
"`TrainerOptimizersMixin.convert_to_lightning_optimizers` was deprecated in v1.6 and will be removed in "
"v1.8."
)
def _validate_scheduler_optimizer(optimizers, lr_schedulers):
if any(sch["scheduler"].optimizer not in optimizers for sch in lr_schedulers):
raise MisconfigurationException(
"Some schedulers are attached with an optimizer that wasn't returned from `configure_optimizers`."
)
def _get_default_scheduler_config() -> Dict[str, Any]:
return {
"scheduler": None,
"name": None, # no custom name
"interval": "epoch", # after epoch is over
"frequency": 1, # every epoch/batch
"reduce_on_plateau": False, # most often not ReduceLROnPlateau scheduler
"monitor": None, # value to monitor for ReduceLROnPlateau
"strict": True, # enforce that the monitor exists for ReduceLROnPlateau
"opt_idx": None, # necessary to store opt_idx when optimizer frequencies are specified
}
_convert_to_lightning_optimizers(self)

View File

@ -33,7 +33,7 @@ from pytorch_lightning.accelerators import Accelerator, IPUAccelerator
from pytorch_lightning.callbacks import Callback, EarlyStopping, ModelCheckpoint, ProgressBarBase
from pytorch_lightning.callbacks.prediction_writer import BasePredictionWriter
from pytorch_lightning.core.datamodule import LightningDataModule
from pytorch_lightning.core.optimizer import LightningOptimizer
from pytorch_lightning.core.optimizer import _convert_to_lightning_optimizers, LightningOptimizer
from pytorch_lightning.loggers import LightningLoggerBase
from pytorch_lightning.loggers.base import DummyLogger, LoggerCollection
from pytorch_lightning.loggers.tensorboard import TensorBoardLogger
@ -480,8 +480,7 @@ class Trainer(
# default .predict() loop
self.predict_loop = PredictionLoop()
# Needed because of LightningOptimizer
self._lightning_optimizers = None
self._lightning_optimizers: Optional[Dict[int, LightningOptimizer]] = None
# .validate() and .test() set this when they load a checkpoint
self.validated_ckpt_path: Optional[str] = None
@ -1926,9 +1925,9 @@ class Trainer(
return SLURMEnvironment.job_id()
@property
def lightning_optimizers(self) -> List[LightningOptimizer]:
def lightning_optimizers(self) -> Dict[int, LightningOptimizer]:
if self._lightning_optimizers is None:
self.convert_to_lightning_optimizers()
_convert_to_lightning_optimizers(self)
return self._lightning_optimizers
@property

View File

@ -24,8 +24,8 @@ from torch.optim.lr_scheduler import _LRScheduler
import pytorch_lightning as pl
from pytorch_lightning.callbacks import Callback
from pytorch_lightning.core.optimizer import _get_default_scheduler_config
from pytorch_lightning.loggers.base import DummyLogger
from pytorch_lightning.trainer.optimizers import _get_default_scheduler_config
from pytorch_lightning.utilities import rank_zero_warn
from pytorch_lightning.utilities.cloud_io import get_filesystem
from pytorch_lightning.utilities.exceptions import MisconfigurationException
@ -98,15 +98,15 @@ class _LRFinder:
self.results = {}
self._total_batch_idx = 0 # for debug purpose
def _exchange_scheduler(self, trainer: "pl.Trainer"):
"""Decorate `trainer.init_optimizers` method such that it returns the users originally specified optimizer
together with a new scheduler that that takes care of the learning rate search."""
init_optimizers = trainer.init_optimizers
def _exchange_scheduler(self, trainer: "pl.Trainer", model: "pl.LightningModule"):
"""Decorate `trainer.strategy.init_optimizers` method such that it returns the user's originally specified
optimizer together with a new scheduler that that takes care of the learning rate search."""
init_optimizers = trainer.strategy.init_optimizers
@wraps(init_optimizers)
def func(model):
# Decide the structure of the output from init_optimizers
optimizers, _, _ = init_optimizers(model)
def func(trainer, model):
# Decide the structure of the output from trainer.strategy.init_optimizers
optimizers, _, _ = init_optimizers(trainer, model)
if len(optimizers) != 1:
raise MisconfigurationException(
@ -232,7 +232,7 @@ def lr_find(
trainer.save_checkpoint(str(save_path))
# Configure optimizer and scheduler
trainer.init_optimizers = lr_finder._exchange_scheduler(trainer)
trainer.strategy.init_optimizers = lr_finder._exchange_scheduler(trainer, model)
# Fit, lr & loss logged in callback
trainer.tuner._run(model)
@ -278,7 +278,6 @@ def __lr_finder_dump_params(trainer, model):
"max_steps": trainer.max_steps,
"checkpoint_callback": trainer.checkpoint_callback,
"current_epoch": trainer.current_epoch,
"init_optimizers": trainer.init_optimizers,
}
@ -289,7 +288,6 @@ def __lr_finder_restore_params(trainer, model):
trainer.fit_loop.global_step = trainer.__dumped_params["global_step"]
trainer.fit_loop.max_steps = trainer.__dumped_params["max_steps"]
trainer.fit_loop.current_epoch = trainer.__dumped_params["current_epoch"]
trainer.init_optimizers = trainer.__dumped_params["init_optimizers"]
del trainer.__dumped_params

View File

@ -140,6 +140,24 @@ def test_v1_8_0_deprecated_trainer_should_rank_save_checkpoint(tmpdir):
_ = trainer.should_rank_save_checkpoint
def test_v1_8_0_trainer_optimizers_mixin():
trainer = Trainer()
model = BoringModel()
trainer.strategy.connect(model)
trainer.lightning_module.trainer = trainer
with pytest.deprecated_call(
match=r"`TrainerOptimizersMixin.init_optimizers` was deprecated in v1.6 and will be removed in v1.8."
):
trainer.init_optimizers(model)
with pytest.deprecated_call(
match=r"`TrainerOptimizersMixin.convert_to_lightning_optimizers` was deprecated in v1.6 and will be removed in "
"v1.8."
):
trainer.convert_to_lightning_optimizers()
def test_v1_8_0_deprecate_trainer_callback_hook_mixin():
methods_with_self = [
"on_before_accelerator_backend_setup",

View File

@ -15,7 +15,6 @@ import torch
from torchmetrics.functional import accuracy
from pytorch_lightning import LightningDataModule, LightningModule, Trainer
from pytorch_lightning.utilities import _StrategyType
from tests.helpers import BoringModel
from tests.helpers.utils import get_default_logger, load_model_from_checkpoint, reset_seed
@ -70,7 +69,7 @@ def run_model_test(
assert change_ratio > 0.03, f"the model is changed of {change_ratio}"
# test model loading
pretrained_model = load_model_from_checkpoint(logger, trainer.checkpoint_callback.best_model_path, type(model))
_ = load_model_from_checkpoint(logger, trainer.checkpoint_callback.best_model_path, type(model))
# test new model accuracy
test_loaders = model.test_dataloader() if not data else data.test_dataloader()
@ -82,12 +81,6 @@ def run_model_test(
run_model_prediction(model, dataloader, min_acc=min_acc)
if with_hpc:
if trainer._distrib_type in (_StrategyType.DDP, _StrategyType.DDP_SPAWN, _StrategyType.DDP2):
# on hpc this would work fine... but need to hack it for the purpose of the test
trainer.optimizers, trainer.lr_schedulers, trainer.optimizer_frequencies = trainer.init_optimizers(
pretrained_model
)
# test HPC saving
trainer.checkpoint_connector.hpc_save(save_dir, logger)
# test HPC loading

View File

@ -19,7 +19,11 @@ from torch import optim
from pytorch_lightning import Callback, Trainer
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.trainer.optimizers import TrainerOptimizersMixin
from pytorch_lightning.core.optimizer import (
_configure_optimizers,
_configure_schedulers,
_init_optimizers_and_lr_schedulers,
)
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from tests.helpers.boring_model import BoringDataModule, BoringModel
from tests.helpers.runif import RunIf
@ -106,7 +110,7 @@ def test_onecyclelr_with_epoch_interval_warns():
optimizer = optim.Adam(model.parameters())
lr_scheduler = {"scheduler": optim.lr_scheduler.OneCycleLR(optimizer, max_lr=0.01, total_steps=3)}
with pytest.warns(RuntimeWarning, match="Are you sure you didn't mean 'interval': 'step'?"):
TrainerOptimizersMixin._configure_schedulers([lr_scheduler], None, False)
_configure_schedulers([lr_scheduler], None, False)
def test_reducelronplateau_scheduling(tmpdir):
@ -144,6 +148,8 @@ def test_reducelronplateau_scheduling(tmpdir):
def test_optimizer_return_options(tmpdir):
trainer = Trainer(default_root_dir=tmpdir)
model = BoringModel()
trainer.strategy.connect(model)
trainer.lightning_module.trainer = trainer
# single optimizer
opt_a = optim.Adam(model.parameters(), lr=0.002)
@ -153,18 +159,18 @@ def test_optimizer_return_options(tmpdir):
# single optimizer
model.configure_optimizers = lambda: opt_a
opt, lr_sched, freq = trainer.init_optimizers(model)
opt, lr_sched, freq = _init_optimizers_and_lr_schedulers(model)
assert len(opt) == 1 and len(lr_sched) == len(freq) == 0
# opt tuple
model.configure_optimizers = lambda: (opt_a, opt_b)
opt, lr_sched, freq = trainer.init_optimizers(model)
opt, lr_sched, freq = _init_optimizers_and_lr_schedulers(model)
assert opt == [opt_a, opt_b]
assert len(lr_sched) == len(freq) == 0
# opt list
model.configure_optimizers = lambda: [opt_a, opt_b]
opt, lr_sched, freq = trainer.init_optimizers(model)
opt, lr_sched, freq = _init_optimizers_and_lr_schedulers(model)
assert opt == [opt_a, opt_b]
assert len(lr_sched) == len(freq) == 0
@ -181,7 +187,7 @@ def test_optimizer_return_options(tmpdir):
# opt tuple of 2 lists
model.configure_optimizers = lambda: ([opt_a], [scheduler_a])
opt, lr_sched, freq = trainer.init_optimizers(model)
opt, lr_sched, freq = _init_optimizers_and_lr_schedulers(model)
assert len(opt) == len(lr_sched) == 1
assert len(freq) == 0
assert opt[0] == opt_a
@ -189,7 +195,7 @@ def test_optimizer_return_options(tmpdir):
# opt tuple of 1 list
model.configure_optimizers = lambda: ([opt_a], scheduler_a)
opt, lr_sched, freq = trainer.init_optimizers(model)
opt, lr_sched, freq = _init_optimizers_and_lr_schedulers(model)
assert len(opt) == len(lr_sched) == 1
assert len(freq) == 0
assert opt[0] == opt_a
@ -197,7 +203,7 @@ def test_optimizer_return_options(tmpdir):
# opt single dictionary
model.configure_optimizers = lambda: {"optimizer": opt_a, "lr_scheduler": scheduler_a}
opt, lr_sched, freq = trainer.init_optimizers(model)
opt, lr_sched, freq = _init_optimizers_and_lr_schedulers(model)
assert len(opt) == len(lr_sched) == 1
assert len(freq) == 0
assert opt[0] == opt_a
@ -208,7 +214,7 @@ def test_optimizer_return_options(tmpdir):
{"optimizer": opt_a, "lr_scheduler": scheduler_a, "frequency": 1},
{"optimizer": opt_b, "lr_scheduler": scheduler_b, "frequency": 5},
)
opt, lr_sched, freq = trainer.init_optimizers(model)
opt, lr_sched, freq = _init_optimizers_and_lr_schedulers(model)
assert len(opt) == len(lr_sched) == len(freq) == 2
assert opt[0] == opt_a
ref_lr_sched["opt_idx"] = 0
@ -436,7 +442,7 @@ def test_optimizer_config_dict_with_extra_keys_warns(tmpdir):
"bar": 2,
}
with pytest.warns(RuntimeWarning, match=r"Found unsupported keys in the optimizer configuration: \{.+\}"):
TrainerOptimizersMixin._configure_optimizers(optim_conf)
_configure_optimizers(optim_conf)
def test_multiple_optimizer_config_dicts_with_extra_keys_warns(tmpdir):
@ -451,7 +457,7 @@ def test_multiple_optimizer_config_dicts_with_extra_keys_warns(tmpdir):
{"optimizer": optimizer2, "lr_scheduler": lr_scheduler_config_2, "foo": 1, "bar": 2},
]
with pytest.warns(RuntimeWarning, match=r"Found unsupported keys in the optimizer configuration: \{.+\}"):
TrainerOptimizersMixin._configure_optimizers(optim_conf)
_configure_optimizers(optim_conf)
def test_lr_scheduler_with_unknown_interval_raises(tmpdir):