Log `LearningRateMonitor` values to `Trainer.callback_metrics` for `EarlyStopping` (#17626)

Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
Bas Krahmer 2023-05-18 19:01:52 +02:00 committed by GitHub
parent 2ce975882d
commit 3a68493d0a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 46 additions and 1 deletions

View File

@ -54,6 +54,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Added non-layer param count to the model summary ([#17005](https://github.com/Lightning-AI/lightning/pull/17005))
- Updated `LearningRateMonitor` to log monitored values to `trainer.callback_metrics` ([#17626](https://github.com/Lightning-AI/lightning/pull/17626))
### Changed
- Removed the limitation to call `self.trainer.model.parameters()` in `LightningModule.configure_optimizers()` ([#17309](https://github.com/Lightning-AI/lightning/pull/17309))

View File

@ -23,6 +23,7 @@ import itertools
from collections import defaultdict
from typing import Any, DefaultDict, Dict, List, Optional, Set, Tuple, Type
import torch
from torch.optim.optimizer import Optimizer
import lightning.pytorch as pl
@ -193,6 +194,10 @@ class LearningRateMonitor(Callback):
current_stat = self._get_lr_momentum_stat(opt, names)
latest_stat.update(current_stat)
trainer.callback_metrics.update(
{name: torch.tensor(value, device=trainer.strategy.root_device) for name, value in latest_stat.items()}
)
return latest_stat
def _get_lr_momentum_stat(self, optimizer: Optimizer, names: List[str]) -> Dict[str, float]:

View File

@ -16,7 +16,7 @@ import torch
from torch import optim
from lightning.pytorch import Trainer
from lightning.pytorch.callbacks import LearningRateMonitor
from lightning.pytorch.callbacks import EarlyStopping, LearningRateMonitor
from lightning.pytorch.callbacks.callback import Callback
from lightning.pytorch.callbacks.finetuning import BackboneFinetuning
from lightning.pytorch.demos.boring_classes import BoringModel
@ -626,3 +626,40 @@ def test_lr_monitor_multiple_param_groups_no_lr_scheduler(tmpdir):
assert list(lr_monitor.last_momentum_values) == ["lr-Adam/pg1-momentum", "lr-Adam/pg2-momentum"]
assert all(val == momentum for val in lr_monitor.last_momentum_values.values())
assert all(all(val == lr for val in lr_monitor.lrs[lr_key]) for lr_key in lr_monitor.lrs)
def test_lr_monitor_update_callback_metrics(tmpdir):
"""Test that the `LearningRateMonitor` callback updates trainer.callback_metrics."""
class TestModel(BoringModel):
def configure_optimizers(self):
optimizer = torch.optim.SGD(self.layer.parameters(), lr=0.1)
lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1, gamma=0.5)
return [optimizer], [lr_scheduler]
monitor_key = "lr-SGD"
stop_threshold = 0.02
expected_stop_epoch = 3
lr_monitor = LearningRateMonitor()
lr_es = EarlyStopping(
monitor=monitor_key, mode="min", stopping_threshold=stop_threshold, check_on_train_epoch_end=True
)
trainer = Trainer(
default_root_dir=tmpdir,
callbacks=[lr_monitor, lr_es],
max_epochs=5,
limit_val_batches=0,
limit_train_batches=2,
logger=CSVLogger(tmpdir),
)
model = TestModel()
trainer.fit(model)
assert monitor_key in trainer.callback_metrics
assert lr_monitor.lrs[monitor_key] == [0.1, 0.05, 0.025, 0.0125]
assert min(lr_monitor.lrs[monitor_key][:expected_stop_epoch]) > stop_threshold
assert len(lr_monitor.lrs[monitor_key][expected_stop_epoch:]) == 1
assert min(lr_monitor.lrs[monitor_key][expected_stop_epoch:]) < stop_threshold
assert trainer.current_epoch - 1 == expected_stop_epoch
assert lr_es.stopped_epoch == expected_stop_epoch