Update `LearningRateMonitor` docs and tests for `log_weight_decay` (#19805)

This commit is contained in:
Gilles Peiffer 2024-05-21 19:31:54 +02:00 committed by GitHub
parent d76feef0d6
commit b1bb3f3173
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 6 additions and 1 deletions

View File

@ -44,6 +44,8 @@ class LearningRateMonitor(Callback):
according to the ``interval`` key of each scheduler. Defaults to ``None``.
log_momentum: option to also log the momentum values of the optimizer, if the optimizer
has the ``momentum`` or ``betas`` attribute. Defaults to ``False``.
log_weight_decay: option to also log the weight decay values of the optimizer. Defaults to
``False``.
Raises:
MisconfigurationException:
@ -58,7 +60,7 @@ class LearningRateMonitor(Callback):
Logging names are automatically determined based on optimizer class name.
In case of multiple optimizers of same type, they will be named ``Adam``,
``Adam-1`` etc. If a optimizer has multiple parameter groups they will
``Adam-1`` etc. If an optimizer has multiple parameter groups they will
be named ``Adam/pg1``, ``Adam/pg2`` etc. To control naming, pass in a
``name`` keyword in the construction of the learning rate schedulers.
A ``name`` keyword can also be used for parameter groups in the

View File

@ -44,6 +44,9 @@ def test_lr_monitor_single_lr(tmp_path):
assert lr_monitor.lrs, "No learning rates logged"
assert all(v is None for v in lr_monitor.last_momentum_values.values()), "Momentum should not be logged by default"
assert all(
v is None for v in lr_monitor.last_weight_decay_values.values()
), "Weight decay should not be logged by default"
assert len(lr_monitor.lrs) == len(trainer.lr_scheduler_configs)
assert list(lr_monitor.lrs) == ["lr-SGD"]