From b1bb3f31735e7df22097a8c15dd5323d75fda45d Mon Sep 17 00:00:00 2001 From: Gilles Peiffer Date: Tue, 21 May 2024 19:31:54 +0200 Subject: [PATCH] Update `LearningRateMonitor` docs and tests for `log_weight_decay` (#19805) --- src/lightning/pytorch/callbacks/lr_monitor.py | 4 +++- tests/tests_pytorch/callbacks/test_lr_monitor.py | 3 +++ 2 files changed, 6 insertions(+), 1 deletion(-) diff --git a/src/lightning/pytorch/callbacks/lr_monitor.py b/src/lightning/pytorch/callbacks/lr_monitor.py index 357cfceefa..6a94c7ece7 100644 --- a/src/lightning/pytorch/callbacks/lr_monitor.py +++ b/src/lightning/pytorch/callbacks/lr_monitor.py @@ -44,6 +44,8 @@ class LearningRateMonitor(Callback): according to the ``interval`` key of each scheduler. Defaults to ``None``. log_momentum: option to also log the momentum values of the optimizer, if the optimizer has the ``momentum`` or ``betas`` attribute. Defaults to ``False``. + log_weight_decay: option to also log the weight decay values of the optimizer. Defaults to + ``False``. Raises: MisconfigurationException: @@ -58,7 +60,7 @@ class LearningRateMonitor(Callback): Logging names are automatically determined based on optimizer class name. In case of multiple optimizers of same type, they will be named ``Adam``, - ``Adam-1`` etc. If a optimizer has multiple parameter groups they will + ``Adam-1`` etc. If an optimizer has multiple parameter groups they will be named ``Adam/pg1``, ``Adam/pg2`` etc. To control naming, pass in a ``name`` keyword in the construction of the learning rate schedulers. A ``name`` keyword can also be used for parameter groups in the diff --git a/tests/tests_pytorch/callbacks/test_lr_monitor.py b/tests/tests_pytorch/callbacks/test_lr_monitor.py index ebe21e272a..4aedb4f23f 100644 --- a/tests/tests_pytorch/callbacks/test_lr_monitor.py +++ b/tests/tests_pytorch/callbacks/test_lr_monitor.py @@ -44,6 +44,9 @@ def test_lr_monitor_single_lr(tmp_path): assert lr_monitor.lrs, "No learning rates logged" assert all(v is None for v in lr_monitor.last_momentum_values.values()), "Momentum should not be logged by default" + assert all( + v is None for v in lr_monitor.last_weight_decay_values.values() + ), "Weight decay should not be logged by default" assert len(lr_monitor.lrs) == len(trainer.lr_scheduler_configs) assert list(lr_monitor.lrs) == ["lr-SGD"]