Fix lr key name in case of param groups (#1719)
* Fix lr key name in case of param groups * Add tests * Update test and added configure_optimizers__param_groups * Update CHANGELOG
This commit is contained in:
parent
7f64ad7a33
commit
d962ab5d89
|
@ -44,6 +44,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
|
|||
|
||||
- Fixed a bug in Trainer that prepended the checkpoint path with `version_` when it shouldn't ([#1748](https://github.com/PyTorchLightning/pytorch-lightning/pull/1748))
|
||||
|
||||
- Fixed lr key name in case of param groups in LearningRateLogger ([#1719](https://github.com/PyTorchLightning/pytorch-lightning/pull/1719))
|
||||
|
||||
## [0.7.5] - 2020-04-27
|
||||
|
||||
### Changed
|
||||
|
|
|
@ -80,7 +80,7 @@ class LearningRateLogger(Callback):
|
|||
param_groups = scheduler['scheduler'].optimizer.param_groups
|
||||
if len(param_groups) != 1:
|
||||
for i, pg in enumerate(param_groups):
|
||||
lr, key = pg['lr'], f'{name}/{i + 1}'
|
||||
lr, key = pg['lr'], f'{name}/pg{i + 1}'
|
||||
self.lrs[key].append(lr)
|
||||
latest_stat[key] = lr
|
||||
else:
|
||||
|
@ -109,7 +109,7 @@ class LearningRateLogger(Callback):
|
|||
param_groups = sch.optimizer.param_groups
|
||||
if len(param_groups) != 1:
|
||||
for i, pg in enumerate(param_groups):
|
||||
temp = name + '/pg' + str(i + 1)
|
||||
temp = f'{name}/pg{i + 1}'
|
||||
names.append(temp)
|
||||
else:
|
||||
names.append(name)
|
||||
|
|
|
@ -59,3 +59,13 @@ class ConfigureOptimizersPool(ABC):
|
|||
optimizer = optim.Adam(self.parameters(), lr=self.hparams.learning_rate)
|
||||
lr_scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer)
|
||||
return [optimizer], [lr_scheduler]
|
||||
|
||||
def configure_optimizers__param_groups(self):
|
||||
param_groups = [
|
||||
{'params': list(self.parameters())[:2], 'lr': self.hparams.learning_rate * 0.1},
|
||||
{'params': list(self.parameters())[2:], 'lr': self.hparams.learning_rate}
|
||||
]
|
||||
|
||||
optimizer = optim.Adam(param_groups)
|
||||
lr_scheduler = optim.lr_scheduler.StepLR(optimizer, 1, gamma=0.1)
|
||||
return [optimizer], [lr_scheduler]
|
||||
|
|
|
@ -331,3 +331,27 @@ def test_lr_logger_multi_lrs(tmpdir):
|
|||
'Number of learning rates logged does not match number of lr schedulers'
|
||||
assert all([k in ['lr-Adam', 'lr-Adam-1'] for k in lr_logger.lrs.keys()]), \
|
||||
'Names of learning rates not set correctly'
|
||||
|
||||
|
||||
def test_lr_logger_param_groups(tmpdir):
|
||||
""" Test that learning rates are extracted and logged for single lr scheduler"""
|
||||
tutils.reset_seed()
|
||||
|
||||
model = EvalModelTemplate()
|
||||
model.configure_optimizers = model.configure_optimizers__param_groups
|
||||
|
||||
lr_logger = LearningRateLogger()
|
||||
trainer = Trainer(
|
||||
default_root_dir=tmpdir,
|
||||
max_epochs=5,
|
||||
val_percent_check=0.1,
|
||||
train_percent_check=0.5,
|
||||
callbacks=[lr_logger]
|
||||
)
|
||||
results = trainer.fit(model)
|
||||
|
||||
assert lr_logger.lrs, 'No learning rates logged'
|
||||
assert len(lr_logger.lrs) == 2 * len(trainer.lr_schedulers), \
|
||||
'Number of learning rates logged does not match number of param groups'
|
||||
assert all([k in ['lr-Adam/pg1', 'lr-Adam/pg2'] for k in lr_logger.lrs.keys()]), \
|
||||
'Names of learning rates not set correctly'
|
||||
|
|
Loading…
Reference in New Issue