lightning/pytorch_lightning/trainer/optimizers.py

195 lines
8.6 KiB
Python

# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from abc import ABC
from typing import Any, Dict, List, Optional, Tuple
import torch
from torch import optim
from torch.optim.optimizer import Optimizer
from pytorch_lightning.core.lightning import LightningModule
from pytorch_lightning.core.optimizer import LightningOptimizer
from pytorch_lightning.utilities import rank_zero_warn
from pytorch_lightning.utilities.exceptions import MisconfigurationException
class TrainerOptimizersMixin(ABC):
_lightning_optimizers: Optional[List[LightningOptimizer]]
def init_optimizers(self, model: LightningModule) -> Tuple[List, List, List]:
self._lightning_optimizers = None
optim_conf = model.configure_optimizers()
if optim_conf is None:
rank_zero_warn(
'`LightningModule.configure_optimizers` returned `None`, this fit will run with no optimizer',
UserWarning,
)
optim_conf = _MockOptimizer()
optimizers, lr_schedulers, optimizer_frequencies = [], [], []
monitor = None
# single output, single optimizer
if isinstance(optim_conf, Optimizer):
optimizers = [optim_conf]
# two lists, optimizer + lr schedulers
elif isinstance(optim_conf, (list, tuple)) and len(optim_conf) == 2 and isinstance(optim_conf[0], list):
opt, sch = optim_conf
optimizers = opt
lr_schedulers = sch if isinstance(sch, list) else [sch]
# single dictionary
elif isinstance(optim_conf, dict):
optimizers = [optim_conf["optimizer"]]
monitor = optim_conf.get('monitor', None)
lr_schedulers = [optim_conf["lr_scheduler"]] if "lr_scheduler" in optim_conf else []
# multiple dictionaries
elif isinstance(optim_conf, (list, tuple)) and all(isinstance(d, dict) for d in optim_conf):
optimizers = [opt_dict["optimizer"] for opt_dict in optim_conf]
lr_schedulers = [opt_dict["lr_scheduler"] for opt_dict in optim_conf if "lr_scheduler" in opt_dict]
optimizer_frequencies = [
opt_dict["frequency"] for opt_dict in optim_conf if opt_dict.get("frequency", None) is not None
]
# assert that if frequencies are present, they are given for all optimizers
if optimizer_frequencies and len(optimizer_frequencies) != len(optimizers):
raise ValueError("A frequency must be given to each optimizer.")
# single list or tuple, multiple optimizer
elif isinstance(optim_conf, (list, tuple)):
optimizers = list(optim_conf)
# unknown configuration
else:
raise MisconfigurationException(
'Unknown configuration for model optimizers.'
' Output from `model.configure_optimizers()` should either be:\n'
' * `torch.optim.Optimizer`\n'
' * [`torch.optim.Optimizer`]\n'
' * ([`torch.optim.Optimizer`], [`torch.optim.lr_scheduler`])\n'
' * {"optimizer": `torch.optim.Optimizer`, (optional) "lr_scheduler": `torch.optim.lr_scheduler`}\n'
' * A list of the previously described dict format, with an optional "frequency" key (int)'
)
lr_schedulers = self.configure_schedulers(lr_schedulers, monitor=monitor)
_validate_scheduler_optimizer(optimizers, lr_schedulers)
return optimizers, lr_schedulers, optimizer_frequencies
def convert_to_lightning_optimizers(self):
def _convert_to_lightning_optimizer(trainer, optimizer):
if not isinstance(optimizer, LightningOptimizer):
optimizer = LightningOptimizer(optimizer)
optimizer._on_trainer_init(trainer)
return optimizer
self._lightning_optimizers = {
opt_idx: _convert_to_lightning_optimizer(self, opt)
for opt_idx, opt in enumerate(self.optimizers)
}
def configure_schedulers(self, schedulers: list, monitor: Optional[str] = None):
# Convert each scheduler into dict structure with relevant information
lr_schedulers = []
default_config = _get_default_scheduler_config()
for scheduler in schedulers:
if isinstance(scheduler, dict):
# check provided keys
extra_keys = [k for k in scheduler.keys() if k not in default_config.keys()]
if extra_keys:
rank_zero_warn(f'Found unsupported keys in the lr scheduler dict: {extra_keys}', RuntimeWarning)
if 'scheduler' not in scheduler:
raise MisconfigurationException(
'The lr scheduler dict must have the key "scheduler" with its item being an lr scheduler'
)
if 'interval' in scheduler and scheduler['interval'] not in ('step', 'epoch'):
raise MisconfigurationException(
f'The "interval" key in lr scheduler dict must be "step" or "epoch"'
f' but is "{scheduler["interval"]}"'
)
scheduler['reduce_on_plateau'] = isinstance(
scheduler['scheduler'], optim.lr_scheduler.ReduceLROnPlateau
)
if scheduler['reduce_on_plateau'] and scheduler.get('monitor', None) is None:
raise MisconfigurationException(
'The lr scheduler dict must include a monitor when a `ReduceLROnPlateau` scheduler is used.'
' For example: {"optimizer": optimizer, "lr_scheduler":'
' {"scheduler": scheduler, "monitor": "your_loss"}}'
)
lr_schedulers.append({**default_config, **scheduler})
elif isinstance(scheduler, optim.lr_scheduler.ReduceLROnPlateau):
if monitor is None:
raise MisconfigurationException(
'`configure_optimizers` must include a monitor when a `ReduceLROnPlateau` scheduler is used.'
' For example:'
' {"optimizer": optimizer, "lr_scheduler": scheduler, "monitor": "metric_to_track"}'
)
lr_schedulers.append({
**default_config, 'scheduler': scheduler,
'reduce_on_plateau': True,
'monitor': monitor
})
elif isinstance(scheduler, optim.lr_scheduler._LRScheduler):
lr_schedulers.append({**default_config, 'scheduler': scheduler})
else:
raise ValueError(f'The provided lr scheduler "{scheduler}" is invalid')
return lr_schedulers
class _MockOptimizer(Optimizer):
"""The `_MockOptimizer` will be used inplace of an optimizer in the event that `None`
is returned from `configure_optimizers`.
"""
def __init__(self):
super().__init__([torch.zeros(1)], {})
def add_param_group(self, param_group):
pass # Do Nothing
def load_state_dict(self, state_dict):
pass # Do Nothing
def state_dict(self):
return {} # Return Empty
def step(self, closure=None):
if closure is not None:
closure()
def zero_grad(self):
pass # Do Nothing
def __repr__(self):
return 'No Optimizer'
def _validate_scheduler_optimizer(optimizers, lr_schedulers):
if any(sch['scheduler'].optimizer not in optimizers for sch in lr_schedulers):
raise MisconfigurationException(
"Some schedulers are attatched with an optimizer that wasn't returned from `configure_optimizers`."
)
def _get_default_scheduler_config() -> Dict[str, Any]:
return {
'scheduler': None,
'name': None, # no custom name
'interval': 'epoch', # after epoch is over
'frequency': 1, # every epoch/batch
'reduce_on_plateau': False, # most often not ReduceLROnPlateau scheduler
'monitor': None, # value to monitor for ReduceLROnPlateau
'strict': True, # enforce that the monitor exists for ReduceLROnPlateau
}