Update lr_logger.py (#2847)
* Update lr_logger.py when logging learning_rate, we should provide different choices to log including 'step' and 'epoch' * Update lr_logger.py add some type annotations and docstrings * Update lr_logger.py fixed a bug where `on_train_batch_start()` can't be triggered, instead, we should use on_batch_start(); add `interval` args so that we can record learning_rates with respect to `global_step` or `current_epoch`. * Update lr_logger.py restore _extract_lr() * suggestion * Update lr_logger.py modify _extract_lr(), it no more need to pass `interval` parameter. * Update test_lr_logger.py SkafteNicki 's suggetion * log_interval now supports `None`, `step`, `epoch` * change `log_interval` to `logging_interval` * Update test_lr_logger.py * Update lr_logger.py * put types check into `on_train_start()` * cleanup * docstring typos * minor changes from suggestions Co-authored-by: Jirka Borovec <jirka@pytorchlightning.ai> Co-authored-by: rohitgr7 <rohitgr1998@gmail.com> Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com> Co-authored-by: mergify[bot] <37929162+mergify[bot]@users.noreply.github.com>
This commit is contained in:
parent
0cfa05b703
commit
6c18fd9a24
|
@ -7,6 +7,8 @@ Log learning rate for lr schedulers during training
|
|||
|
||||
"""
|
||||
|
||||
from typing import Optional
|
||||
|
||||
from pytorch_lightning.callbacks.base import Callback
|
||||
from pytorch_lightning.utilities import rank_zero_warn
|
||||
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
||||
|
@ -16,11 +18,16 @@ class LearningRateLogger(Callback):
|
|||
r"""
|
||||
Automatically logs learning rate for learning rate schedulers during training.
|
||||
|
||||
Args:
|
||||
logging_interval: set to `epoch` or `step` to log `lr` of all optimizers
|
||||
at the same interval, set to `None` to log at individual interval
|
||||
according to the `interval` key of each scheduler. Defaults to ``None``.
|
||||
|
||||
Example::
|
||||
|
||||
>>> from pytorch_lightning import Trainer
|
||||
>>> from pytorch_lightning.callbacks import LearningRateLogger
|
||||
>>> lr_logger = LearningRateLogger()
|
||||
>>> lr_logger = LearningRateLogger(logging_interval='step')
|
||||
>>> trainer = Trainer(callbacks=[lr_logger])
|
||||
|
||||
Logging names are automatically determined based on optimizer class name.
|
||||
|
@ -37,7 +44,13 @@ class LearningRateLogger(Callback):
|
|||
'name': 'my_logging_name'}
|
||||
return [optimizer], [lr_scheduler]
|
||||
"""
|
||||
def __init__(self):
|
||||
def __init__(self, logging_interval: Optional[str] = None):
|
||||
if logging_interval not in (None, 'step', 'epoch'):
|
||||
raise MisconfigurationException(
|
||||
'logging_interval should be `step` or `epoch` or `None`.'
|
||||
)
|
||||
|
||||
self.logging_interval = logging_interval
|
||||
self.lrs = None
|
||||
self.lr_sch_names = []
|
||||
|
||||
|
@ -48,7 +61,8 @@ class LearningRateLogger(Callback):
|
|||
"""
|
||||
if not trainer.logger:
|
||||
raise MisconfigurationException(
|
||||
'Cannot use LearningRateLogger callback with Trainer that has no logger.')
|
||||
'Cannot use LearningRateLogger callback with Trainer that has no logger.'
|
||||
)
|
||||
|
||||
if not trainer.lr_schedulers:
|
||||
rank_zero_warn(
|
||||
|
@ -63,22 +77,29 @@ class LearningRateLogger(Callback):
|
|||
# Initialize for storing values
|
||||
self.lrs = {name: [] for name in names}
|
||||
|
||||
def on_train_batch_start(self, trainer, pl_module, batch, batch_idx, dataloader_idx):
|
||||
latest_stat = self._extract_lr(trainer, 'step')
|
||||
if trainer.logger and latest_stat:
|
||||
trainer.logger.log_metrics(latest_stat, step=trainer.global_step)
|
||||
def on_batch_start(self, trainer, pl_module):
|
||||
if self.logging_interval != 'epoch':
|
||||
interval = 'step' if self.logging_interval is None else 'any'
|
||||
latest_stat = self._extract_lr(trainer, interval)
|
||||
|
||||
if trainer.logger is not None and latest_stat:
|
||||
trainer.logger.log_metrics(latest_stat, step=trainer.global_step)
|
||||
|
||||
def on_epoch_start(self, trainer, pl_module):
|
||||
latest_stat = self._extract_lr(trainer, 'epoch')
|
||||
if trainer.logger and latest_stat:
|
||||
trainer.logger.log_metrics(latest_stat, step=trainer.global_step)
|
||||
if self.logging_interval != 'step':
|
||||
interval = 'epoch' if self.logging_interval is None else 'any'
|
||||
latest_stat = self._extract_lr(trainer, interval)
|
||||
|
||||
if trainer.logger is not None and latest_stat:
|
||||
trainer.logger.log_metrics(latest_stat, step=trainer.current_epoch)
|
||||
|
||||
def _extract_lr(self, trainer, interval):
|
||||
""" Extracts learning rates for lr schedulers and saves information
|
||||
into dict structure. """
|
||||
latest_stat = {}
|
||||
|
||||
for name, scheduler in zip(self.lr_sch_names, trainer.lr_schedulers):
|
||||
if scheduler['interval'] == interval:
|
||||
if scheduler['interval'] == interval or interval == 'any':
|
||||
param_groups = scheduler['scheduler'].optimizer.param_groups
|
||||
if len(param_groups) != 1:
|
||||
for i, pg in enumerate(param_groups):
|
||||
|
@ -88,6 +109,7 @@ class LearningRateLogger(Callback):
|
|||
else:
|
||||
self.lrs[name].append(param_groups[0]['lr'])
|
||||
latest_stat[name] = param_groups[0]['lr']
|
||||
|
||||
return latest_stat
|
||||
|
||||
def _find_names(self, lr_schedulers):
|
||||
|
@ -109,6 +131,7 @@ class LearningRateLogger(Callback):
|
|||
|
||||
# Multiple param groups for the same schduler
|
||||
param_groups = sch.optimizer.param_groups
|
||||
|
||||
if len(param_groups) != 1:
|
||||
for i, pg in enumerate(param_groups):
|
||||
temp = f'{name}/pg{i + 1}'
|
||||
|
@ -117,4 +140,5 @@ class LearningRateLogger(Callback):
|
|||
names.append(name)
|
||||
|
||||
self.lr_sch_names.append(name)
|
||||
|
||||
return names
|
||||
|
|
|
@ -50,14 +50,15 @@ def test_lr_logger_no_lr(tmpdir):
|
|||
assert result
|
||||
|
||||
|
||||
def test_lr_logger_multi_lrs(tmpdir):
|
||||
@pytest.mark.parametrize("logging_interval", ['step', 'epoch'])
|
||||
def test_lr_logger_multi_lrs(tmpdir, logging_interval):
|
||||
""" Test that learning rates are extracted and logged for multi lr schedulers. """
|
||||
tutils.reset_seed()
|
||||
|
||||
model = EvalModelTemplate()
|
||||
model.configure_optimizers = model.configure_optimizers__multiple_schedulers
|
||||
|
||||
lr_logger = LearningRateLogger()
|
||||
lr_logger = LearningRateLogger(logging_interval=logging_interval)
|
||||
trainer = Trainer(
|
||||
default_root_dir=tmpdir,
|
||||
max_epochs=2,
|
||||
|
@ -73,8 +74,14 @@ def test_lr_logger_multi_lrs(tmpdir):
|
|||
'Number of learning rates logged does not match number of lr schedulers'
|
||||
assert all([k in ['lr-Adam', 'lr-Adam-1'] for k in lr_logger.lrs.keys()]), \
|
||||
'Names of learning rates not set correctly'
|
||||
assert all(len(lr) == trainer.max_epochs for k, lr in lr_logger.lrs.items()), \
|
||||
'Length of logged learning rates exceeds the number of epochs'
|
||||
|
||||
if logging_interval == 'step':
|
||||
expected_number_logged = trainer.global_step
|
||||
if logging_interval == 'epoch':
|
||||
expected_number_logged = trainer.max_epochs
|
||||
|
||||
assert all(len(lr) == expected_number_logged for lr in lr_logger.lrs.values()), \
|
||||
'Length of logged learning rates do not match the expected number'
|
||||
|
||||
|
||||
def test_lr_logger_param_groups(tmpdir):
|
||||
|
|
Loading…
Reference in New Issue