From 142bc0230e228cd2e851481e5a07069e7d198655 Mon Sep 17 00:00:00 2001 From: Nicki Skafte Date: Thu, 30 Apr 2020 14:06:41 +0200 Subject: [PATCH] Learning rate log callback (#1498) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * base implementation * docs + implementation * fix styling * add lr string * renaming * CHANGELOG.md * add tests * Apply suggestions from code review Co-Authored-By: Adrian Wälchli * Apply suggestions from code review * Update pytorch_lightning/callbacks/lr_logger.py * Update pytorch_lightning/callbacks/lr_logger.py * add test for naming * base implementation * docs + implementation * fix styling * add lr string * renaming * CHANGELOG.md * add tests * Apply suggestions from code review Co-Authored-By: Adrian Wälchli * Apply suggestions from code review * Update pytorch_lightning/callbacks/lr_logger.py * Update pytorch_lightning/callbacks/lr_logger.py * add test for naming * Update pytorch_lightning/callbacks/lr_logger.py Co-Authored-By: Adrian Wälchli * suggestions from code review * fix styling * rebase * fix tests Co-authored-by: Nicki Skafte Co-authored-by: Jirka Borovec Co-authored-by: Adrian Wälchli --- CHANGELOG.md | 2 + docs/source/callbacks.rst | 8 ++ pytorch_lightning/callbacks/__init__.py | 2 + pytorch_lightning/callbacks/lr_logger.py | 118 +++++++++++++++++++++++ tests/callbacks/test_callbacks.py | 59 +++++++++++- 5 files changed, 188 insertions(+), 1 deletion(-) create mode 100755 pytorch_lightning/callbacks/lr_logger.py diff --git a/CHANGELOG.md b/CHANGELOG.md index 10ec061f18..f67e85a452 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -8,6 +8,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Added +- Added callback for logging learning rates ([#1498](https://github.com/PyTorchLightning/pytorch-lightning/pull/1498)) + ### Changed ### Deprecated diff --git a/docs/source/callbacks.rst b/docs/source/callbacks.rst index 10323472fa..a2969820b2 100644 --- a/docs/source/callbacks.rst +++ b/docs/source/callbacks.rst @@ -84,3 +84,11 @@ We successfully extended functionality without polluting our super clean .. automodule:: pytorch_lightning.callbacks.progress :noindex: :exclude-members: + +--------- + +.. automodule:: pytorch_lightning.callbacks.lr_logger + :noindex: + :exclude-members: + _extract_lr, + _find_names \ No newline at end of file diff --git a/pytorch_lightning/callbacks/__init__.py b/pytorch_lightning/callbacks/__init__.py index c232060ca4..7e8e0ce5bc 100644 --- a/pytorch_lightning/callbacks/__init__.py +++ b/pytorch_lightning/callbacks/__init__.py @@ -2,6 +2,7 @@ from pytorch_lightning.callbacks.base import Callback from pytorch_lightning.callbacks.early_stopping import EarlyStopping from pytorch_lightning.callbacks.gradient_accumulation_scheduler import GradientAccumulationScheduler from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint +from pytorch_lightning.callbacks.lr_logger import LearningRateLogger from pytorch_lightning.callbacks.progress import ProgressBarBase, ProgressBar __all__ = [ @@ -9,6 +10,7 @@ __all__ = [ 'EarlyStopping', 'ModelCheckpoint', 'GradientAccumulationScheduler', + 'LearningRateLogger', 'ProgressBarBase', 'ProgressBar', ] diff --git a/pytorch_lightning/callbacks/lr_logger.py b/pytorch_lightning/callbacks/lr_logger.py new file mode 100755 index 0000000000..6ad68905bc --- /dev/null +++ b/pytorch_lightning/callbacks/lr_logger.py @@ -0,0 +1,118 @@ +r""" + +Logging of learning rates +========================= + +Log learning rate for lr schedulers during training + +""" + +from pytorch_lightning.callbacks.base import Callback +from pytorch_lightning.utilities.exceptions import MisconfigurationException + + +class LearningRateLogger(Callback): + r""" + Automatically logs learning rate for learning rate schedulers during training. + + Example:: + + >>> from pytorch_lightning import Trainer + >>> from pytorch_lightning.callbacks import LearningRateLogger + >>> lr_logger = LearningRateLogger() + >>> trainer = Trainer(callbacks=[lr_logger]) + + Logging names are automatically determined based on optimizer class name. + In case of multiple optimizers of same type, they will be named `Adam`, + `Adam-1` etc. If a optimizer has multiple parameter groups they will + be named `Adam/pg1`, `Adam/pg2` etc. To control naming, pass in a + `name` keyword in the construction of the learning rate schdulers + + Example:: + + def configure_optimizer(self): + optimizer = torch.optim.Adam(...) + lr_scheduler = {'scheduler': torch.optim.lr_schedulers.LambdaLR(optimizer, ...) + 'name': 'my_logging_name'} + return [optimizer], [lr_scheduler] + """ + def __init__(self): + self.lrs = None + self.lr_sch_names = [] + + def on_train_start(self, trainer, pl_module): + """ Called before training, determines unique names for all lr + schedulers in the case of multiple of the same type or in + the case of multiple parameter groups + """ + if trainer.lr_schedulers == []: + raise MisconfigurationException( + 'Cannot use LearningRateLogger callback with models that have no' + ' learning rate schedulers. Please see documentation for' + ' `configure_optimizers` method.') + + if not trainer.logger: + raise MisconfigurationException( + 'Cannot use LearningRateLogger callback with Trainer that has no logger.') + + # Find names for schedulers + names = self._find_names(trainer.lr_schedulers) + + # Initialize for storing values + self.lrs = dict.fromkeys(names, []) + + def on_batch_start(self, trainer, pl_module): + latest_stat = self._extract_lr(trainer, 'step') + if trainer.logger and latest_stat: + trainer.logger.log_metrics(latest_stat, step=trainer.global_step) + + def on_epoch_start(self, trainer, pl_module): + latest_stat = self._extract_lr(trainer, 'epoch') + if trainer.logger and latest_stat: + trainer.logger.log_metrics(latest_stat, step=trainer.global_step) + + def _extract_lr(self, trainer, interval): + """ Extracts learning rates for lr schedulers and saves information + into dict structure. """ + latest_stat = {} + for name, scheduler in zip(self.lr_sch_names, trainer.lr_schedulers): + if scheduler['interval'] == interval: + param_groups = scheduler['scheduler'].optimizer.param_groups + if len(param_groups) != 1: + for i, pg in enumerate(param_groups): + lr, key = pg['lr'], f'{name}/{i + 1}' + self.lrs[key].append(lr) + latest_stat[key] = lr + else: + self.lrs[name].append(param_groups[0]['lr']) + latest_stat[name] = param_groups[0]['lr'] + return latest_stat + + def _find_names(self, lr_schedulers): + # Create uniqe names in the case we have multiple of the same learning + # rate schduler + multiple parameter groups + names = [] + for scheduler in lr_schedulers: + sch = scheduler['scheduler'] + if 'name' in scheduler: + name = scheduler['name'] + else: + opt_name = 'lr-' + sch.optimizer.__class__.__name__ + i, name = 1, opt_name + # Multiple schduler of the same type + while True: + if name not in names: + break + i, name = i + 1, f'{opt_name}-{i}' + + # Multiple param groups for the same schduler + param_groups = sch.optimizer.param_groups + if len(param_groups) != 1: + for i, pg in enumerate(param_groups): + temp = name + '/pg' + str(i + 1) + names.append(temp) + else: + names.append(name) + + self.lr_sch_names.append(name) + return names diff --git a/tests/callbacks/test_callbacks.py b/tests/callbacks/test_callbacks.py index 9dba21eab0..a082c5ec6f 100644 --- a/tests/callbacks/test_callbacks.py +++ b/tests/callbacks/test_callbacks.py @@ -2,11 +2,12 @@ import pytest import tests.base.utils as tutils from pytorch_lightning import Callback from pytorch_lightning import Trainer, LightningModule -from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint +from pytorch_lightning.callbacks import EarlyStopping, LearningRateLogger, ModelCheckpoint from tests.base import ( LightTrainDataloader, LightTestMixin, LightValidationMixin, + LightTestOptimizersWithMixedSchedulingMixin, TestModelBase ) @@ -273,3 +274,59 @@ def test_model_checkpoint_with_non_string_input(tmpdir, save_top_k): # These should be different if the dirpath has be overridden assert trainer.ckpt_path != trainer.default_root_dir + + +def test_lr_logger_single_lr(tmpdir): + """ Test that learning rates are extracted and logged for single lr scheduler""" + tutils.reset_seed() + + class CurrentTestModel(LightTrainDataloader, TestModelBase): + pass + + hparams = tutils.get_default_hparams() + model = CurrentTestModel(hparams) + + lr_logger = LearningRateLogger() + trainer = Trainer( + default_root_dir=tmpdir, + max_epochs=5, + val_percent_check=0.1, + train_percent_check=0.5, + callbacks=[lr_logger] + ) + results = trainer.fit(model) + + assert lr_logger.lrs, 'No learning rates logged' + assert len(lr_logger.lrs) == len(trainer.lr_schedulers), \ + 'Number of learning rates logged does not match number of lr schedulers' + assert all([k in ['lr-Adam'] for k in lr_logger.lrs.keys()]), \ + 'Names of learning rates not set correctly' + + +def test_lr_logger_multi_lrs(tmpdir): + """ Test that learning rates are extracted and logged for multi lr schedulers """ + tutils.reset_seed() + + class CurrentTestModel(LightTestOptimizersWithMixedSchedulingMixin, + LightTrainDataloader, + TestModelBase): + pass + + hparams = tutils.get_default_hparams() + model = CurrentTestModel(hparams) + + lr_logger = LearningRateLogger() + trainer = Trainer( + default_root_dir=tmpdir, + max_epochs=1, + val_percent_check=0.1, + train_percent_check=0.5, + callbacks=[lr_logger] + ) + results = trainer.fit(model) + + assert lr_logger.lrs, 'No learning rates logged' + assert len(lr_logger.lrs) == len(trainer.lr_schedulers), \ + 'Number of learning rates logged does not match number of lr schedulers' + assert all([k in ['lr-Adam', 'lr-Adam-1'] for k in lr_logger.lrs.keys()]), \ + 'Names of learning rates not set correctly'