[feat] Named Parameter Groups in `LearningRateMonitor` (#7987)
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Kaushik B <45285388+kaushikb11@users.noreply.github.com> Co-authored-by: Carlos Mocholí <carlossmocholi@gmail.com>
This commit is contained in:
parent
5647087f03
commit
906de2a7fa
|
@ -9,6 +9,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
|
|||
|
||||
### Added
|
||||
|
||||
- Add support for named parameter groups in `LearningRateMonitor` ([#7987](https://github.com/PyTorchLightning/pytorch-lightning/pull/7987))
|
||||
|
||||
|
||||
- Add `dataclass` support for `pytorch_lightning.utilities.apply_to_collection` ([#7935](https://github.com/PyTorchLightning/pytorch-lightning/pull/7935))
|
||||
|
||||
|
||||
|
|
|
@ -20,7 +20,7 @@ Monitor and logs learning rate for lr schedulers during training.
|
|||
|
||||
"""
|
||||
from collections import defaultdict
|
||||
from typing import Any, DefaultDict, Dict, List, Optional, Type
|
||||
from typing import Any, DefaultDict, Dict, List, Optional, Set, Type
|
||||
|
||||
from torch.optim.optimizer import Optimizer
|
||||
|
||||
|
@ -55,7 +55,9 @@ class LearningRateMonitor(Callback):
|
|||
In case of multiple optimizers of same type, they will be named ``Adam``,
|
||||
``Adam-1`` etc. If a optimizer has multiple parameter groups they will
|
||||
be named ``Adam/pg1``, ``Adam/pg2`` etc. To control naming, pass in a
|
||||
``name`` keyword in the construction of the learning rate schedulers
|
||||
``name`` keyword in the construction of the learning rate schedulers.
|
||||
A ``name`` keyword can also be used for parameter groups in the
|
||||
construction of the optimizer.
|
||||
|
||||
Example::
|
||||
|
||||
|
@ -67,6 +69,19 @@ class LearningRateMonitor(Callback):
|
|||
}
|
||||
return [optimizer], [lr_scheduler]
|
||||
|
||||
Example::
|
||||
|
||||
def configure_optimizer(self):
|
||||
optimizer = torch.optim.SGD(
|
||||
[{
|
||||
'params': [p for p in self.parameters()],
|
||||
'name': 'my_parameter_group_name'
|
||||
}],
|
||||
lr=0.1
|
||||
)
|
||||
lr_scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, ...)
|
||||
return [optimizer], [lr_scheduler]
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, logging_interval: Optional[str] = None, log_momentum: bool = False):
|
||||
|
@ -150,11 +165,11 @@ class LearningRateMonitor(Callback):
|
|||
use_betas = 'betas' in opt.defaults
|
||||
|
||||
for i, pg in enumerate(param_groups):
|
||||
suffix = f'/pg{i + 1}' if len(param_groups) > 1 else ''
|
||||
lr = self._extract_lr(pg, f'{name}{suffix}')
|
||||
name_and_suffix = self._add_suffix(name, param_groups, i)
|
||||
lr = self._extract_lr(pg, name_and_suffix)
|
||||
latest_stat.update(lr)
|
||||
momentum = self._extract_momentum(
|
||||
param_group=pg, name=f'{name}-momentum{suffix}', use_betas=use_betas
|
||||
param_group=pg, name=name_and_suffix.replace(name, f'{name}-momentum'), use_betas=use_betas
|
||||
)
|
||||
latest_stat.update(momentum)
|
||||
|
||||
|
@ -192,6 +207,26 @@ class LearningRateMonitor(Callback):
|
|||
count = seen_optimizer_types[optimizer_cls]
|
||||
return name + f'-{count - 1}' if count > 1 else name
|
||||
|
||||
def _add_suffix(self, name: str, param_groups: List[Dict], param_group_index: int, use_names: bool = True) -> str:
|
||||
if len(param_groups) > 1:
|
||||
if not use_names:
|
||||
return f'{name}/pg{param_group_index+1}'
|
||||
else:
|
||||
pg_name = param_groups[param_group_index].get('name', f'pg{param_group_index+1}')
|
||||
return f'{name}/{pg_name}'
|
||||
elif use_names:
|
||||
pg_name = param_groups[param_group_index].get('name')
|
||||
return f'{name}/{pg_name}' if pg_name else name
|
||||
return name
|
||||
|
||||
def _duplicate_param_group_names(self, param_groups: List[Dict]) -> Set[str]:
|
||||
names = [pg.get('name', f'pg{i}') for i, pg in enumerate(param_groups, start=1)]
|
||||
unique = set(names)
|
||||
if len(names) == len(unique):
|
||||
return set()
|
||||
else:
|
||||
return set(n for n in names if names.count(n) > 1)
|
||||
|
||||
def _find_names(self, lr_schedulers: List, add_lr_sch_names: bool = True) -> List[str]:
|
||||
# Create unique names in the case we have multiple of the same learning
|
||||
# rate scheduler + multiple parameter groups
|
||||
|
@ -212,15 +247,16 @@ class LearningRateMonitor(Callback):
|
|||
|
||||
# Multiple param groups for the same scheduler
|
||||
param_groups = sch.optimizer.param_groups
|
||||
duplicates = self._duplicate_param_group_names(param_groups)
|
||||
if duplicates:
|
||||
raise MisconfigurationException(
|
||||
'A single `Optimizer` cannot have multiple parameter groups with identical '
|
||||
f'`name` values. {name} has duplicated parameter group names {duplicates}'
|
||||
)
|
||||
|
||||
name = self._add_prefix(name, optimizer_cls, seen_optimizer_types)
|
||||
|
||||
if len(param_groups) != 1:
|
||||
for i in range(len(param_groups)):
|
||||
temp = f'{name}/pg{i + 1}'
|
||||
names.append(temp)
|
||||
else:
|
||||
names.append(name)
|
||||
names.extend(self._add_suffix(name, param_groups, i) for i in range(len(param_groups)))
|
||||
|
||||
if add_lr_sch_names:
|
||||
self.lr_sch_names.append(name)
|
||||
|
|
|
@ -283,6 +283,77 @@ def test_lr_monitor_custom_name(tmpdir):
|
|||
assert lr_monitor.lr_sch_names == list(lr_monitor.lrs.keys()) == ['my_logging_name']
|
||||
|
||||
|
||||
def test_lr_monitor_custom_pg_name(tmpdir):
|
||||
|
||||
class TestModel(BoringModel):
|
||||
|
||||
def configure_optimizers(self):
|
||||
optimizer = torch.optim.SGD([{'params': [p for p in self.layer.parameters()], 'name': 'linear'}], lr=0.1)
|
||||
lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1)
|
||||
return [optimizer], [lr_scheduler]
|
||||
|
||||
lr_monitor = LearningRateMonitor()
|
||||
trainer = Trainer(
|
||||
default_root_dir=tmpdir,
|
||||
max_epochs=2,
|
||||
limit_val_batches=2,
|
||||
limit_train_batches=2,
|
||||
callbacks=[lr_monitor],
|
||||
progress_bar_refresh_rate=0,
|
||||
weights_summary=None,
|
||||
)
|
||||
trainer.fit(TestModel())
|
||||
assert lr_monitor.lr_sch_names == ['lr-SGD']
|
||||
assert list(lr_monitor.lrs) == ['lr-SGD/linear']
|
||||
|
||||
|
||||
def test_lr_monitor_duplicate_custom_pg_names(tmpdir):
|
||||
tutils.reset_seed()
|
||||
|
||||
class TestModel(BoringModel):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.linear_a = torch.nn.Linear(32, 16)
|
||||
self.linear_b = torch.nn.Linear(16, 2)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.linear_a(x)
|
||||
x = self.linear_b(x)
|
||||
return x
|
||||
|
||||
def configure_optimizers(self):
|
||||
param_groups = [
|
||||
{
|
||||
'params': [p for p in self.linear_a.parameters()],
|
||||
'name': 'linear'
|
||||
},
|
||||
{
|
||||
'params': [p for p in self.linear_b.parameters()],
|
||||
'name': 'linear'
|
||||
},
|
||||
]
|
||||
optimizer = torch.optim.SGD(param_groups, lr=0.1)
|
||||
lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1)
|
||||
return [optimizer], [lr_scheduler]
|
||||
|
||||
lr_monitor = LearningRateMonitor()
|
||||
trainer = Trainer(
|
||||
default_root_dir=tmpdir,
|
||||
max_epochs=2,
|
||||
limit_val_batches=2,
|
||||
limit_train_batches=2,
|
||||
callbacks=[lr_monitor],
|
||||
progress_bar_refresh_rate=0,
|
||||
weights_summary=None,
|
||||
)
|
||||
|
||||
with pytest.raises(
|
||||
MisconfigurationException, match='A single `Optimizer` cannot have multiple parameter groups with identical'
|
||||
):
|
||||
trainer.fit(TestModel())
|
||||
|
||||
|
||||
def test_multiple_optimizers_basefinetuning(tmpdir):
|
||||
|
||||
class TestModel(BoringModel):
|
||||
|
|
Loading…
Reference in New Issue