From 3a68493d0aa73ef8df88d5ec09e400c2285a501e Mon Sep 17 00:00:00 2001 From: Bas Krahmer Date: Thu, 18 May 2023 19:01:52 +0200 Subject: [PATCH] Log `LearningRateMonitor` values to `Trainer.callback_metrics` for `EarlyStopping` (#17626) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Adrian Wälchli Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- src/lightning/pytorch/CHANGELOG.md | 3 ++ src/lightning/pytorch/callbacks/lr_monitor.py | 5 +++ .../callbacks/test_lr_monitor.py | 39 ++++++++++++++++++- 3 files changed, 46 insertions(+), 1 deletion(-) diff --git a/src/lightning/pytorch/CHANGELOG.md b/src/lightning/pytorch/CHANGELOG.md index e78a5f2ee5..8dcec7db41 100644 --- a/src/lightning/pytorch/CHANGELOG.md +++ b/src/lightning/pytorch/CHANGELOG.md @@ -54,6 +54,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Added non-layer param count to the model summary ([#17005](https://github.com/Lightning-AI/lightning/pull/17005)) +- Updated `LearningRateMonitor` to log monitored values to `trainer.callback_metrics` ([#17626](https://github.com/Lightning-AI/lightning/pull/17626)) + + ### Changed - Removed the limitation to call `self.trainer.model.parameters()` in `LightningModule.configure_optimizers()` ([#17309](https://github.com/Lightning-AI/lightning/pull/17309)) diff --git a/src/lightning/pytorch/callbacks/lr_monitor.py b/src/lightning/pytorch/callbacks/lr_monitor.py index 65ccfbb5ff..d938db61d6 100644 --- a/src/lightning/pytorch/callbacks/lr_monitor.py +++ b/src/lightning/pytorch/callbacks/lr_monitor.py @@ -23,6 +23,7 @@ import itertools from collections import defaultdict from typing import Any, DefaultDict, Dict, List, Optional, Set, Tuple, Type +import torch from torch.optim.optimizer import Optimizer import lightning.pytorch as pl @@ -193,6 +194,10 @@ class LearningRateMonitor(Callback): current_stat = self._get_lr_momentum_stat(opt, names) latest_stat.update(current_stat) + trainer.callback_metrics.update( + {name: torch.tensor(value, device=trainer.strategy.root_device) for name, value in latest_stat.items()} + ) + return latest_stat def _get_lr_momentum_stat(self, optimizer: Optimizer, names: List[str]) -> Dict[str, float]: diff --git a/tests/tests_pytorch/callbacks/test_lr_monitor.py b/tests/tests_pytorch/callbacks/test_lr_monitor.py index e0b930a675..7f89ea845b 100644 --- a/tests/tests_pytorch/callbacks/test_lr_monitor.py +++ b/tests/tests_pytorch/callbacks/test_lr_monitor.py @@ -16,7 +16,7 @@ import torch from torch import optim from lightning.pytorch import Trainer -from lightning.pytorch.callbacks import LearningRateMonitor +from lightning.pytorch.callbacks import EarlyStopping, LearningRateMonitor from lightning.pytorch.callbacks.callback import Callback from lightning.pytorch.callbacks.finetuning import BackboneFinetuning from lightning.pytorch.demos.boring_classes import BoringModel @@ -626,3 +626,40 @@ def test_lr_monitor_multiple_param_groups_no_lr_scheduler(tmpdir): assert list(lr_monitor.last_momentum_values) == ["lr-Adam/pg1-momentum", "lr-Adam/pg2-momentum"] assert all(val == momentum for val in lr_monitor.last_momentum_values.values()) assert all(all(val == lr for val in lr_monitor.lrs[lr_key]) for lr_key in lr_monitor.lrs) + + +def test_lr_monitor_update_callback_metrics(tmpdir): + """Test that the `LearningRateMonitor` callback updates trainer.callback_metrics.""" + + class TestModel(BoringModel): + def configure_optimizers(self): + optimizer = torch.optim.SGD(self.layer.parameters(), lr=0.1) + lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1, gamma=0.5) + return [optimizer], [lr_scheduler] + + monitor_key = "lr-SGD" + stop_threshold = 0.02 + expected_stop_epoch = 3 + + lr_monitor = LearningRateMonitor() + lr_es = EarlyStopping( + monitor=monitor_key, mode="min", stopping_threshold=stop_threshold, check_on_train_epoch_end=True + ) + trainer = Trainer( + default_root_dir=tmpdir, + callbacks=[lr_monitor, lr_es], + max_epochs=5, + limit_val_batches=0, + limit_train_batches=2, + logger=CSVLogger(tmpdir), + ) + model = TestModel() + trainer.fit(model) + + assert monitor_key in trainer.callback_metrics + assert lr_monitor.lrs[monitor_key] == [0.1, 0.05, 0.025, 0.0125] + assert min(lr_monitor.lrs[monitor_key][:expected_stop_epoch]) > stop_threshold + assert len(lr_monitor.lrs[monitor_key][expected_stop_epoch:]) == 1 + assert min(lr_monitor.lrs[monitor_key][expected_stop_epoch:]) < stop_threshold + assert trainer.current_epoch - 1 == expected_stop_epoch + assert lr_es.stopped_epoch == expected_stop_epoch