diff --git a/CHANGELOG.md b/CHANGELOG.md index bb326232fc..4bfd337681 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -213,6 +213,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Deprecated `Trainer.should_rank_save_checkpoint` Trainer property ([#11068](https://github.com/PyTorchLightning/pytorch-lightning/pull/11068)) +- Deprecated `TrainerOptimizersMixin` and moved functionality to `core/optimizer.py`([#11155](https://github.com/PyTorchLightning/pytorch-lightning/pull/11155)) + + - Deprecated `TrainerCallbackHookMixin` ([#11148](https://github.com/PyTorchLightning/pytorch-lightning/pull/11148)) ### Removed @@ -351,6 +354,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixed an issue with the `TPUSpawnPlugin` handling the `XLA_USE_BF16` environment variable incorrectly ([#10990](https://github.com/PyTorchLightning/pytorch-lightning/pull/10990)) +- Fixed wrong typehint for `Trainer.lightning_optimizers` ([#11155](https://github.com/PyTorchLightning/pytorch-lightning/pull/11155)) + ## [1.5.7] - 2021-12-21 diff --git a/pytorch_lightning/callbacks/stochastic_weight_avg.py b/pytorch_lightning/callbacks/stochastic_weight_avg.py index bde9c1b5c2..504d477d3f 100644 --- a/pytorch_lightning/callbacks/stochastic_weight_avg.py +++ b/pytorch_lightning/callbacks/stochastic_weight_avg.py @@ -24,7 +24,7 @@ from torch.optim.swa_utils import SWALR import pytorch_lightning as pl from pytorch_lightning.callbacks.base import Callback -from pytorch_lightning.trainer.optimizers import _get_default_scheduler_config +from pytorch_lightning.core.optimizer import _get_default_scheduler_config from pytorch_lightning.utilities import rank_zero_info, rank_zero_warn from pytorch_lightning.utilities.exceptions import MisconfigurationException diff --git a/pytorch_lightning/core/optimizer.py b/pytorch_lightning/core/optimizer.py index c67decf97c..af036cdd65 100644 --- a/pytorch_lightning/core/optimizer.py +++ b/pytorch_lightning/core/optimizer.py @@ -11,14 +11,17 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import weakref from contextlib import contextmanager -from typing import Any, Callable, Generator, Optional +from typing import Any, Callable, Dict, Generator, List, Optional, Tuple, Union from weakref import proxy +import torch +from torch import optim from torch.optim import Optimizer import pytorch_lightning as pl -from pytorch_lightning.utilities import AMPType +from pytorch_lightning.utilities import AMPType, rank_zero_warn from pytorch_lightning.utilities.exceptions import MisconfigurationException @@ -54,7 +57,8 @@ class LightningOptimizer: return self._optimizer def _on_trainer_init(self, trainer: "pl.Trainer") -> None: - self._trainer = proxy(trainer) + # check if trainer is already of type weakproxy since we can't call proxy on a weakproxy + self._trainer = trainer if isinstance(trainer, weakref.ProxyType) else proxy(trainer) for opt_idx, opt in enumerate(trainer.optimizers): if opt == self._optimizer: self._optimizer_idx = opt_idx @@ -162,3 +166,227 @@ class LightningOptimizer: assert trainer is not None with trainer.profiler.profile(profiler_action): trainer.strategy.optimizer_step(self._optimizer, self._optimizer_idx, closure, **kwargs) + + +def _init_optimizers_and_lr_schedulers(model: "pl.LightningModule") -> Tuple[List, List, List]: + """Calls `LightningModule.configure_optimizers` and parses and validates the output.""" + model.trainer._lightning_optimizers = None + optim_conf = model.trainer._call_lightning_module_hook("configure_optimizers", pl_module=model) + + if optim_conf is None: + rank_zero_warn( + "`LightningModule.configure_optimizers` returned `None`, this fit will run with no optimizer", + ) + optim_conf = _MockOptimizer() + + optimizers, lr_schedulers, optimizer_frequencies, monitor = _configure_optimizers(optim_conf) + lr_schedulers = _configure_schedulers(lr_schedulers, monitor, not model.automatic_optimization) + _validate_scheduler_optimizer(optimizers, lr_schedulers) + return optimizers, lr_schedulers, optimizer_frequencies + + +def _configure_optimizers( + optim_conf: Union[Dict[str, Any], List, Optimizer, Tuple] +) -> Tuple[List, List, List, Optional[str]]: + optimizers, lr_schedulers, optimizer_frequencies = [], [], [] + monitor = None + + # single output, single optimizer + if isinstance(optim_conf, Optimizer): + optimizers = [optim_conf] + # two lists, optimizer + lr schedulers + elif ( + isinstance(optim_conf, (list, tuple)) + and len(optim_conf) == 2 + and isinstance(optim_conf[0], list) + and all(isinstance(opt, Optimizer) for opt in optim_conf[0]) + ): + opt, sch = optim_conf + optimizers = opt + lr_schedulers = sch if isinstance(sch, list) else [sch] + # single dictionary + elif isinstance(optim_conf, dict): + _validate_optim_conf(optim_conf) + optimizers = [optim_conf["optimizer"]] + monitor = optim_conf.get("monitor", None) + lr_schedulers = [optim_conf["lr_scheduler"]] if "lr_scheduler" in optim_conf else [] + # multiple dictionaries + elif isinstance(optim_conf, (list, tuple)) and all(isinstance(d, dict) for d in optim_conf): + for opt_dict in optim_conf: + _validate_optim_conf(opt_dict) + optimizers = [opt_dict["optimizer"] for opt_dict in optim_conf] + scheduler_dict = ( + lambda scheduler, opt_idx: dict(scheduler, opt_idx=opt_idx) + if isinstance(scheduler, dict) + else {"scheduler": scheduler, "opt_idx": opt_idx} + ) + + lr_schedulers = [ + scheduler_dict(opt_dict["lr_scheduler"], opt_idx) + for opt_idx, opt_dict in enumerate(optim_conf) + if "lr_scheduler" in opt_dict + ] + optimizer_frequencies = [ + opt_dict["frequency"] for opt_dict in optim_conf if opt_dict.get("frequency", None) is not None + ] + # assert that if frequencies are present, they are given for all optimizers + if optimizer_frequencies and len(optimizer_frequencies) != len(optimizers): + raise ValueError("A frequency must be given to each optimizer.") + # single list or tuple, multiple optimizer + elif isinstance(optim_conf, (list, tuple)) and all(isinstance(opt, Optimizer) for opt in optim_conf): + optimizers = list(optim_conf) + # unknown configuration + else: + raise MisconfigurationException( + "Unknown configuration for model optimizers." + " Output from `model.configure_optimizers()` should be one of:\n" + " * `Optimizer`\n" + " * [`Optimizer`]\n" + " * ([`Optimizer`], [`_LRScheduler`])\n" + ' * {"optimizer": `Optimizer`, (optional) "lr_scheduler": `_LRScheduler`}\n' + ' * A list of the previously described dict format, with an optional "frequency" key (int)' + ) + return optimizers, lr_schedulers, optimizer_frequencies, monitor + + +def _configure_schedulers( + schedulers: list, monitor: Optional[str], is_manual_optimization: bool +) -> List[Dict[str, Any]]: + """Convert each scheduler into dict structure with relevant information.""" + lr_schedulers = [] + default_config = _get_default_scheduler_config() + # TODO: move is_manual_optimization check out of for loop + for scheduler in schedulers: + if is_manual_optimization: + if isinstance(scheduler, dict): + invalid_keys = {"interval", "frequency", "reduce_on_plateau", "monitor", "strict"} + keys_to_warn = [k for k in scheduler.keys() if k in invalid_keys] + + if keys_to_warn: + rank_zero_warn( + f"The lr scheduler dict contains the key(s) {keys_to_warn}, but the keys will be ignored." + " You need to call `lr_scheduler.step()` manually in manual optimization.", + category=RuntimeWarning, + ) + + scheduler = {key: scheduler[key] for key in scheduler if key not in invalid_keys} + lr_schedulers.append({**default_config, **scheduler}) + else: + lr_schedulers.append({**default_config, "scheduler": scheduler}) + else: + if isinstance(scheduler, dict): + # check provided keys + extra_keys = [k for k in scheduler.keys() if k not in default_config.keys()] + if extra_keys: + rank_zero_warn( + f"Found unsupported keys in the lr scheduler dict: {extra_keys}", category=RuntimeWarning + ) + if "scheduler" not in scheduler: + raise MisconfigurationException( + 'The lr scheduler dict must have the key "scheduler" with its item being an lr scheduler' + ) + if "interval" in scheduler and scheduler["interval"] not in ("step", "epoch"): + raise MisconfigurationException( + 'The "interval" key in lr scheduler dict must be "step" or "epoch"' + f' but is "{scheduler["interval"]}"' + ) + scheduler["reduce_on_plateau"] = isinstance( + scheduler["scheduler"], optim.lr_scheduler.ReduceLROnPlateau + ) + if scheduler["reduce_on_plateau"] and scheduler.get("monitor", None) is None: + raise MisconfigurationException( + "The lr scheduler dict must include a monitor when a `ReduceLROnPlateau` scheduler is used." + ' For example: {"optimizer": optimizer, "lr_scheduler":' + ' {"scheduler": scheduler, "monitor": "your_loss"}}' + ) + is_one_cycle = isinstance(scheduler["scheduler"], optim.lr_scheduler.OneCycleLR) + if is_one_cycle and scheduler.get("interval", "epoch") == "epoch": + rank_zero_warn( + "A `OneCycleLR` scheduler is using 'interval': 'epoch'." + " Are you sure you didn't mean 'interval': 'step'?", + category=RuntimeWarning, + ) + lr_schedulers.append({**default_config, **scheduler}) + elif isinstance(scheduler, optim.lr_scheduler.ReduceLROnPlateau): + if monitor is None: + raise MisconfigurationException( + "`configure_optimizers` must include a monitor when a `ReduceLROnPlateau`" + " scheduler is used. For example:" + ' {"optimizer": optimizer, "lr_scheduler": scheduler, "monitor": "metric_to_track"}' + ) + lr_schedulers.append( + {**default_config, "scheduler": scheduler, "reduce_on_plateau": True, "monitor": monitor} + ) + elif isinstance(scheduler, optim.lr_scheduler._LRScheduler): + lr_schedulers.append({**default_config, "scheduler": scheduler}) + else: + raise ValueError(f'The provided lr scheduler "{scheduler}" is invalid') + return lr_schedulers + + +def _get_default_scheduler_config() -> Dict[str, Any]: + return { + "scheduler": None, + "name": None, # no custom name + "interval": "epoch", # after epoch is over + "frequency": 1, # every epoch/batch + "reduce_on_plateau": False, # most often not ReduceLROnPlateau scheduler + "monitor": None, # value to monitor for ReduceLROnPlateau + "strict": True, # enforce that the monitor exists for ReduceLROnPlateau + "opt_idx": None, # necessary to store opt_idx when optimizer frequencies are specified + } + + +def _validate_scheduler_optimizer(optimizers: List[Any], lr_schedulers: List[Any]) -> None: + if any(sch["scheduler"].optimizer not in optimizers for sch in lr_schedulers): + raise MisconfigurationException( + "Some schedulers are attached with an optimizer that wasn't returned from `configure_optimizers`." + ) + + +def _validate_optim_conf(optim_conf: Dict[str, Any]) -> None: + valid_keys = {"optimizer", "lr_scheduler", "frequency", "monitor"} + extra_keys = optim_conf.keys() - valid_keys + if extra_keys: + rank_zero_warn( + f"Found unsupported keys in the optimizer configuration: {set(extra_keys)}", category=RuntimeWarning + ) + + +def _convert_to_lightning_optimizers(trainer: "pl.Trainer") -> None: + def _convert_to_lightning_optimizer(optimizer: Optimizer) -> LightningOptimizer: + if not isinstance(optimizer, LightningOptimizer): + optimizer = LightningOptimizer(optimizer) # type: ignore [assignment] + optimizer._on_trainer_init(trainer) + return optimizer # type: ignore [return-value] + + trainer._lightning_optimizers = { # type: ignore [assignment] + opt_idx: _convert_to_lightning_optimizer(opt) for opt_idx, opt in enumerate(trainer.optimizers) + } + + +class _MockOptimizer(Optimizer): + """The `_MockOptimizer` will be used inplace of an optimizer in the event that `None` is returned from + `configure_optimizers`.""" + + def __init__(self) -> None: + super().__init__([torch.zeros(1)], {}) + + def add_param_group(self, param_group: Dict[Any, Any]) -> None: + pass # Do Nothing + + def load_state_dict(self, state_dict: Dict[Any, Any]) -> None: + pass # Do Nothing + + def state_dict(self) -> Dict[str, Any]: + return {} # Return Empty + + def step(self, closure: Callable = None) -> None: + if closure is not None: + closure() + + def zero_grad(self, set_to_none: Optional[bool] = False) -> None: + pass # Do Nothing + + def __repr__(self) -> str: + return "No Optimizer" diff --git a/pytorch_lightning/strategies/ddp.py b/pytorch_lightning/strategies/ddp.py index 14f20a5f5e..2a221724fc 100644 --- a/pytorch_lightning/strategies/ddp.py +++ b/pytorch_lightning/strategies/ddp.py @@ -31,7 +31,7 @@ from torch.nn import Module from torch.nn.parallel.distributed import DistributedDataParallel import pytorch_lightning as pl -from pytorch_lightning.core.optimizer import LightningOptimizer +from pytorch_lightning.core.optimizer import _convert_to_lightning_optimizers, LightningOptimizer from pytorch_lightning.overrides import LightningDistributedModule from pytorch_lightning.overrides.distributed import prepare_for_backward from pytorch_lightning.plugins.environments.cluster_environment import ClusterEnvironment @@ -347,7 +347,7 @@ class DDPStrategy(ParallelStrategy): del optimizer trainer = self.lightning_module.trainer trainer.optimizers = optimizers - trainer.convert_to_lightning_optimizers() + _convert_to_lightning_optimizers(trainer) def configure_ddp(self) -> None: self.pre_configure_ddp() diff --git a/pytorch_lightning/strategies/deepspeed.py b/pytorch_lightning/strategies/deepspeed.py index 3824b194c5..8b34061d18 100644 --- a/pytorch_lightning/strategies/deepspeed.py +++ b/pytorch_lightning/strategies/deepspeed.py @@ -27,11 +27,11 @@ from torch.optim import Optimizer from torch.optim.lr_scheduler import _LRScheduler import pytorch_lightning as pl +from pytorch_lightning.core.optimizer import _get_default_scheduler_config, _init_optimizers_and_lr_schedulers from pytorch_lightning.overrides.base import _LightningModuleWrapperBase from pytorch_lightning.plugins.environments.cluster_environment import ClusterEnvironment from pytorch_lightning.plugins.precision import PrecisionPlugin from pytorch_lightning.strategies.ddp import DDPStrategy -from pytorch_lightning.trainer.optimizers import _get_default_scheduler_config from pytorch_lightning.trainer.states import TrainerFn from pytorch_lightning.utilities import GradClipAlgorithmType from pytorch_lightning.utilities.apply_func import apply_to_collection @@ -446,9 +446,7 @@ class DeepSpeedStrategy(DDPStrategy): self._initialize_deepspeed_inference(model) def _init_optimizers(self) -> Tuple[Optimizer, Optional[Union[LRSchedulerTypeTuple]], Optional[int]]: - optimizers, schedulers, optimizer_frequencies = self.lightning_module.trainer.init_optimizers( - self.lightning_module - ) + optimizers, schedulers, optimizer_frequencies = _init_optimizers_and_lr_schedulers(self.lightning_module) if len(optimizers) > 1 or len(schedulers) > 1: raise MisconfigurationException( "DeepSpeed currently only supports single optimizer, single optional scheduler." diff --git a/pytorch_lightning/strategies/sharded.py b/pytorch_lightning/strategies/sharded.py index 67b848040c..3900b0c798 100644 --- a/pytorch_lightning/strategies/sharded.py +++ b/pytorch_lightning/strategies/sharded.py @@ -19,7 +19,7 @@ from torch.nn import Module from torch.optim import Optimizer import pytorch_lightning as pl -from pytorch_lightning.core.optimizer import LightningOptimizer +from pytorch_lightning.core.optimizer import _convert_to_lightning_optimizers, LightningOptimizer from pytorch_lightning.strategies.ddp import DDPStrategy from pytorch_lightning.trainer.states import TrainerFn from pytorch_lightning.utilities import _FAIRSCALE_AVAILABLE, _FAIRSCALE_OSS_FP16_BROADCAST_AVAILABLE, rank_zero_only @@ -50,7 +50,7 @@ class DDPShardedStrategy(DDPStrategy): optimizers=trainer.optimizers, ) trainer.optimizers = optimizers - trainer.convert_to_lightning_optimizers() + _convert_to_lightning_optimizers(trainer) def _setup_model_and_optimizers(self, model: Module, optimizers: List[Optimizer]) -> Tuple[Module, List[Optimizer]]: """Wraps the model and optimizers with fairscale components. diff --git a/pytorch_lightning/strategies/training_type_plugin.py b/pytorch_lightning/strategies/training_type_plugin.py index 9d7bb51b9a..2ee586a724 100644 --- a/pytorch_lightning/strategies/training_type_plugin.py +++ b/pytorch_lightning/strategies/training_type_plugin.py @@ -22,6 +22,7 @@ from torch.optim import Optimizer from torch.utils.data import DataLoader import pytorch_lightning as pl +from pytorch_lightning.core.optimizer import _init_optimizers_and_lr_schedulers from pytorch_lightning.overrides.base import unwrap_lightning_module from pytorch_lightning.plugins import TorchCheckpointIO from pytorch_lightning.plugins.io.checkpoint_plugin import CheckpointIO @@ -377,7 +378,7 @@ class Strategy(ABC): return dataloader def init_optimizers(self, trainer: "pl.Trainer", model: "pl.LightningModule"): - return trainer.init_optimizers(model) + return _init_optimizers_and_lr_schedulers(model) @property def restore_checkpoint_after_setup(self) -> bool: diff --git a/pytorch_lightning/trainer/optimizers.py b/pytorch_lightning/trainer/optimizers.py index 6635dd8bb2..0b44786873 100644 --- a/pytorch_lightning/trainer/optimizers.py +++ b/pytorch_lightning/trainer/optimizers.py @@ -13,239 +13,43 @@ # limitations under the License. from abc import ABC -from typing import Any, Dict, List, Optional, Tuple, Union - -import torch -from torch import optim -from torch.optim.optimizer import Optimizer +from typing import List, Optional, Tuple import pytorch_lightning as pl -from pytorch_lightning.core.optimizer import LightningOptimizer -from pytorch_lightning.utilities import rank_zero_warn -from pytorch_lightning.utilities.exceptions import MisconfigurationException -from pytorch_lightning.utilities.types import LRSchedulerConfig +from pytorch_lightning.core.optimizer import ( + _convert_to_lightning_optimizers, + _init_optimizers_and_lr_schedulers, + LightningOptimizer, +) +from pytorch_lightning.utilities import rank_zero_deprecation class TrainerOptimizersMixin(ABC): + r""" + .. deprecated:: v1.6 + The `TrainerOptimizersMixin` was deprecated in v1.6 and will be removed in v1.8. + """ _lightning_optimizers: Optional[List[LightningOptimizer]] def init_optimizers(self, model: Optional["pl.LightningModule"]) -> Tuple[List, List, List]: + r""" + .. deprecated:: v1.6 + `TrainerOptimizersMixin.init_optimizers` was deprecated in v1.6 and will be removed in v1.8. + """ + rank_zero_deprecation( + "`TrainerOptimizersMixin.init_optimizers` was deprecated in v1.6 and will be removed in v1.8." + ) pl_module = self.lightning_module or model - self._lightning_optimizers = None - optim_conf = self._call_lightning_module_hook("configure_optimizers", pl_module=pl_module) - if optim_conf is None: - rank_zero_warn( - "`LightningModule.configure_optimizers` returned `None`, this fit will run with no optimizer", - ) - optim_conf = _MockOptimizer() - - optimizers, lr_schedulers, optimizer_frequencies, monitor = self._configure_optimizers(optim_conf) - lr_schedulers = self._configure_schedulers(lr_schedulers, monitor, not pl_module.automatic_optimization) - _validate_scheduler_optimizer(optimizers, lr_schedulers) - return optimizers, lr_schedulers, optimizer_frequencies - - @staticmethod - def _configure_optimizers( - optim_conf: Union[Dict[str, Any], List, Optimizer, Tuple] - ) -> Tuple[List, List, List, Optional[str]]: - optimizers, lr_schedulers, optimizer_frequencies = [], [], [] - monitor = None - - # single output, single optimizer - if isinstance(optim_conf, Optimizer): - optimizers = [optim_conf] - # two lists, optimizer + lr schedulers - elif ( - isinstance(optim_conf, (list, tuple)) - and len(optim_conf) == 2 - and isinstance(optim_conf[0], list) - and all(isinstance(opt, Optimizer) for opt in optim_conf[0]) - ): - opt, sch = optim_conf - optimizers = opt - lr_schedulers = sch if isinstance(sch, list) else [sch] - # single dictionary - elif isinstance(optim_conf, dict): - _validate_optim_conf(optim_conf) - optimizers = [optim_conf["optimizer"]] - monitor = optim_conf.get("monitor", None) - lr_schedulers = [optim_conf["lr_scheduler"]] if "lr_scheduler" in optim_conf else [] - # multiple dictionaries - elif isinstance(optim_conf, (list, tuple)) and all(isinstance(d, dict) for d in optim_conf): - for opt_dict in optim_conf: - _validate_optim_conf(opt_dict) - optimizers = [opt_dict["optimizer"] for opt_dict in optim_conf] - scheduler_dict = ( - lambda scheduler, opt_idx: dict(scheduler, opt_idx=opt_idx) - if isinstance(scheduler, dict) - else {"scheduler": scheduler, "opt_idx": opt_idx} - ) - - lr_schedulers = [ - scheduler_dict(opt_dict["lr_scheduler"], opt_idx) - for opt_idx, opt_dict in enumerate(optim_conf) - if "lr_scheduler" in opt_dict - ] - optimizer_frequencies = [ - opt_dict["frequency"] for opt_dict in optim_conf if opt_dict.get("frequency", None) is not None - ] - # assert that if frequencies are present, they are given for all optimizers - if optimizer_frequencies and len(optimizer_frequencies) != len(optimizers): - raise ValueError("A frequency must be given to each optimizer.") - # single list or tuple, multiple optimizer - elif isinstance(optim_conf, (list, tuple)) and all(isinstance(opt, Optimizer) for opt in optim_conf): - optimizers = list(optim_conf) - # unknown configuration - else: - raise MisconfigurationException( - "Unknown configuration for model optimizers." - " Output from `model.configure_optimizers()` should either be:\n" - " * `torch.optim.Optimizer`\n" - " * [`torch.optim.Optimizer`]\n" - " * ([`torch.optim.Optimizer`], [`torch.optim.lr_scheduler`])\n" - ' * {"optimizer": `torch.optim.Optimizer`, (optional) "lr_scheduler": `torch.optim.lr_scheduler`}\n' - ' * A list of the previously described dict format, with an optional "frequency" key (int)' - ) - return optimizers, lr_schedulers, optimizer_frequencies, monitor + return _init_optimizers_and_lr_schedulers(pl_module) def convert_to_lightning_optimizers(self): - def _convert_to_lightning_optimizer(trainer, optimizer): - if not isinstance(optimizer, LightningOptimizer): - optimizer = LightningOptimizer(optimizer) - optimizer._on_trainer_init(trainer) - return optimizer - - self._lightning_optimizers = { - opt_idx: _convert_to_lightning_optimizer(self, opt) for opt_idx, opt in enumerate(self.optimizers) - } - - @staticmethod - def _configure_schedulers( - schedulers: list, monitor: Optional[str], is_manual_optimization: bool - ) -> List[LRSchedulerConfig]: - """Convert each scheduler into dict structure with relevant information.""" - lr_schedulers = [] - default_config = _get_default_scheduler_config() - for scheduler in schedulers: - if is_manual_optimization: - if isinstance(scheduler, dict): - invalid_keys = {"interval", "frequency", "reduce_on_plateau", "monitor", "strict"} - keys_to_warn = [k for k in scheduler.keys() if k in invalid_keys] - - if keys_to_warn: - rank_zero_warn( - f"The lr scheduler dict contains the key(s) {keys_to_warn}, but the keys will be ignored." - " You need to call `lr_scheduler.step()` manually in manual optimization.", - category=RuntimeWarning, - ) - - scheduler = {key: scheduler[key] for key in scheduler if key not in invalid_keys} - lr_schedulers.append({**default_config, **scheduler}) - else: - lr_schedulers.append({**default_config, "scheduler": scheduler}) - else: - if isinstance(scheduler, dict): - # check provided keys - extra_keys = [k for k in scheduler.keys() if k not in default_config.keys()] - if extra_keys: - rank_zero_warn( - f"Found unsupported keys in the lr scheduler dict: {extra_keys}", category=RuntimeWarning - ) - if "scheduler" not in scheduler: - raise MisconfigurationException( - 'The lr scheduler dict must have the key "scheduler" with its item being an lr scheduler' - ) - if "interval" in scheduler and scheduler["interval"] not in ("step", "epoch"): - raise MisconfigurationException( - 'The "interval" key in lr scheduler dict must be "step" or "epoch"' - f' but is "{scheduler["interval"]}"' - ) - scheduler["reduce_on_plateau"] = isinstance( - scheduler["scheduler"], optim.lr_scheduler.ReduceLROnPlateau - ) - if scheduler["reduce_on_plateau"] and scheduler.get("monitor", None) is None: - raise MisconfigurationException( - "The lr scheduler dict must include a monitor when a `ReduceLROnPlateau` scheduler is used." - ' For example: {"optimizer": optimizer, "lr_scheduler":' - ' {"scheduler": scheduler, "monitor": "your_loss"}}' - ) - is_one_cycle = isinstance(scheduler["scheduler"], optim.lr_scheduler.OneCycleLR) - if is_one_cycle and scheduler.get("interval", "epoch") == "epoch": - rank_zero_warn( - "A `OneCycleLR` scheduler is using 'interval': 'epoch'." - " Are you sure you didn't mean 'interval': 'step'?", - category=RuntimeWarning, - ) - lr_schedulers.append({**default_config, **scheduler}) - elif isinstance(scheduler, optim.lr_scheduler.ReduceLROnPlateau): - if monitor is None: - raise MisconfigurationException( - "`configure_optimizers` must include a monitor when a `ReduceLROnPlateau`" - " scheduler is used. For example:" - ' {"optimizer": optimizer, "lr_scheduler": scheduler, "monitor": "metric_to_track"}' - ) - lr_schedulers.append( - {**default_config, "scheduler": scheduler, "reduce_on_plateau": True, "monitor": monitor} - ) - elif isinstance(scheduler, optim.lr_scheduler._LRScheduler): - lr_schedulers.append({**default_config, "scheduler": scheduler}) - else: - raise ValueError(f'The provided lr scheduler "{scheduler}" is invalid') - return lr_schedulers - - -class _MockOptimizer(Optimizer): - """The `_MockOptimizer` will be used inplace of an optimizer in the event that `None` is returned from - `configure_optimizers`.""" - - def __init__(self): - super().__init__([torch.zeros(1)], {}) - - def add_param_group(self, param_group): - pass # Do Nothing - - def load_state_dict(self, state_dict): - pass # Do Nothing - - def state_dict(self): - return {} # Return Empty - - def step(self, closure=None): - if closure is not None: - closure() - - def zero_grad(self): - pass # Do Nothing - - def __repr__(self): - return "No Optimizer" - - -def _validate_optim_conf(optim_conf: Dict[str, Any]) -> None: - valid_keys = {"optimizer", "lr_scheduler", "frequency", "monitor"} - extra_keys = optim_conf.keys() - valid_keys - if extra_keys: - rank_zero_warn( - f"Found unsupported keys in the optimizer configuration: {set(extra_keys)}", category=RuntimeWarning + r""" + .. deprecated:: v1.6 + `TrainerOptimizersMixin.convert_to_lightning_optimizers` was deprecated in v1.6 and will be removed in v1.8. + """ + rank_zero_deprecation( + "`TrainerOptimizersMixin.convert_to_lightning_optimizers` was deprecated in v1.6 and will be removed in " + "v1.8." ) - - -def _validate_scheduler_optimizer(optimizers, lr_schedulers): - if any(sch["scheduler"].optimizer not in optimizers for sch in lr_schedulers): - raise MisconfigurationException( - "Some schedulers are attached with an optimizer that wasn't returned from `configure_optimizers`." - ) - - -def _get_default_scheduler_config() -> Dict[str, Any]: - return { - "scheduler": None, - "name": None, # no custom name - "interval": "epoch", # after epoch is over - "frequency": 1, # every epoch/batch - "reduce_on_plateau": False, # most often not ReduceLROnPlateau scheduler - "monitor": None, # value to monitor for ReduceLROnPlateau - "strict": True, # enforce that the monitor exists for ReduceLROnPlateau - "opt_idx": None, # necessary to store opt_idx when optimizer frequencies are specified - } + _convert_to_lightning_optimizers(self) diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 179462a805..434ff45519 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -33,7 +33,7 @@ from pytorch_lightning.accelerators import Accelerator, IPUAccelerator from pytorch_lightning.callbacks import Callback, EarlyStopping, ModelCheckpoint, ProgressBarBase from pytorch_lightning.callbacks.prediction_writer import BasePredictionWriter from pytorch_lightning.core.datamodule import LightningDataModule -from pytorch_lightning.core.optimizer import LightningOptimizer +from pytorch_lightning.core.optimizer import _convert_to_lightning_optimizers, LightningOptimizer from pytorch_lightning.loggers import LightningLoggerBase from pytorch_lightning.loggers.base import DummyLogger, LoggerCollection from pytorch_lightning.loggers.tensorboard import TensorBoardLogger @@ -480,8 +480,7 @@ class Trainer( # default .predict() loop self.predict_loop = PredictionLoop() - # Needed because of LightningOptimizer - self._lightning_optimizers = None + self._lightning_optimizers: Optional[Dict[int, LightningOptimizer]] = None # .validate() and .test() set this when they load a checkpoint self.validated_ckpt_path: Optional[str] = None @@ -1926,9 +1925,9 @@ class Trainer( return SLURMEnvironment.job_id() @property - def lightning_optimizers(self) -> List[LightningOptimizer]: + def lightning_optimizers(self) -> Dict[int, LightningOptimizer]: if self._lightning_optimizers is None: - self.convert_to_lightning_optimizers() + _convert_to_lightning_optimizers(self) return self._lightning_optimizers @property diff --git a/pytorch_lightning/tuner/lr_finder.py b/pytorch_lightning/tuner/lr_finder.py index ba4c737ed0..5b5c6adf32 100644 --- a/pytorch_lightning/tuner/lr_finder.py +++ b/pytorch_lightning/tuner/lr_finder.py @@ -24,8 +24,8 @@ from torch.optim.lr_scheduler import _LRScheduler import pytorch_lightning as pl from pytorch_lightning.callbacks import Callback +from pytorch_lightning.core.optimizer import _get_default_scheduler_config from pytorch_lightning.loggers.base import DummyLogger -from pytorch_lightning.trainer.optimizers import _get_default_scheduler_config from pytorch_lightning.utilities import rank_zero_warn from pytorch_lightning.utilities.cloud_io import get_filesystem from pytorch_lightning.utilities.exceptions import MisconfigurationException @@ -98,15 +98,15 @@ class _LRFinder: self.results = {} self._total_batch_idx = 0 # for debug purpose - def _exchange_scheduler(self, trainer: "pl.Trainer"): - """Decorate `trainer.init_optimizers` method such that it returns the users originally specified optimizer - together with a new scheduler that that takes care of the learning rate search.""" - init_optimizers = trainer.init_optimizers + def _exchange_scheduler(self, trainer: "pl.Trainer", model: "pl.LightningModule"): + """Decorate `trainer.strategy.init_optimizers` method such that it returns the user's originally specified + optimizer together with a new scheduler that that takes care of the learning rate search.""" + init_optimizers = trainer.strategy.init_optimizers @wraps(init_optimizers) - def func(model): - # Decide the structure of the output from init_optimizers - optimizers, _, _ = init_optimizers(model) + def func(trainer, model): + # Decide the structure of the output from trainer.strategy.init_optimizers + optimizers, _, _ = init_optimizers(trainer, model) if len(optimizers) != 1: raise MisconfigurationException( @@ -232,7 +232,7 @@ def lr_find( trainer.save_checkpoint(str(save_path)) # Configure optimizer and scheduler - trainer.init_optimizers = lr_finder._exchange_scheduler(trainer) + trainer.strategy.init_optimizers = lr_finder._exchange_scheduler(trainer, model) # Fit, lr & loss logged in callback trainer.tuner._run(model) @@ -278,7 +278,6 @@ def __lr_finder_dump_params(trainer, model): "max_steps": trainer.max_steps, "checkpoint_callback": trainer.checkpoint_callback, "current_epoch": trainer.current_epoch, - "init_optimizers": trainer.init_optimizers, } @@ -289,7 +288,6 @@ def __lr_finder_restore_params(trainer, model): trainer.fit_loop.global_step = trainer.__dumped_params["global_step"] trainer.fit_loop.max_steps = trainer.__dumped_params["max_steps"] trainer.fit_loop.current_epoch = trainer.__dumped_params["current_epoch"] - trainer.init_optimizers = trainer.__dumped_params["init_optimizers"] del trainer.__dumped_params diff --git a/tests/deprecated_api/test_remove_1-8.py b/tests/deprecated_api/test_remove_1-8.py index 46f0a6c91d..e2ce852fc9 100644 --- a/tests/deprecated_api/test_remove_1-8.py +++ b/tests/deprecated_api/test_remove_1-8.py @@ -140,6 +140,24 @@ def test_v1_8_0_deprecated_trainer_should_rank_save_checkpoint(tmpdir): _ = trainer.should_rank_save_checkpoint +def test_v1_8_0_trainer_optimizers_mixin(): + trainer = Trainer() + model = BoringModel() + trainer.strategy.connect(model) + trainer.lightning_module.trainer = trainer + + with pytest.deprecated_call( + match=r"`TrainerOptimizersMixin.init_optimizers` was deprecated in v1.6 and will be removed in v1.8." + ): + trainer.init_optimizers(model) + + with pytest.deprecated_call( + match=r"`TrainerOptimizersMixin.convert_to_lightning_optimizers` was deprecated in v1.6 and will be removed in " + "v1.8." + ): + trainer.convert_to_lightning_optimizers() + + def test_v1_8_0_deprecate_trainer_callback_hook_mixin(): methods_with_self = [ "on_before_accelerator_backend_setup", diff --git a/tests/helpers/pipelines.py b/tests/helpers/pipelines.py index e67ee5f52e..2f7d2d584d 100644 --- a/tests/helpers/pipelines.py +++ b/tests/helpers/pipelines.py @@ -15,7 +15,6 @@ import torch from torchmetrics.functional import accuracy from pytorch_lightning import LightningDataModule, LightningModule, Trainer -from pytorch_lightning.utilities import _StrategyType from tests.helpers import BoringModel from tests.helpers.utils import get_default_logger, load_model_from_checkpoint, reset_seed @@ -70,7 +69,7 @@ def run_model_test( assert change_ratio > 0.03, f"the model is changed of {change_ratio}" # test model loading - pretrained_model = load_model_from_checkpoint(logger, trainer.checkpoint_callback.best_model_path, type(model)) + _ = load_model_from_checkpoint(logger, trainer.checkpoint_callback.best_model_path, type(model)) # test new model accuracy test_loaders = model.test_dataloader() if not data else data.test_dataloader() @@ -82,12 +81,6 @@ def run_model_test( run_model_prediction(model, dataloader, min_acc=min_acc) if with_hpc: - if trainer._distrib_type in (_StrategyType.DDP, _StrategyType.DDP_SPAWN, _StrategyType.DDP2): - # on hpc this would work fine... but need to hack it for the purpose of the test - trainer.optimizers, trainer.lr_schedulers, trainer.optimizer_frequencies = trainer.init_optimizers( - pretrained_model - ) - # test HPC saving trainer.checkpoint_connector.hpc_save(save_dir, logger) # test HPC loading diff --git a/tests/trainer/optimization/test_optimizers.py b/tests/trainer/optimization/test_optimizers.py index 4a99b3318f..9c90a5def7 100644 --- a/tests/trainer/optimization/test_optimizers.py +++ b/tests/trainer/optimization/test_optimizers.py @@ -19,7 +19,11 @@ from torch import optim from pytorch_lightning import Callback, Trainer from pytorch_lightning.callbacks import ModelCheckpoint -from pytorch_lightning.trainer.optimizers import TrainerOptimizersMixin +from pytorch_lightning.core.optimizer import ( + _configure_optimizers, + _configure_schedulers, + _init_optimizers_and_lr_schedulers, +) from pytorch_lightning.utilities.exceptions import MisconfigurationException from tests.helpers.boring_model import BoringDataModule, BoringModel from tests.helpers.runif import RunIf @@ -106,7 +110,7 @@ def test_onecyclelr_with_epoch_interval_warns(): optimizer = optim.Adam(model.parameters()) lr_scheduler = {"scheduler": optim.lr_scheduler.OneCycleLR(optimizer, max_lr=0.01, total_steps=3)} with pytest.warns(RuntimeWarning, match="Are you sure you didn't mean 'interval': 'step'?"): - TrainerOptimizersMixin._configure_schedulers([lr_scheduler], None, False) + _configure_schedulers([lr_scheduler], None, False) def test_reducelronplateau_scheduling(tmpdir): @@ -144,6 +148,8 @@ def test_reducelronplateau_scheduling(tmpdir): def test_optimizer_return_options(tmpdir): trainer = Trainer(default_root_dir=tmpdir) model = BoringModel() + trainer.strategy.connect(model) + trainer.lightning_module.trainer = trainer # single optimizer opt_a = optim.Adam(model.parameters(), lr=0.002) @@ -153,18 +159,18 @@ def test_optimizer_return_options(tmpdir): # single optimizer model.configure_optimizers = lambda: opt_a - opt, lr_sched, freq = trainer.init_optimizers(model) + opt, lr_sched, freq = _init_optimizers_and_lr_schedulers(model) assert len(opt) == 1 and len(lr_sched) == len(freq) == 0 # opt tuple model.configure_optimizers = lambda: (opt_a, opt_b) - opt, lr_sched, freq = trainer.init_optimizers(model) + opt, lr_sched, freq = _init_optimizers_and_lr_schedulers(model) assert opt == [opt_a, opt_b] assert len(lr_sched) == len(freq) == 0 # opt list model.configure_optimizers = lambda: [opt_a, opt_b] - opt, lr_sched, freq = trainer.init_optimizers(model) + opt, lr_sched, freq = _init_optimizers_and_lr_schedulers(model) assert opt == [opt_a, opt_b] assert len(lr_sched) == len(freq) == 0 @@ -181,7 +187,7 @@ def test_optimizer_return_options(tmpdir): # opt tuple of 2 lists model.configure_optimizers = lambda: ([opt_a], [scheduler_a]) - opt, lr_sched, freq = trainer.init_optimizers(model) + opt, lr_sched, freq = _init_optimizers_and_lr_schedulers(model) assert len(opt) == len(lr_sched) == 1 assert len(freq) == 0 assert opt[0] == opt_a @@ -189,7 +195,7 @@ def test_optimizer_return_options(tmpdir): # opt tuple of 1 list model.configure_optimizers = lambda: ([opt_a], scheduler_a) - opt, lr_sched, freq = trainer.init_optimizers(model) + opt, lr_sched, freq = _init_optimizers_and_lr_schedulers(model) assert len(opt) == len(lr_sched) == 1 assert len(freq) == 0 assert opt[0] == opt_a @@ -197,7 +203,7 @@ def test_optimizer_return_options(tmpdir): # opt single dictionary model.configure_optimizers = lambda: {"optimizer": opt_a, "lr_scheduler": scheduler_a} - opt, lr_sched, freq = trainer.init_optimizers(model) + opt, lr_sched, freq = _init_optimizers_and_lr_schedulers(model) assert len(opt) == len(lr_sched) == 1 assert len(freq) == 0 assert opt[0] == opt_a @@ -208,7 +214,7 @@ def test_optimizer_return_options(tmpdir): {"optimizer": opt_a, "lr_scheduler": scheduler_a, "frequency": 1}, {"optimizer": opt_b, "lr_scheduler": scheduler_b, "frequency": 5}, ) - opt, lr_sched, freq = trainer.init_optimizers(model) + opt, lr_sched, freq = _init_optimizers_and_lr_schedulers(model) assert len(opt) == len(lr_sched) == len(freq) == 2 assert opt[0] == opt_a ref_lr_sched["opt_idx"] = 0 @@ -436,7 +442,7 @@ def test_optimizer_config_dict_with_extra_keys_warns(tmpdir): "bar": 2, } with pytest.warns(RuntimeWarning, match=r"Found unsupported keys in the optimizer configuration: \{.+\}"): - TrainerOptimizersMixin._configure_optimizers(optim_conf) + _configure_optimizers(optim_conf) def test_multiple_optimizer_config_dicts_with_extra_keys_warns(tmpdir): @@ -451,7 +457,7 @@ def test_multiple_optimizer_config_dicts_with_extra_keys_warns(tmpdir): {"optimizer": optimizer2, "lr_scheduler": lr_scheduler_config_2, "foo": 1, "bar": 2}, ] with pytest.warns(RuntimeWarning, match=r"Found unsupported keys in the optimizer configuration: \{.+\}"): - TrainerOptimizersMixin._configure_optimizers(optim_conf) + _configure_optimizers(optim_conf) def test_lr_scheduler_with_unknown_interval_raises(tmpdir):