Fix lr key name in case of param groups (#1719)

* Fix lr key name in case of param groups

* Add tests

* Update test and added configure_optimizers__param_groups

* Update CHANGELOG
This commit is contained in:
Rohit Gupta 2020-05-11 02:35:34 +05:30 committed by GitHub
parent 7f64ad7a33
commit d962ab5d89
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 38 additions and 2 deletions

View File

@ -44,6 +44,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Fixed a bug in Trainer that prepended the checkpoint path with `version_` when it shouldn't ([#1748](https://github.com/PyTorchLightning/pytorch-lightning/pull/1748))
- Fixed lr key name in case of param groups in LearningRateLogger ([#1719](https://github.com/PyTorchLightning/pytorch-lightning/pull/1719))
## [0.7.5] - 2020-04-27
### Changed

View File

@ -80,7 +80,7 @@ class LearningRateLogger(Callback):
param_groups = scheduler['scheduler'].optimizer.param_groups
if len(param_groups) != 1:
for i, pg in enumerate(param_groups):
lr, key = pg['lr'], f'{name}/{i + 1}'
lr, key = pg['lr'], f'{name}/pg{i + 1}'
self.lrs[key].append(lr)
latest_stat[key] = lr
else:
@ -109,7 +109,7 @@ class LearningRateLogger(Callback):
param_groups = sch.optimizer.param_groups
if len(param_groups) != 1:
for i, pg in enumerate(param_groups):
temp = name + '/pg' + str(i + 1)
temp = f'{name}/pg{i + 1}'
names.append(temp)
else:
names.append(name)

0
tests/base/mixins.py Normal file
View File

View File

@ -59,3 +59,13 @@ class ConfigureOptimizersPool(ABC):
optimizer = optim.Adam(self.parameters(), lr=self.hparams.learning_rate)
lr_scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer)
return [optimizer], [lr_scheduler]
def configure_optimizers__param_groups(self):
param_groups = [
{'params': list(self.parameters())[:2], 'lr': self.hparams.learning_rate * 0.1},
{'params': list(self.parameters())[2:], 'lr': self.hparams.learning_rate}
]
optimizer = optim.Adam(param_groups)
lr_scheduler = optim.lr_scheduler.StepLR(optimizer, 1, gamma=0.1)
return [optimizer], [lr_scheduler]

View File

@ -331,3 +331,27 @@ def test_lr_logger_multi_lrs(tmpdir):
'Number of learning rates logged does not match number of lr schedulers'
assert all([k in ['lr-Adam', 'lr-Adam-1'] for k in lr_logger.lrs.keys()]), \
'Names of learning rates not set correctly'
def test_lr_logger_param_groups(tmpdir):
""" Test that learning rates are extracted and logged for single lr scheduler"""
tutils.reset_seed()
model = EvalModelTemplate()
model.configure_optimizers = model.configure_optimizers__param_groups
lr_logger = LearningRateLogger()
trainer = Trainer(
default_root_dir=tmpdir,
max_epochs=5,
val_percent_check=0.1,
train_percent_check=0.5,
callbacks=[lr_logger]
)
results = trainer.fit(model)
assert lr_logger.lrs, 'No learning rates logged'
assert len(lr_logger.lrs) == 2 * len(trainer.lr_schedulers), \
'Number of learning rates logged does not match number of param groups'
assert all([k in ['lr-Adam/pg1', 'lr-Adam/pg2'] for k in lr_logger.lrs.keys()]), \
'Names of learning rates not set correctly'