Learning rate log callback (#1498)

* base implementation

* docs + implementation

* fix styling

* add lr string

* renaming

* CHANGELOG.md

* add tests

* Apply suggestions from code review

Co-Authored-By: Adrian Wälchli <aedu.waelchli@gmail.com>

* Apply suggestions from code review

* Update pytorch_lightning/callbacks/lr_logger.py

* Update pytorch_lightning/callbacks/lr_logger.py

* add test for naming

* base implementation

* docs + implementation

* fix styling

* add lr string

* renaming

* CHANGELOG.md

* add tests

* Apply suggestions from code review

Co-Authored-By: Adrian Wälchli <aedu.waelchli@gmail.com>

* Apply suggestions from code review

* Update pytorch_lightning/callbacks/lr_logger.py

* Update pytorch_lightning/callbacks/lr_logger.py

* add test for naming

* Update pytorch_lightning/callbacks/lr_logger.py

Co-Authored-By: Adrian Wälchli <aedu.waelchli@gmail.com>

* suggestions from code review

* fix styling

* rebase

* fix tests

Co-authored-by: Nicki Skafte <nugginea@gmail.com>
Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>
Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com>
This commit is contained in:
Nicki Skafte 2020-04-30 14:06:41 +02:00 committed by GitHub
parent 3eac6cfd4f
commit 142bc0230e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 188 additions and 1 deletions

View File

@ -8,6 +8,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
### Added
- Added callback for logging learning rates ([#1498](https://github.com/PyTorchLightning/pytorch-lightning/pull/1498))
### Changed
### Deprecated

View File

@ -84,3 +84,11 @@ We successfully extended functionality without polluting our super clean
.. automodule:: pytorch_lightning.callbacks.progress
:noindex:
:exclude-members:
---------
.. automodule:: pytorch_lightning.callbacks.lr_logger
:noindex:
:exclude-members:
_extract_lr,
_find_names

View File

@ -2,6 +2,7 @@ from pytorch_lightning.callbacks.base import Callback
from pytorch_lightning.callbacks.early_stopping import EarlyStopping
from pytorch_lightning.callbacks.gradient_accumulation_scheduler import GradientAccumulationScheduler
from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint
from pytorch_lightning.callbacks.lr_logger import LearningRateLogger
from pytorch_lightning.callbacks.progress import ProgressBarBase, ProgressBar
__all__ = [
@ -9,6 +10,7 @@ __all__ = [
'EarlyStopping',
'ModelCheckpoint',
'GradientAccumulationScheduler',
'LearningRateLogger',
'ProgressBarBase',
'ProgressBar',
]

View File

@ -0,0 +1,118 @@
r"""
Logging of learning rates
=========================
Log learning rate for lr schedulers during training
"""
from pytorch_lightning.callbacks.base import Callback
from pytorch_lightning.utilities.exceptions import MisconfigurationException
class LearningRateLogger(Callback):
r"""
Automatically logs learning rate for learning rate schedulers during training.
Example::
>>> from pytorch_lightning import Trainer
>>> from pytorch_lightning.callbacks import LearningRateLogger
>>> lr_logger = LearningRateLogger()
>>> trainer = Trainer(callbacks=[lr_logger])
Logging names are automatically determined based on optimizer class name.
In case of multiple optimizers of same type, they will be named `Adam`,
`Adam-1` etc. If a optimizer has multiple parameter groups they will
be named `Adam/pg1`, `Adam/pg2` etc. To control naming, pass in a
`name` keyword in the construction of the learning rate schdulers
Example::
def configure_optimizer(self):
optimizer = torch.optim.Adam(...)
lr_scheduler = {'scheduler': torch.optim.lr_schedulers.LambdaLR(optimizer, ...)
'name': 'my_logging_name'}
return [optimizer], [lr_scheduler]
"""
def __init__(self):
self.lrs = None
self.lr_sch_names = []
def on_train_start(self, trainer, pl_module):
""" Called before training, determines unique names for all lr
schedulers in the case of multiple of the same type or in
the case of multiple parameter groups
"""
if trainer.lr_schedulers == []:
raise MisconfigurationException(
'Cannot use LearningRateLogger callback with models that have no'
' learning rate schedulers. Please see documentation for'
' `configure_optimizers` method.')
if not trainer.logger:
raise MisconfigurationException(
'Cannot use LearningRateLogger callback with Trainer that has no logger.')
# Find names for schedulers
names = self._find_names(trainer.lr_schedulers)
# Initialize for storing values
self.lrs = dict.fromkeys(names, [])
def on_batch_start(self, trainer, pl_module):
latest_stat = self._extract_lr(trainer, 'step')
if trainer.logger and latest_stat:
trainer.logger.log_metrics(latest_stat, step=trainer.global_step)
def on_epoch_start(self, trainer, pl_module):
latest_stat = self._extract_lr(trainer, 'epoch')
if trainer.logger and latest_stat:
trainer.logger.log_metrics(latest_stat, step=trainer.global_step)
def _extract_lr(self, trainer, interval):
""" Extracts learning rates for lr schedulers and saves information
into dict structure. """
latest_stat = {}
for name, scheduler in zip(self.lr_sch_names, trainer.lr_schedulers):
if scheduler['interval'] == interval:
param_groups = scheduler['scheduler'].optimizer.param_groups
if len(param_groups) != 1:
for i, pg in enumerate(param_groups):
lr, key = pg['lr'], f'{name}/{i + 1}'
self.lrs[key].append(lr)
latest_stat[key] = lr
else:
self.lrs[name].append(param_groups[0]['lr'])
latest_stat[name] = param_groups[0]['lr']
return latest_stat
def _find_names(self, lr_schedulers):
# Create uniqe names in the case we have multiple of the same learning
# rate schduler + multiple parameter groups
names = []
for scheduler in lr_schedulers:
sch = scheduler['scheduler']
if 'name' in scheduler:
name = scheduler['name']
else:
opt_name = 'lr-' + sch.optimizer.__class__.__name__
i, name = 1, opt_name
# Multiple schduler of the same type
while True:
if name not in names:
break
i, name = i + 1, f'{opt_name}-{i}'
# Multiple param groups for the same schduler
param_groups = sch.optimizer.param_groups
if len(param_groups) != 1:
for i, pg in enumerate(param_groups):
temp = name + '/pg' + str(i + 1)
names.append(temp)
else:
names.append(name)
self.lr_sch_names.append(name)
return names

View File

@ -2,11 +2,12 @@ import pytest
import tests.base.utils as tutils
from pytorch_lightning import Callback
from pytorch_lightning import Trainer, LightningModule
from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint
from pytorch_lightning.callbacks import EarlyStopping, LearningRateLogger, ModelCheckpoint
from tests.base import (
LightTrainDataloader,
LightTestMixin,
LightValidationMixin,
LightTestOptimizersWithMixedSchedulingMixin,
TestModelBase
)
@ -273,3 +274,59 @@ def test_model_checkpoint_with_non_string_input(tmpdir, save_top_k):
# These should be different if the dirpath has be overridden
assert trainer.ckpt_path != trainer.default_root_dir
def test_lr_logger_single_lr(tmpdir):
""" Test that learning rates are extracted and logged for single lr scheduler"""
tutils.reset_seed()
class CurrentTestModel(LightTrainDataloader, TestModelBase):
pass
hparams = tutils.get_default_hparams()
model = CurrentTestModel(hparams)
lr_logger = LearningRateLogger()
trainer = Trainer(
default_root_dir=tmpdir,
max_epochs=5,
val_percent_check=0.1,
train_percent_check=0.5,
callbacks=[lr_logger]
)
results = trainer.fit(model)
assert lr_logger.lrs, 'No learning rates logged'
assert len(lr_logger.lrs) == len(trainer.lr_schedulers), \
'Number of learning rates logged does not match number of lr schedulers'
assert all([k in ['lr-Adam'] for k in lr_logger.lrs.keys()]), \
'Names of learning rates not set correctly'
def test_lr_logger_multi_lrs(tmpdir):
""" Test that learning rates are extracted and logged for multi lr schedulers """
tutils.reset_seed()
class CurrentTestModel(LightTestOptimizersWithMixedSchedulingMixin,
LightTrainDataloader,
TestModelBase):
pass
hparams = tutils.get_default_hparams()
model = CurrentTestModel(hparams)
lr_logger = LearningRateLogger()
trainer = Trainer(
default_root_dir=tmpdir,
max_epochs=1,
val_percent_check=0.1,
train_percent_check=0.5,
callbacks=[lr_logger]
)
results = trainer.fit(model)
assert lr_logger.lrs, 'No learning rates logged'
assert len(lr_logger.lrs) == len(trainer.lr_schedulers), \
'Number of learning rates logged does not match number of lr schedulers'
assert all([k in ['lr-Adam', 'lr-Adam-1'] for k in lr_logger.lrs.keys()]), \
'Names of learning rates not set correctly'