LearningRateLogger in multi-scheduler setting (#1944)

* fixed undesired behaviour due to dict.fromkeys

* a test for log length consistency

* runtime-warn if no schedulers are configured

* chlog

* move

Co-authored-by: Jirka <jirka@pytorchlightning.ai>
This commit is contained in:
Ivan Nazarov 2020-05-28 05:44:46 +03:00 committed by GitHub
parent 3af4994d5a
commit 7c19c373ac
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 118 additions and 84 deletions

View File

@ -42,7 +42,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Fixed an issue with `Trainer.from_argparse_args` when passing in unknown Trainer args ([#1932](https://github.com/PyTorchLightning/pytorch-lightning/pull/1932))
- Fix bug related to logger not being reset correctly for model after tuner algorithms ([#1933](https://github.com/PyTorchLightning/pytorch-lightning/pull/1933))
- Fixed bug related to logger not being reset correctly for model after tuner algorithms ([#1933](https://github.com/PyTorchLightning/pytorch-lightning/pull/1933))
- Fixed `LearningRateLogger` in multi-scheduler setting ([#1944](https://github.com/PyTorchLightning/pytorch-lightning/pull/1944))
## [0.7.6] - 2020-05-16

View File

@ -10,6 +10,8 @@ Log learning rate for lr schedulers during training
from pytorch_lightning.callbacks.base import Callback
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities import rank_zero_warn
class LearningRateLogger(Callback):
r"""
@ -45,21 +47,22 @@ class LearningRateLogger(Callback):
schedulers in the case of multiple of the same type or in
the case of multiple parameter groups
"""
if trainer.lr_schedulers == []:
raise MisconfigurationException(
'Cannot use LearningRateLogger callback with models that have no'
' learning rate schedulers. Please see documentation for'
' `configure_optimizers` method.')
if not trainer.logger:
raise MisconfigurationException(
'Cannot use LearningRateLogger callback with Trainer that has no logger.')
if not trainer.lr_schedulers:
rank_zero_warn(
'You are using LearningRateLogger callback with models that'
' have no learning rate schedulers. Please see documentation'
' for `configure_optimizers` method.', RuntimeWarning
)
# Find names for schedulers
names = self._find_names(trainer.lr_schedulers)
# Initialize for storing values
self.lrs = dict.fromkeys(names, [])
self.lrs = {name: [] for name in names}
def on_batch_start(self, trainer, pl_module):
latest_stat = self._extract_lr(trainer, 'step')

View File

@ -1,12 +1,13 @@
from pathlib import Path
import pytest
import tests.base.utils as tutils
from pytorch_lightning import Callback
from pytorch_lightning import Trainer, LightningModule
from pytorch_lightning.callbacks import EarlyStopping, LearningRateLogger, ModelCheckpoint
from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint
from pytorch_lightning.loggers import TensorBoardLogger
from tests.base import EvalModelTemplate
from pathlib import Path
def test_trainer_callback_system(tmpdir):
@ -281,77 +282,3 @@ def test_model_checkpoint_path(tmpdir, logger_version, expected):
ckpt_version = Path(trainer.ckpt_path).parent.name
assert ckpt_version == expected
def test_lr_logger_single_lr(tmpdir):
""" Test that learning rates are extracted and logged for single lr scheduler"""
tutils.reset_seed()
model = EvalModelTemplate()
model.configure_optimizers = model.configure_optimizers__single_scheduler
lr_logger = LearningRateLogger()
trainer = Trainer(
default_root_dir=tmpdir,
max_epochs=5,
val_percent_check=0.1,
train_percent_check=0.5,
callbacks=[lr_logger]
)
results = trainer.fit(model)
assert results == 1
assert lr_logger.lrs, 'No learning rates logged'
assert len(lr_logger.lrs) == len(trainer.lr_schedulers), \
'Number of learning rates logged does not match number of lr schedulers'
assert all([k in ['lr-Adam'] for k in lr_logger.lrs.keys()]), \
'Names of learning rates not set correctly'
def test_lr_logger_multi_lrs(tmpdir):
""" Test that learning rates are extracted and logged for multi lr schedulers """
tutils.reset_seed()
model = EvalModelTemplate()
model.configure_optimizers = model.configure_optimizers__multiple_schedulers
lr_logger = LearningRateLogger()
trainer = Trainer(
default_root_dir=tmpdir,
max_epochs=1,
val_percent_check=0.1,
train_percent_check=0.5,
callbacks=[lr_logger]
)
results = trainer.fit(model)
assert results == 1
assert lr_logger.lrs, 'No learning rates logged'
assert len(lr_logger.lrs) == len(trainer.lr_schedulers), \
'Number of learning rates logged does not match number of lr schedulers'
assert all([k in ['lr-Adam', 'lr-Adam-1'] for k in lr_logger.lrs.keys()]), \
'Names of learning rates not set correctly'
def test_lr_logger_param_groups(tmpdir):
""" Test that learning rates are extracted and logged for single lr scheduler"""
tutils.reset_seed()
model = EvalModelTemplate()
model.configure_optimizers = model.configure_optimizers__param_groups
lr_logger = LearningRateLogger()
trainer = Trainer(
default_root_dir=tmpdir,
max_epochs=5,
val_percent_check=0.1,
train_percent_check=0.5,
callbacks=[lr_logger]
)
results = trainer.fit(model)
assert lr_logger.lrs, 'No learning rates logged'
assert len(lr_logger.lrs) == 2 * len(trainer.lr_schedulers), \
'Number of learning rates logged does not match number of param groups'
assert all([k in ['lr-Adam/pg1', 'lr-Adam/pg2'] for k in lr_logger.lrs.keys()]), \
'Names of learning rates not set correctly'

102
tests/callbacks/test_lr.py Normal file
View File

@ -0,0 +1,102 @@
import pytest
import tests.base.utils as tutils
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import LearningRateLogger
from tests.base import EvalModelTemplate
def test_lr_logger_single_lr(tmpdir):
""" Test that learning rates are extracted and logged for single lr scheduler. """
tutils.reset_seed()
model = EvalModelTemplate()
model.configure_optimizers = model.configure_optimizers__single_scheduler
lr_logger = LearningRateLogger()
trainer = Trainer(
default_root_dir=tmpdir,
max_epochs=5,
val_percent_check=0.1,
train_percent_check=0.5,
callbacks=[lr_logger]
)
result = trainer.fit(model)
assert result
assert lr_logger.lrs, 'No learning rates logged'
assert len(lr_logger.lrs) == len(trainer.lr_schedulers), \
'Number of learning rates logged does not match number of lr schedulers'
assert all([k in ['lr-Adam'] for k in lr_logger.lrs.keys()]), \
'Names of learning rates not set correctly'
def test_lr_logger_no_lr(tmpdir):
tutils.reset_seed()
model = EvalModelTemplate()
lr_logger = LearningRateLogger()
trainer = Trainer(
default_root_dir=tmpdir,
max_epochs=5,
val_percent_check=0.1,
train_percent_check=0.5,
callbacks=[lr_logger]
)
with pytest.warns(RuntimeWarning):
result = trainer.fit(model)
assert result
def test_lr_logger_multi_lrs(tmpdir):
""" Test that learning rates are extracted and logged for multi lr schedulers. """
tutils.reset_seed()
model = EvalModelTemplate()
model.configure_optimizers = model.configure_optimizers__multiple_schedulers
lr_logger = LearningRateLogger()
trainer = Trainer(
default_root_dir=tmpdir,
max_epochs=10,
val_percent_check=0.1,
train_percent_check=0.5,
callbacks=[lr_logger]
)
result = trainer.fit(model)
assert result
assert lr_logger.lrs, 'No learning rates logged'
assert len(lr_logger.lrs) == len(trainer.lr_schedulers), \
'Number of learning rates logged does not match number of lr schedulers'
assert all([k in ['lr-Adam', 'lr-Adam-1'] for k in lr_logger.lrs.keys()]), \
'Names of learning rates not set correctly'
assert all(len(lr) == trainer.max_epochs for k, lr in lr_logger.lrs.items()), \
'Length of logged learning rates exceeds the number of epochs'
def test_lr_logger_param_groups(tmpdir):
""" Test that learning rates are extracted and logged for single lr scheduler. """
tutils.reset_seed()
model = EvalModelTemplate()
model.configure_optimizers = model.configure_optimizers__param_groups
lr_logger = LearningRateLogger()
trainer = Trainer(
default_root_dir=tmpdir,
max_epochs=5,
val_percent_check=0.1,
train_percent_check=0.5,
callbacks=[lr_logger]
)
result = trainer.fit(model)
assert result
assert lr_logger.lrs, 'No learning rates logged'
assert len(lr_logger.lrs) == 2 * len(trainer.lr_schedulers), \
'Number of learning rates logged does not match number of param groups'
assert all([k in ['lr-Adam/pg1', 'lr-Adam/pg2'] for k in lr_logger.lrs.keys()]), \
'Names of learning rates not set correctly'