Update lr_logger.py (#2847)

* Update lr_logger.py

when logging learning_rate, we should provide different choices to log including 'step' and 'epoch'

* Update lr_logger.py

add some type annotations and docstrings

* Update lr_logger.py

fixed a bug where `on_train_batch_start()` can't be triggered, instead, we should use on_batch_start(); add `interval` args so that we can record learning_rates with respect to `global_step` or `current_epoch`.

* Update lr_logger.py

restore _extract_lr()

* suggestion

* Update lr_logger.py

modify _extract_lr(), it no more need to pass `interval` parameter.

* Update test_lr_logger.py

SkafteNicki 's suggetion

* log_interval now supports `None`, `step`, `epoch`

* change `log_interval` to `logging_interval`

* Update test_lr_logger.py

* Update lr_logger.py

* put types check into `on_train_start()`

* cleanup

* docstring typos

* minor changes from suggestions

Co-authored-by: Jirka Borovec <jirka@pytorchlightning.ai>
Co-authored-by: rohitgr7 <rohitgr1998@gmail.com>
Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com>
Co-authored-by: mergify[bot] <37929162+mergify[bot]@users.noreply.github.com>
This commit is contained in:
Caldera 2020-08-10 00:30:43 +08:00 committed by GitHub
parent 0cfa05b703
commit 6c18fd9a24
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 46 additions and 15 deletions

View File

@ -7,6 +7,8 @@ Log learning rate for lr schedulers during training
"""
from typing import Optional
from pytorch_lightning.callbacks.base import Callback
from pytorch_lightning.utilities import rank_zero_warn
from pytorch_lightning.utilities.exceptions import MisconfigurationException
@ -16,11 +18,16 @@ class LearningRateLogger(Callback):
r"""
Automatically logs learning rate for learning rate schedulers during training.
Args:
logging_interval: set to `epoch` or `step` to log `lr` of all optimizers
at the same interval, set to `None` to log at individual interval
according to the `interval` key of each scheduler. Defaults to ``None``.
Example::
>>> from pytorch_lightning import Trainer
>>> from pytorch_lightning.callbacks import LearningRateLogger
>>> lr_logger = LearningRateLogger()
>>> lr_logger = LearningRateLogger(logging_interval='step')
>>> trainer = Trainer(callbacks=[lr_logger])
Logging names are automatically determined based on optimizer class name.
@ -37,7 +44,13 @@ class LearningRateLogger(Callback):
'name': 'my_logging_name'}
return [optimizer], [lr_scheduler]
"""
def __init__(self):
def __init__(self, logging_interval: Optional[str] = None):
if logging_interval not in (None, 'step', 'epoch'):
raise MisconfigurationException(
'logging_interval should be `step` or `epoch` or `None`.'
)
self.logging_interval = logging_interval
self.lrs = None
self.lr_sch_names = []
@ -48,7 +61,8 @@ class LearningRateLogger(Callback):
"""
if not trainer.logger:
raise MisconfigurationException(
'Cannot use LearningRateLogger callback with Trainer that has no logger.')
'Cannot use LearningRateLogger callback with Trainer that has no logger.'
)
if not trainer.lr_schedulers:
rank_zero_warn(
@ -63,22 +77,29 @@ class LearningRateLogger(Callback):
# Initialize for storing values
self.lrs = {name: [] for name in names}
def on_train_batch_start(self, trainer, pl_module, batch, batch_idx, dataloader_idx):
latest_stat = self._extract_lr(trainer, 'step')
if trainer.logger and latest_stat:
trainer.logger.log_metrics(latest_stat, step=trainer.global_step)
def on_batch_start(self, trainer, pl_module):
if self.logging_interval != 'epoch':
interval = 'step' if self.logging_interval is None else 'any'
latest_stat = self._extract_lr(trainer, interval)
if trainer.logger is not None and latest_stat:
trainer.logger.log_metrics(latest_stat, step=trainer.global_step)
def on_epoch_start(self, trainer, pl_module):
latest_stat = self._extract_lr(trainer, 'epoch')
if trainer.logger and latest_stat:
trainer.logger.log_metrics(latest_stat, step=trainer.global_step)
if self.logging_interval != 'step':
interval = 'epoch' if self.logging_interval is None else 'any'
latest_stat = self._extract_lr(trainer, interval)
if trainer.logger is not None and latest_stat:
trainer.logger.log_metrics(latest_stat, step=trainer.current_epoch)
def _extract_lr(self, trainer, interval):
""" Extracts learning rates for lr schedulers and saves information
into dict structure. """
latest_stat = {}
for name, scheduler in zip(self.lr_sch_names, trainer.lr_schedulers):
if scheduler['interval'] == interval:
if scheduler['interval'] == interval or interval == 'any':
param_groups = scheduler['scheduler'].optimizer.param_groups
if len(param_groups) != 1:
for i, pg in enumerate(param_groups):
@ -88,6 +109,7 @@ class LearningRateLogger(Callback):
else:
self.lrs[name].append(param_groups[0]['lr'])
latest_stat[name] = param_groups[0]['lr']
return latest_stat
def _find_names(self, lr_schedulers):
@ -109,6 +131,7 @@ class LearningRateLogger(Callback):
# Multiple param groups for the same schduler
param_groups = sch.optimizer.param_groups
if len(param_groups) != 1:
for i, pg in enumerate(param_groups):
temp = f'{name}/pg{i + 1}'
@ -117,4 +140,5 @@ class LearningRateLogger(Callback):
names.append(name)
self.lr_sch_names.append(name)
return names

View File

@ -50,14 +50,15 @@ def test_lr_logger_no_lr(tmpdir):
assert result
def test_lr_logger_multi_lrs(tmpdir):
@pytest.mark.parametrize("logging_interval", ['step', 'epoch'])
def test_lr_logger_multi_lrs(tmpdir, logging_interval):
""" Test that learning rates are extracted and logged for multi lr schedulers. """
tutils.reset_seed()
model = EvalModelTemplate()
model.configure_optimizers = model.configure_optimizers__multiple_schedulers
lr_logger = LearningRateLogger()
lr_logger = LearningRateLogger(logging_interval=logging_interval)
trainer = Trainer(
default_root_dir=tmpdir,
max_epochs=2,
@ -73,8 +74,14 @@ def test_lr_logger_multi_lrs(tmpdir):
'Number of learning rates logged does not match number of lr schedulers'
assert all([k in ['lr-Adam', 'lr-Adam-1'] for k in lr_logger.lrs.keys()]), \
'Names of learning rates not set correctly'
assert all(len(lr) == trainer.max_epochs for k, lr in lr_logger.lrs.items()), \
'Length of logged learning rates exceeds the number of epochs'
if logging_interval == 'step':
expected_number_logged = trainer.global_step
if logging_interval == 'epoch':
expected_number_logged = trainer.max_epochs
assert all(len(lr) == expected_number_logged for lr in lr_logger.lrs.values()), \
'Length of logged learning rates do not match the expected number'
def test_lr_logger_param_groups(tmpdir):