diff --git a/CHANGELOG.md b/CHANGELOG.md index 75da2d4171..6946d11498 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -9,6 +9,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Added +- Add support for named parameter groups in `LearningRateMonitor` ([#7987](https://github.com/PyTorchLightning/pytorch-lightning/pull/7987)) + + - Add `dataclass` support for `pytorch_lightning.utilities.apply_to_collection` ([#7935](https://github.com/PyTorchLightning/pytorch-lightning/pull/7935)) diff --git a/pytorch_lightning/callbacks/lr_monitor.py b/pytorch_lightning/callbacks/lr_monitor.py index 5a8e8be513..c17e2289d2 100644 --- a/pytorch_lightning/callbacks/lr_monitor.py +++ b/pytorch_lightning/callbacks/lr_monitor.py @@ -20,7 +20,7 @@ Monitor and logs learning rate for lr schedulers during training. """ from collections import defaultdict -from typing import Any, DefaultDict, Dict, List, Optional, Type +from typing import Any, DefaultDict, Dict, List, Optional, Set, Type from torch.optim.optimizer import Optimizer @@ -55,7 +55,9 @@ class LearningRateMonitor(Callback): In case of multiple optimizers of same type, they will be named ``Adam``, ``Adam-1`` etc. If a optimizer has multiple parameter groups they will be named ``Adam/pg1``, ``Adam/pg2`` etc. To control naming, pass in a - ``name`` keyword in the construction of the learning rate schedulers + ``name`` keyword in the construction of the learning rate schedulers. + A ``name`` keyword can also be used for parameter groups in the + construction of the optimizer. Example:: @@ -67,6 +69,19 @@ class LearningRateMonitor(Callback): } return [optimizer], [lr_scheduler] + Example:: + + def configure_optimizer(self): + optimizer = torch.optim.SGD( + [{ + 'params': [p for p in self.parameters()], + 'name': 'my_parameter_group_name' + }], + lr=0.1 + ) + lr_scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, ...) + return [optimizer], [lr_scheduler] + """ def __init__(self, logging_interval: Optional[str] = None, log_momentum: bool = False): @@ -150,11 +165,11 @@ class LearningRateMonitor(Callback): use_betas = 'betas' in opt.defaults for i, pg in enumerate(param_groups): - suffix = f'/pg{i + 1}' if len(param_groups) > 1 else '' - lr = self._extract_lr(pg, f'{name}{suffix}') + name_and_suffix = self._add_suffix(name, param_groups, i) + lr = self._extract_lr(pg, name_and_suffix) latest_stat.update(lr) momentum = self._extract_momentum( - param_group=pg, name=f'{name}-momentum{suffix}', use_betas=use_betas + param_group=pg, name=name_and_suffix.replace(name, f'{name}-momentum'), use_betas=use_betas ) latest_stat.update(momentum) @@ -192,6 +207,26 @@ class LearningRateMonitor(Callback): count = seen_optimizer_types[optimizer_cls] return name + f'-{count - 1}' if count > 1 else name + def _add_suffix(self, name: str, param_groups: List[Dict], param_group_index: int, use_names: bool = True) -> str: + if len(param_groups) > 1: + if not use_names: + return f'{name}/pg{param_group_index+1}' + else: + pg_name = param_groups[param_group_index].get('name', f'pg{param_group_index+1}') + return f'{name}/{pg_name}' + elif use_names: + pg_name = param_groups[param_group_index].get('name') + return f'{name}/{pg_name}' if pg_name else name + return name + + def _duplicate_param_group_names(self, param_groups: List[Dict]) -> Set[str]: + names = [pg.get('name', f'pg{i}') for i, pg in enumerate(param_groups, start=1)] + unique = set(names) + if len(names) == len(unique): + return set() + else: + return set(n for n in names if names.count(n) > 1) + def _find_names(self, lr_schedulers: List, add_lr_sch_names: bool = True) -> List[str]: # Create unique names in the case we have multiple of the same learning # rate scheduler + multiple parameter groups @@ -212,15 +247,16 @@ class LearningRateMonitor(Callback): # Multiple param groups for the same scheduler param_groups = sch.optimizer.param_groups + duplicates = self._duplicate_param_group_names(param_groups) + if duplicates: + raise MisconfigurationException( + 'A single `Optimizer` cannot have multiple parameter groups with identical ' + f'`name` values. {name} has duplicated parameter group names {duplicates}' + ) name = self._add_prefix(name, optimizer_cls, seen_optimizer_types) - if len(param_groups) != 1: - for i in range(len(param_groups)): - temp = f'{name}/pg{i + 1}' - names.append(temp) - else: - names.append(name) + names.extend(self._add_suffix(name, param_groups, i) for i in range(len(param_groups))) if add_lr_sch_names: self.lr_sch_names.append(name) diff --git a/tests/callbacks/test_lr_monitor.py b/tests/callbacks/test_lr_monitor.py index 808165d61b..9b5ffc9728 100644 --- a/tests/callbacks/test_lr_monitor.py +++ b/tests/callbacks/test_lr_monitor.py @@ -283,6 +283,77 @@ def test_lr_monitor_custom_name(tmpdir): assert lr_monitor.lr_sch_names == list(lr_monitor.lrs.keys()) == ['my_logging_name'] +def test_lr_monitor_custom_pg_name(tmpdir): + + class TestModel(BoringModel): + + def configure_optimizers(self): + optimizer = torch.optim.SGD([{'params': [p for p in self.layer.parameters()], 'name': 'linear'}], lr=0.1) + lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1) + return [optimizer], [lr_scheduler] + + lr_monitor = LearningRateMonitor() + trainer = Trainer( + default_root_dir=tmpdir, + max_epochs=2, + limit_val_batches=2, + limit_train_batches=2, + callbacks=[lr_monitor], + progress_bar_refresh_rate=0, + weights_summary=None, + ) + trainer.fit(TestModel()) + assert lr_monitor.lr_sch_names == ['lr-SGD'] + assert list(lr_monitor.lrs) == ['lr-SGD/linear'] + + +def test_lr_monitor_duplicate_custom_pg_names(tmpdir): + tutils.reset_seed() + + class TestModel(BoringModel): + + def __init__(self): + super().__init__() + self.linear_a = torch.nn.Linear(32, 16) + self.linear_b = torch.nn.Linear(16, 2) + + def forward(self, x): + x = self.linear_a(x) + x = self.linear_b(x) + return x + + def configure_optimizers(self): + param_groups = [ + { + 'params': [p for p in self.linear_a.parameters()], + 'name': 'linear' + }, + { + 'params': [p for p in self.linear_b.parameters()], + 'name': 'linear' + }, + ] + optimizer = torch.optim.SGD(param_groups, lr=0.1) + lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1) + return [optimizer], [lr_scheduler] + + lr_monitor = LearningRateMonitor() + trainer = Trainer( + default_root_dir=tmpdir, + max_epochs=2, + limit_val_batches=2, + limit_train_batches=2, + callbacks=[lr_monitor], + progress_bar_refresh_rate=0, + weights_summary=None, + ) + + with pytest.raises( + MisconfigurationException, match='A single `Optimizer` cannot have multiple parameter groups with identical' + ): + trainer.fit(TestModel()) + + def test_multiple_optimizers_basefinetuning(tmpdir): class TestModel(BoringModel):