diff --git a/pytorch_lightning/callbacks/lr_logger.py b/pytorch_lightning/callbacks/lr_logger.py index a401cc8451..a4f2ddc757 100755 --- a/pytorch_lightning/callbacks/lr_logger.py +++ b/pytorch_lightning/callbacks/lr_logger.py @@ -7,6 +7,8 @@ Log learning rate for lr schedulers during training """ +from typing import Optional + from pytorch_lightning.callbacks.base import Callback from pytorch_lightning.utilities import rank_zero_warn from pytorch_lightning.utilities.exceptions import MisconfigurationException @@ -16,11 +18,16 @@ class LearningRateLogger(Callback): r""" Automatically logs learning rate for learning rate schedulers during training. + Args: + logging_interval: set to `epoch` or `step` to log `lr` of all optimizers + at the same interval, set to `None` to log at individual interval + according to the `interval` key of each scheduler. Defaults to ``None``. + Example:: >>> from pytorch_lightning import Trainer >>> from pytorch_lightning.callbacks import LearningRateLogger - >>> lr_logger = LearningRateLogger() + >>> lr_logger = LearningRateLogger(logging_interval='step') >>> trainer = Trainer(callbacks=[lr_logger]) Logging names are automatically determined based on optimizer class name. @@ -37,7 +44,13 @@ class LearningRateLogger(Callback): 'name': 'my_logging_name'} return [optimizer], [lr_scheduler] """ - def __init__(self): + def __init__(self, logging_interval: Optional[str] = None): + if logging_interval not in (None, 'step', 'epoch'): + raise MisconfigurationException( + 'logging_interval should be `step` or `epoch` or `None`.' + ) + + self.logging_interval = logging_interval self.lrs = None self.lr_sch_names = [] @@ -48,7 +61,8 @@ class LearningRateLogger(Callback): """ if not trainer.logger: raise MisconfigurationException( - 'Cannot use LearningRateLogger callback with Trainer that has no logger.') + 'Cannot use LearningRateLogger callback with Trainer that has no logger.' + ) if not trainer.lr_schedulers: rank_zero_warn( @@ -63,22 +77,29 @@ class LearningRateLogger(Callback): # Initialize for storing values self.lrs = {name: [] for name in names} - def on_train_batch_start(self, trainer, pl_module, batch, batch_idx, dataloader_idx): - latest_stat = self._extract_lr(trainer, 'step') - if trainer.logger and latest_stat: - trainer.logger.log_metrics(latest_stat, step=trainer.global_step) + def on_batch_start(self, trainer, pl_module): + if self.logging_interval != 'epoch': + interval = 'step' if self.logging_interval is None else 'any' + latest_stat = self._extract_lr(trainer, interval) + + if trainer.logger is not None and latest_stat: + trainer.logger.log_metrics(latest_stat, step=trainer.global_step) def on_epoch_start(self, trainer, pl_module): - latest_stat = self._extract_lr(trainer, 'epoch') - if trainer.logger and latest_stat: - trainer.logger.log_metrics(latest_stat, step=trainer.global_step) + if self.logging_interval != 'step': + interval = 'epoch' if self.logging_interval is None else 'any' + latest_stat = self._extract_lr(trainer, interval) + + if trainer.logger is not None and latest_stat: + trainer.logger.log_metrics(latest_stat, step=trainer.current_epoch) def _extract_lr(self, trainer, interval): """ Extracts learning rates for lr schedulers and saves information into dict structure. """ latest_stat = {} + for name, scheduler in zip(self.lr_sch_names, trainer.lr_schedulers): - if scheduler['interval'] == interval: + if scheduler['interval'] == interval or interval == 'any': param_groups = scheduler['scheduler'].optimizer.param_groups if len(param_groups) != 1: for i, pg in enumerate(param_groups): @@ -88,6 +109,7 @@ class LearningRateLogger(Callback): else: self.lrs[name].append(param_groups[0]['lr']) latest_stat[name] = param_groups[0]['lr'] + return latest_stat def _find_names(self, lr_schedulers): @@ -109,6 +131,7 @@ class LearningRateLogger(Callback): # Multiple param groups for the same schduler param_groups = sch.optimizer.param_groups + if len(param_groups) != 1: for i, pg in enumerate(param_groups): temp = f'{name}/pg{i + 1}' @@ -117,4 +140,5 @@ class LearningRateLogger(Callback): names.append(name) self.lr_sch_names.append(name) + return names diff --git a/tests/callbacks/test_lr_logger.py b/tests/callbacks/test_lr_logger.py index ff12f98fbf..264329fa3f 100644 --- a/tests/callbacks/test_lr_logger.py +++ b/tests/callbacks/test_lr_logger.py @@ -50,14 +50,15 @@ def test_lr_logger_no_lr(tmpdir): assert result -def test_lr_logger_multi_lrs(tmpdir): +@pytest.mark.parametrize("logging_interval", ['step', 'epoch']) +def test_lr_logger_multi_lrs(tmpdir, logging_interval): """ Test that learning rates are extracted and logged for multi lr schedulers. """ tutils.reset_seed() model = EvalModelTemplate() model.configure_optimizers = model.configure_optimizers__multiple_schedulers - lr_logger = LearningRateLogger() + lr_logger = LearningRateLogger(logging_interval=logging_interval) trainer = Trainer( default_root_dir=tmpdir, max_epochs=2, @@ -73,8 +74,14 @@ def test_lr_logger_multi_lrs(tmpdir): 'Number of learning rates logged does not match number of lr schedulers' assert all([k in ['lr-Adam', 'lr-Adam-1'] for k in lr_logger.lrs.keys()]), \ 'Names of learning rates not set correctly' - assert all(len(lr) == trainer.max_epochs for k, lr in lr_logger.lrs.items()), \ - 'Length of logged learning rates exceeds the number of epochs' + + if logging_interval == 'step': + expected_number_logged = trainer.global_step + if logging_interval == 'epoch': + expected_number_logged = trainer.max_epochs + + assert all(len(lr) == expected_number_logged for lr in lr_logger.lrs.values()), \ + 'Length of logged learning rates do not match the expected number' def test_lr_logger_param_groups(tmpdir):