Learning rate log callback (#1498)
* base implementation * docs + implementation * fix styling * add lr string * renaming * CHANGELOG.md * add tests * Apply suggestions from code review Co-Authored-By: Adrian Wälchli <aedu.waelchli@gmail.com> * Apply suggestions from code review * Update pytorch_lightning/callbacks/lr_logger.py * Update pytorch_lightning/callbacks/lr_logger.py * add test for naming * base implementation * docs + implementation * fix styling * add lr string * renaming * CHANGELOG.md * add tests * Apply suggestions from code review Co-Authored-By: Adrian Wälchli <aedu.waelchli@gmail.com> * Apply suggestions from code review * Update pytorch_lightning/callbacks/lr_logger.py * Update pytorch_lightning/callbacks/lr_logger.py * add test for naming * Update pytorch_lightning/callbacks/lr_logger.py Co-Authored-By: Adrian Wälchli <aedu.waelchli@gmail.com> * suggestions from code review * fix styling * rebase * fix tests Co-authored-by: Nicki Skafte <nugginea@gmail.com> Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com>
This commit is contained in:
parent
3eac6cfd4f
commit
142bc0230e
|
@ -8,6 +8,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
|
||||||
|
|
||||||
### Added
|
### Added
|
||||||
|
|
||||||
|
- Added callback for logging learning rates ([#1498](https://github.com/PyTorchLightning/pytorch-lightning/pull/1498))
|
||||||
|
|
||||||
### Changed
|
### Changed
|
||||||
|
|
||||||
### Deprecated
|
### Deprecated
|
||||||
|
|
|
@ -84,3 +84,11 @@ We successfully extended functionality without polluting our super clean
|
||||||
.. automodule:: pytorch_lightning.callbacks.progress
|
.. automodule:: pytorch_lightning.callbacks.progress
|
||||||
:noindex:
|
:noindex:
|
||||||
:exclude-members:
|
:exclude-members:
|
||||||
|
|
||||||
|
---------
|
||||||
|
|
||||||
|
.. automodule:: pytorch_lightning.callbacks.lr_logger
|
||||||
|
:noindex:
|
||||||
|
:exclude-members:
|
||||||
|
_extract_lr,
|
||||||
|
_find_names
|
|
@ -2,6 +2,7 @@ from pytorch_lightning.callbacks.base import Callback
|
||||||
from pytorch_lightning.callbacks.early_stopping import EarlyStopping
|
from pytorch_lightning.callbacks.early_stopping import EarlyStopping
|
||||||
from pytorch_lightning.callbacks.gradient_accumulation_scheduler import GradientAccumulationScheduler
|
from pytorch_lightning.callbacks.gradient_accumulation_scheduler import GradientAccumulationScheduler
|
||||||
from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint
|
from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint
|
||||||
|
from pytorch_lightning.callbacks.lr_logger import LearningRateLogger
|
||||||
from pytorch_lightning.callbacks.progress import ProgressBarBase, ProgressBar
|
from pytorch_lightning.callbacks.progress import ProgressBarBase, ProgressBar
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
|
@ -9,6 +10,7 @@ __all__ = [
|
||||||
'EarlyStopping',
|
'EarlyStopping',
|
||||||
'ModelCheckpoint',
|
'ModelCheckpoint',
|
||||||
'GradientAccumulationScheduler',
|
'GradientAccumulationScheduler',
|
||||||
|
'LearningRateLogger',
|
||||||
'ProgressBarBase',
|
'ProgressBarBase',
|
||||||
'ProgressBar',
|
'ProgressBar',
|
||||||
]
|
]
|
||||||
|
|
|
@ -0,0 +1,118 @@
|
||||||
|
r"""
|
||||||
|
|
||||||
|
Logging of learning rates
|
||||||
|
=========================
|
||||||
|
|
||||||
|
Log learning rate for lr schedulers during training
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
from pytorch_lightning.callbacks.base import Callback
|
||||||
|
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
||||||
|
|
||||||
|
|
||||||
|
class LearningRateLogger(Callback):
|
||||||
|
r"""
|
||||||
|
Automatically logs learning rate for learning rate schedulers during training.
|
||||||
|
|
||||||
|
Example::
|
||||||
|
|
||||||
|
>>> from pytorch_lightning import Trainer
|
||||||
|
>>> from pytorch_lightning.callbacks import LearningRateLogger
|
||||||
|
>>> lr_logger = LearningRateLogger()
|
||||||
|
>>> trainer = Trainer(callbacks=[lr_logger])
|
||||||
|
|
||||||
|
Logging names are automatically determined based on optimizer class name.
|
||||||
|
In case of multiple optimizers of same type, they will be named `Adam`,
|
||||||
|
`Adam-1` etc. If a optimizer has multiple parameter groups they will
|
||||||
|
be named `Adam/pg1`, `Adam/pg2` etc. To control naming, pass in a
|
||||||
|
`name` keyword in the construction of the learning rate schdulers
|
||||||
|
|
||||||
|
Example::
|
||||||
|
|
||||||
|
def configure_optimizer(self):
|
||||||
|
optimizer = torch.optim.Adam(...)
|
||||||
|
lr_scheduler = {'scheduler': torch.optim.lr_schedulers.LambdaLR(optimizer, ...)
|
||||||
|
'name': 'my_logging_name'}
|
||||||
|
return [optimizer], [lr_scheduler]
|
||||||
|
"""
|
||||||
|
def __init__(self):
|
||||||
|
self.lrs = None
|
||||||
|
self.lr_sch_names = []
|
||||||
|
|
||||||
|
def on_train_start(self, trainer, pl_module):
|
||||||
|
""" Called before training, determines unique names for all lr
|
||||||
|
schedulers in the case of multiple of the same type or in
|
||||||
|
the case of multiple parameter groups
|
||||||
|
"""
|
||||||
|
if trainer.lr_schedulers == []:
|
||||||
|
raise MisconfigurationException(
|
||||||
|
'Cannot use LearningRateLogger callback with models that have no'
|
||||||
|
' learning rate schedulers. Please see documentation for'
|
||||||
|
' `configure_optimizers` method.')
|
||||||
|
|
||||||
|
if not trainer.logger:
|
||||||
|
raise MisconfigurationException(
|
||||||
|
'Cannot use LearningRateLogger callback with Trainer that has no logger.')
|
||||||
|
|
||||||
|
# Find names for schedulers
|
||||||
|
names = self._find_names(trainer.lr_schedulers)
|
||||||
|
|
||||||
|
# Initialize for storing values
|
||||||
|
self.lrs = dict.fromkeys(names, [])
|
||||||
|
|
||||||
|
def on_batch_start(self, trainer, pl_module):
|
||||||
|
latest_stat = self._extract_lr(trainer, 'step')
|
||||||
|
if trainer.logger and latest_stat:
|
||||||
|
trainer.logger.log_metrics(latest_stat, step=trainer.global_step)
|
||||||
|
|
||||||
|
def on_epoch_start(self, trainer, pl_module):
|
||||||
|
latest_stat = self._extract_lr(trainer, 'epoch')
|
||||||
|
if trainer.logger and latest_stat:
|
||||||
|
trainer.logger.log_metrics(latest_stat, step=trainer.global_step)
|
||||||
|
|
||||||
|
def _extract_lr(self, trainer, interval):
|
||||||
|
""" Extracts learning rates for lr schedulers and saves information
|
||||||
|
into dict structure. """
|
||||||
|
latest_stat = {}
|
||||||
|
for name, scheduler in zip(self.lr_sch_names, trainer.lr_schedulers):
|
||||||
|
if scheduler['interval'] == interval:
|
||||||
|
param_groups = scheduler['scheduler'].optimizer.param_groups
|
||||||
|
if len(param_groups) != 1:
|
||||||
|
for i, pg in enumerate(param_groups):
|
||||||
|
lr, key = pg['lr'], f'{name}/{i + 1}'
|
||||||
|
self.lrs[key].append(lr)
|
||||||
|
latest_stat[key] = lr
|
||||||
|
else:
|
||||||
|
self.lrs[name].append(param_groups[0]['lr'])
|
||||||
|
latest_stat[name] = param_groups[0]['lr']
|
||||||
|
return latest_stat
|
||||||
|
|
||||||
|
def _find_names(self, lr_schedulers):
|
||||||
|
# Create uniqe names in the case we have multiple of the same learning
|
||||||
|
# rate schduler + multiple parameter groups
|
||||||
|
names = []
|
||||||
|
for scheduler in lr_schedulers:
|
||||||
|
sch = scheduler['scheduler']
|
||||||
|
if 'name' in scheduler:
|
||||||
|
name = scheduler['name']
|
||||||
|
else:
|
||||||
|
opt_name = 'lr-' + sch.optimizer.__class__.__name__
|
||||||
|
i, name = 1, opt_name
|
||||||
|
# Multiple schduler of the same type
|
||||||
|
while True:
|
||||||
|
if name not in names:
|
||||||
|
break
|
||||||
|
i, name = i + 1, f'{opt_name}-{i}'
|
||||||
|
|
||||||
|
# Multiple param groups for the same schduler
|
||||||
|
param_groups = sch.optimizer.param_groups
|
||||||
|
if len(param_groups) != 1:
|
||||||
|
for i, pg in enumerate(param_groups):
|
||||||
|
temp = name + '/pg' + str(i + 1)
|
||||||
|
names.append(temp)
|
||||||
|
else:
|
||||||
|
names.append(name)
|
||||||
|
|
||||||
|
self.lr_sch_names.append(name)
|
||||||
|
return names
|
|
@ -2,11 +2,12 @@ import pytest
|
||||||
import tests.base.utils as tutils
|
import tests.base.utils as tutils
|
||||||
from pytorch_lightning import Callback
|
from pytorch_lightning import Callback
|
||||||
from pytorch_lightning import Trainer, LightningModule
|
from pytorch_lightning import Trainer, LightningModule
|
||||||
from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint
|
from pytorch_lightning.callbacks import EarlyStopping, LearningRateLogger, ModelCheckpoint
|
||||||
from tests.base import (
|
from tests.base import (
|
||||||
LightTrainDataloader,
|
LightTrainDataloader,
|
||||||
LightTestMixin,
|
LightTestMixin,
|
||||||
LightValidationMixin,
|
LightValidationMixin,
|
||||||
|
LightTestOptimizersWithMixedSchedulingMixin,
|
||||||
TestModelBase
|
TestModelBase
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -273,3 +274,59 @@ def test_model_checkpoint_with_non_string_input(tmpdir, save_top_k):
|
||||||
|
|
||||||
# These should be different if the dirpath has be overridden
|
# These should be different if the dirpath has be overridden
|
||||||
assert trainer.ckpt_path != trainer.default_root_dir
|
assert trainer.ckpt_path != trainer.default_root_dir
|
||||||
|
|
||||||
|
|
||||||
|
def test_lr_logger_single_lr(tmpdir):
|
||||||
|
""" Test that learning rates are extracted and logged for single lr scheduler"""
|
||||||
|
tutils.reset_seed()
|
||||||
|
|
||||||
|
class CurrentTestModel(LightTrainDataloader, TestModelBase):
|
||||||
|
pass
|
||||||
|
|
||||||
|
hparams = tutils.get_default_hparams()
|
||||||
|
model = CurrentTestModel(hparams)
|
||||||
|
|
||||||
|
lr_logger = LearningRateLogger()
|
||||||
|
trainer = Trainer(
|
||||||
|
default_root_dir=tmpdir,
|
||||||
|
max_epochs=5,
|
||||||
|
val_percent_check=0.1,
|
||||||
|
train_percent_check=0.5,
|
||||||
|
callbacks=[lr_logger]
|
||||||
|
)
|
||||||
|
results = trainer.fit(model)
|
||||||
|
|
||||||
|
assert lr_logger.lrs, 'No learning rates logged'
|
||||||
|
assert len(lr_logger.lrs) == len(trainer.lr_schedulers), \
|
||||||
|
'Number of learning rates logged does not match number of lr schedulers'
|
||||||
|
assert all([k in ['lr-Adam'] for k in lr_logger.lrs.keys()]), \
|
||||||
|
'Names of learning rates not set correctly'
|
||||||
|
|
||||||
|
|
||||||
|
def test_lr_logger_multi_lrs(tmpdir):
|
||||||
|
""" Test that learning rates are extracted and logged for multi lr schedulers """
|
||||||
|
tutils.reset_seed()
|
||||||
|
|
||||||
|
class CurrentTestModel(LightTestOptimizersWithMixedSchedulingMixin,
|
||||||
|
LightTrainDataloader,
|
||||||
|
TestModelBase):
|
||||||
|
pass
|
||||||
|
|
||||||
|
hparams = tutils.get_default_hparams()
|
||||||
|
model = CurrentTestModel(hparams)
|
||||||
|
|
||||||
|
lr_logger = LearningRateLogger()
|
||||||
|
trainer = Trainer(
|
||||||
|
default_root_dir=tmpdir,
|
||||||
|
max_epochs=1,
|
||||||
|
val_percent_check=0.1,
|
||||||
|
train_percent_check=0.5,
|
||||||
|
callbacks=[lr_logger]
|
||||||
|
)
|
||||||
|
results = trainer.fit(model)
|
||||||
|
|
||||||
|
assert lr_logger.lrs, 'No learning rates logged'
|
||||||
|
assert len(lr_logger.lrs) == len(trainer.lr_schedulers), \
|
||||||
|
'Number of learning rates logged does not match number of lr schedulers'
|
||||||
|
assert all([k in ['lr-Adam', 'lr-Adam-1'] for k in lr_logger.lrs.keys()]), \
|
||||||
|
'Names of learning rates not set correctly'
|
||||||
|
|
Loading…
Reference in New Issue