From d962ab5d89641e051592a99794315dfd6042eb4f Mon Sep 17 00:00:00 2001 From: Rohit Gupta Date: Mon, 11 May 2020 02:35:34 +0530 Subject: [PATCH] Fix lr key name in case of param groups (#1719) * Fix lr key name in case of param groups * Add tests * Update test and added configure_optimizers__param_groups * Update CHANGELOG --- CHANGELOG.md | 2 ++ pytorch_lightning/callbacks/lr_logger.py | 4 ++-- tests/base/mixins.py | 0 tests/base/model_optimizers.py | 10 ++++++++++ tests/callbacks/test_callbacks.py | 24 ++++++++++++++++++++++++ 5 files changed, 38 insertions(+), 2 deletions(-) create mode 100644 tests/base/mixins.py diff --git a/CHANGELOG.md b/CHANGELOG.md index f28e258e4f..a1ae50dccc 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -44,6 +44,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixed a bug in Trainer that prepended the checkpoint path with `version_` when it shouldn't ([#1748](https://github.com/PyTorchLightning/pytorch-lightning/pull/1748)) +- Fixed lr key name in case of param groups in LearningRateLogger ([#1719](https://github.com/PyTorchLightning/pytorch-lightning/pull/1719)) + ## [0.7.5] - 2020-04-27 ### Changed diff --git a/pytorch_lightning/callbacks/lr_logger.py b/pytorch_lightning/callbacks/lr_logger.py index 6ad68905bc..c8aab75b87 100755 --- a/pytorch_lightning/callbacks/lr_logger.py +++ b/pytorch_lightning/callbacks/lr_logger.py @@ -80,7 +80,7 @@ class LearningRateLogger(Callback): param_groups = scheduler['scheduler'].optimizer.param_groups if len(param_groups) != 1: for i, pg in enumerate(param_groups): - lr, key = pg['lr'], f'{name}/{i + 1}' + lr, key = pg['lr'], f'{name}/pg{i + 1}' self.lrs[key].append(lr) latest_stat[key] = lr else: @@ -109,7 +109,7 @@ class LearningRateLogger(Callback): param_groups = sch.optimizer.param_groups if len(param_groups) != 1: for i, pg in enumerate(param_groups): - temp = name + '/pg' + str(i + 1) + temp = f'{name}/pg{i + 1}' names.append(temp) else: names.append(name) diff --git a/tests/base/mixins.py b/tests/base/mixins.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/base/model_optimizers.py b/tests/base/model_optimizers.py index 2fd9b104a0..394ee69dae 100644 --- a/tests/base/model_optimizers.py +++ b/tests/base/model_optimizers.py @@ -59,3 +59,13 @@ class ConfigureOptimizersPool(ABC): optimizer = optim.Adam(self.parameters(), lr=self.hparams.learning_rate) lr_scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer) return [optimizer], [lr_scheduler] + + def configure_optimizers__param_groups(self): + param_groups = [ + {'params': list(self.parameters())[:2], 'lr': self.hparams.learning_rate * 0.1}, + {'params': list(self.parameters())[2:], 'lr': self.hparams.learning_rate} + ] + + optimizer = optim.Adam(param_groups) + lr_scheduler = optim.lr_scheduler.StepLR(optimizer, 1, gamma=0.1) + return [optimizer], [lr_scheduler] diff --git a/tests/callbacks/test_callbacks.py b/tests/callbacks/test_callbacks.py index a4e216063a..52c03ada3b 100644 --- a/tests/callbacks/test_callbacks.py +++ b/tests/callbacks/test_callbacks.py @@ -331,3 +331,27 @@ def test_lr_logger_multi_lrs(tmpdir): 'Number of learning rates logged does not match number of lr schedulers' assert all([k in ['lr-Adam', 'lr-Adam-1'] for k in lr_logger.lrs.keys()]), \ 'Names of learning rates not set correctly' + + +def test_lr_logger_param_groups(tmpdir): + """ Test that learning rates are extracted and logged for single lr scheduler""" + tutils.reset_seed() + + model = EvalModelTemplate() + model.configure_optimizers = model.configure_optimizers__param_groups + + lr_logger = LearningRateLogger() + trainer = Trainer( + default_root_dir=tmpdir, + max_epochs=5, + val_percent_check=0.1, + train_percent_check=0.5, + callbacks=[lr_logger] + ) + results = trainer.fit(model) + + assert lr_logger.lrs, 'No learning rates logged' + assert len(lr_logger.lrs) == 2 * len(trainer.lr_schedulers), \ + 'Number of learning rates logged does not match number of param groups' + assert all([k in ['lr-Adam/pg1', 'lr-Adam/pg2'] for k in lr_logger.lrs.keys()]), \ + 'Names of learning rates not set correctly'