# Copyright The PyTorch Lightning team. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. r""" Learning Rate Monitor ===================== Monitor and logs learning rate for lr schedulers during training. """ from collections import defaultdict from typing import Any, DefaultDict, Dict, List, Optional, Set, Type from torch.optim.optimizer import Optimizer from pytorch_lightning.callbacks.base import Callback from pytorch_lightning.utilities import rank_zero_warn from pytorch_lightning.utilities.exceptions import MisconfigurationException class LearningRateMonitor(Callback): r""" Automatically monitor and logs learning rate for learning rate schedulers during training. Args: logging_interval: set to ``'epoch'`` or ``'step'`` to log ``lr`` of all optimizers at the same interval, set to ``None`` to log at individual interval according to the ``interval`` key of each scheduler. Defaults to ``None``. log_momentum: option to also log the momentum values of the optimizer, if the optimizer has the ``momentum`` or ``betas`` attribute. Defaults to ``False``. Raises: MisconfigurationException: If ``logging_interval`` is none of ``"step"``, ``"epoch"``, or ``None``. Example:: >>> from pytorch_lightning import Trainer >>> from pytorch_lightning.callbacks import LearningRateMonitor >>> lr_monitor = LearningRateMonitor(logging_interval='step') >>> trainer = Trainer(callbacks=[lr_monitor]) Logging names are automatically determined based on optimizer class name. In case of multiple optimizers of same type, they will be named ``Adam``, ``Adam-1`` etc. If a optimizer has multiple parameter groups they will be named ``Adam/pg1``, ``Adam/pg2`` etc. To control naming, pass in a ``name`` keyword in the construction of the learning rate schedulers. A ``name`` keyword can also be used for parameter groups in the construction of the optimizer. Example:: def configure_optimizer(self): optimizer = torch.optim.Adam(...) lr_scheduler = { 'scheduler': torch.optim.lr_scheduler.LambdaLR(optimizer, ...) 'name': 'my_logging_name' } return [optimizer], [lr_scheduler] Example:: def configure_optimizer(self): optimizer = torch.optim.SGD( [{ 'params': [p for p in self.parameters()], 'name': 'my_parameter_group_name' }], lr=0.1 ) lr_scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, ...) return [optimizer], [lr_scheduler] """ def __init__(self, logging_interval: Optional[str] = None, log_momentum: bool = False): if logging_interval not in (None, "step", "epoch"): raise MisconfigurationException("logging_interval should be `step` or `epoch` or `None`.") self.logging_interval = logging_interval self.log_momentum = log_momentum self.lrs = None self.lr_sch_names = [] def on_train_start(self, trainer, *args, **kwargs): """Called before training, determines unique names for all lr schedulers in the case of multiple of the same type or in the case of multiple parameter groups. Raises: MisconfigurationException: If ``Trainer`` has no ``logger``. """ if not trainer.logger: raise MisconfigurationException( "Cannot use `LearningRateMonitor` callback with `Trainer` that has no logger." ) if not trainer.lr_schedulers: rank_zero_warn( "You are using `LearningRateMonitor` callback with models that" " have no learning rate schedulers. Please see documentation" " for `configure_optimizers` method.", RuntimeWarning, ) if self.log_momentum: def _check_no_key(key): return any(key not in sch["scheduler"].optimizer.defaults for sch in trainer.lr_schedulers) if _check_no_key("momentum") and _check_no_key("betas"): rank_zero_warn( "You have set log_momentum=True, but some optimizers do not" " have momentum. This will log a value 0 for the momentum.", RuntimeWarning, ) # Find names for schedulers names = self._find_names(trainer.lr_schedulers) # Initialize for storing values self.lrs = {name: [] for name in names} self.last_momentum_values = {name + "-momentum": None for name in names} def on_train_batch_start(self, trainer, *args, **kwargs): if not self._should_log(trainer): return if self.logging_interval != "epoch": interval = "step" if self.logging_interval is None else "any" latest_stat = self._extract_stats(trainer, interval) if latest_stat: trainer.logger.log_metrics(latest_stat, step=trainer.global_step) def on_train_epoch_start(self, trainer, *args, **kwargs): if self.logging_interval != "step": interval = "epoch" if self.logging_interval is None else "any" latest_stat = self._extract_stats(trainer, interval) if latest_stat: trainer.logger.log_metrics(latest_stat, step=trainer.global_step) def _extract_stats(self, trainer, interval: str) -> Dict[str, float]: latest_stat = {} names = self._find_names(trainer.lr_schedulers, add_lr_sch_names=False) self._remap_keys(names) for name, scheduler in zip(self.lr_sch_names, trainer.lr_schedulers): if scheduler["interval"] == interval or interval == "any": opt = scheduler["scheduler"].optimizer param_groups = opt.param_groups use_betas = "betas" in opt.defaults for i, pg in enumerate(param_groups): name_and_suffix = self._add_suffix(name, param_groups, i) lr = self._extract_lr(pg, name_and_suffix) latest_stat.update(lr) momentum = self._extract_momentum( param_group=pg, name=name_and_suffix.replace(name, f"{name}-momentum"), use_betas=use_betas ) latest_stat.update(momentum) return latest_stat def _extract_lr(self, param_group: Dict[str, Any], name: str) -> Dict[str, Any]: lr = param_group.get("lr") self.lrs[name].append(lr) return {name: lr} def _remap_keys(self, names: List[str], token: str = "/pg1") -> None: """This function is used the remap the keys if param groups for a given optimizer increased.""" for new_name in names: old_name = new_name.replace(token, "") if token in new_name and old_name in self.lrs: self.lrs[new_name] = self.lrs.pop(old_name) elif new_name not in self.lrs: self.lrs[new_name] = [] def _extract_momentum(self, param_group: Dict[str, Any], name: str, use_betas: bool) -> Dict[str, float]: if not self.log_momentum: return {} momentum = param_group.get("betas")[0] if use_betas else param_group.get("momentum", 0) self.last_momentum_values[name] = momentum return {name: momentum} def _add_prefix( self, name: str, optimizer_cls: Type[Optimizer], seen_optimizer_types: DefaultDict[Type[Optimizer], int] ) -> str: if optimizer_cls not in seen_optimizer_types: return name count = seen_optimizer_types[optimizer_cls] return name + f"-{count - 1}" if count > 1 else name def _add_suffix(self, name: str, param_groups: List[Dict], param_group_index: int, use_names: bool = True) -> str: if len(param_groups) > 1: if not use_names: return f"{name}/pg{param_group_index+1}" pg_name = param_groups[param_group_index].get("name", f"pg{param_group_index+1}") return f"{name}/{pg_name}" elif use_names: pg_name = param_groups[param_group_index].get("name") return f"{name}/{pg_name}" if pg_name else name return name def _duplicate_param_group_names(self, param_groups: List[Dict]) -> Set[str]: names = [pg.get("name", f"pg{i}") for i, pg in enumerate(param_groups, start=1)] unique = set(names) if len(names) == len(unique): return set() return {n for n in names if names.count(n) > 1} def _find_names(self, lr_schedulers: List, add_lr_sch_names: bool = True) -> List[str]: # Create unique names in the case we have multiple of the same learning # rate scheduler + multiple parameter groups names = [] seen_optimizers = [] seen_optimizer_types = defaultdict(int) for scheduler in lr_schedulers: sch = scheduler["scheduler"] if scheduler["name"] is not None: name = scheduler["name"] else: name = "lr-" + sch.optimizer.__class__.__name__ seen_optimizers.append(sch.optimizer) optimizer_cls = type(sch.optimizer) if scheduler["name"] is None: seen_optimizer_types[optimizer_cls] += 1 # Multiple param groups for the same scheduler param_groups = sch.optimizer.param_groups duplicates = self._duplicate_param_group_names(param_groups) if duplicates: raise MisconfigurationException( "A single `Optimizer` cannot have multiple parameter groups with identical " f"`name` values. {name} has duplicated parameter group names {duplicates}" ) name = self._add_prefix(name, optimizer_cls, seen_optimizer_types) names.extend(self._add_suffix(name, param_groups, i) for i in range(len(param_groups))) if add_lr_sch_names: self.lr_sch_names.append(name) return names @staticmethod def _should_log(trainer) -> bool: return (trainer.global_step + 1) % trainer.log_every_n_steps == 0 or trainer.should_stop