LearningRateLogger in multi-scheduler setting (#1944)
* fixed undesired behaviour due to dict.fromkeys * a test for log length consistency * runtime-warn if no schedulers are configured * chlog * move Co-authored-by: Jirka <jirka@pytorchlightning.ai>
This commit is contained in:
parent
3af4994d5a
commit
7c19c373ac
|
@ -42,7 +42,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
|
|||
|
||||
- Fixed an issue with `Trainer.from_argparse_args` when passing in unknown Trainer args ([#1932](https://github.com/PyTorchLightning/pytorch-lightning/pull/1932))
|
||||
|
||||
- Fix bug related to logger not being reset correctly for model after tuner algorithms ([#1933](https://github.com/PyTorchLightning/pytorch-lightning/pull/1933))
|
||||
- Fixed bug related to logger not being reset correctly for model after tuner algorithms ([#1933](https://github.com/PyTorchLightning/pytorch-lightning/pull/1933))
|
||||
|
||||
- Fixed `LearningRateLogger` in multi-scheduler setting ([#1944](https://github.com/PyTorchLightning/pytorch-lightning/pull/1944))
|
||||
|
||||
|
||||
## [0.7.6] - 2020-05-16
|
||||
|
|
|
@ -10,6 +10,8 @@ Log learning rate for lr schedulers during training
|
|||
from pytorch_lightning.callbacks.base import Callback
|
||||
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
||||
|
||||
from pytorch_lightning.utilities import rank_zero_warn
|
||||
|
||||
|
||||
class LearningRateLogger(Callback):
|
||||
r"""
|
||||
|
@ -45,21 +47,22 @@ class LearningRateLogger(Callback):
|
|||
schedulers in the case of multiple of the same type or in
|
||||
the case of multiple parameter groups
|
||||
"""
|
||||
if trainer.lr_schedulers == []:
|
||||
raise MisconfigurationException(
|
||||
'Cannot use LearningRateLogger callback with models that have no'
|
||||
' learning rate schedulers. Please see documentation for'
|
||||
' `configure_optimizers` method.')
|
||||
|
||||
if not trainer.logger:
|
||||
raise MisconfigurationException(
|
||||
'Cannot use LearningRateLogger callback with Trainer that has no logger.')
|
||||
|
||||
if not trainer.lr_schedulers:
|
||||
rank_zero_warn(
|
||||
'You are using LearningRateLogger callback with models that'
|
||||
' have no learning rate schedulers. Please see documentation'
|
||||
' for `configure_optimizers` method.', RuntimeWarning
|
||||
)
|
||||
|
||||
# Find names for schedulers
|
||||
names = self._find_names(trainer.lr_schedulers)
|
||||
|
||||
# Initialize for storing values
|
||||
self.lrs = dict.fromkeys(names, [])
|
||||
self.lrs = {name: [] for name in names}
|
||||
|
||||
def on_batch_start(self, trainer, pl_module):
|
||||
latest_stat = self._extract_lr(trainer, 'step')
|
||||
|
|
|
@ -1,12 +1,13 @@
|
|||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
import tests.base.utils as tutils
|
||||
from pytorch_lightning import Callback
|
||||
from pytorch_lightning import Trainer, LightningModule
|
||||
from pytorch_lightning.callbacks import EarlyStopping, LearningRateLogger, ModelCheckpoint
|
||||
from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint
|
||||
from pytorch_lightning.loggers import TensorBoardLogger
|
||||
from tests.base import EvalModelTemplate
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
def test_trainer_callback_system(tmpdir):
|
||||
|
@ -281,77 +282,3 @@ def test_model_checkpoint_path(tmpdir, logger_version, expected):
|
|||
|
||||
ckpt_version = Path(trainer.ckpt_path).parent.name
|
||||
assert ckpt_version == expected
|
||||
|
||||
|
||||
def test_lr_logger_single_lr(tmpdir):
|
||||
""" Test that learning rates are extracted and logged for single lr scheduler"""
|
||||
tutils.reset_seed()
|
||||
|
||||
model = EvalModelTemplate()
|
||||
model.configure_optimizers = model.configure_optimizers__single_scheduler
|
||||
|
||||
lr_logger = LearningRateLogger()
|
||||
trainer = Trainer(
|
||||
default_root_dir=tmpdir,
|
||||
max_epochs=5,
|
||||
val_percent_check=0.1,
|
||||
train_percent_check=0.5,
|
||||
callbacks=[lr_logger]
|
||||
)
|
||||
results = trainer.fit(model)
|
||||
|
||||
assert results == 1
|
||||
assert lr_logger.lrs, 'No learning rates logged'
|
||||
assert len(lr_logger.lrs) == len(trainer.lr_schedulers), \
|
||||
'Number of learning rates logged does not match number of lr schedulers'
|
||||
assert all([k in ['lr-Adam'] for k in lr_logger.lrs.keys()]), \
|
||||
'Names of learning rates not set correctly'
|
||||
|
||||
|
||||
def test_lr_logger_multi_lrs(tmpdir):
|
||||
""" Test that learning rates are extracted and logged for multi lr schedulers """
|
||||
tutils.reset_seed()
|
||||
|
||||
model = EvalModelTemplate()
|
||||
model.configure_optimizers = model.configure_optimizers__multiple_schedulers
|
||||
|
||||
lr_logger = LearningRateLogger()
|
||||
trainer = Trainer(
|
||||
default_root_dir=tmpdir,
|
||||
max_epochs=1,
|
||||
val_percent_check=0.1,
|
||||
train_percent_check=0.5,
|
||||
callbacks=[lr_logger]
|
||||
)
|
||||
results = trainer.fit(model)
|
||||
|
||||
assert results == 1
|
||||
assert lr_logger.lrs, 'No learning rates logged'
|
||||
assert len(lr_logger.lrs) == len(trainer.lr_schedulers), \
|
||||
'Number of learning rates logged does not match number of lr schedulers'
|
||||
assert all([k in ['lr-Adam', 'lr-Adam-1'] for k in lr_logger.lrs.keys()]), \
|
||||
'Names of learning rates not set correctly'
|
||||
|
||||
|
||||
def test_lr_logger_param_groups(tmpdir):
|
||||
""" Test that learning rates are extracted and logged for single lr scheduler"""
|
||||
tutils.reset_seed()
|
||||
|
||||
model = EvalModelTemplate()
|
||||
model.configure_optimizers = model.configure_optimizers__param_groups
|
||||
|
||||
lr_logger = LearningRateLogger()
|
||||
trainer = Trainer(
|
||||
default_root_dir=tmpdir,
|
||||
max_epochs=5,
|
||||
val_percent_check=0.1,
|
||||
train_percent_check=0.5,
|
||||
callbacks=[lr_logger]
|
||||
)
|
||||
results = trainer.fit(model)
|
||||
|
||||
assert lr_logger.lrs, 'No learning rates logged'
|
||||
assert len(lr_logger.lrs) == 2 * len(trainer.lr_schedulers), \
|
||||
'Number of learning rates logged does not match number of param groups'
|
||||
assert all([k in ['lr-Adam/pg1', 'lr-Adam/pg2'] for k in lr_logger.lrs.keys()]), \
|
||||
'Names of learning rates not set correctly'
|
||||
|
|
|
@ -0,0 +1,102 @@
|
|||
import pytest
|
||||
|
||||
import tests.base.utils as tutils
|
||||
from pytorch_lightning import Trainer
|
||||
from pytorch_lightning.callbacks import LearningRateLogger
|
||||
from tests.base import EvalModelTemplate
|
||||
|
||||
|
||||
def test_lr_logger_single_lr(tmpdir):
|
||||
""" Test that learning rates are extracted and logged for single lr scheduler. """
|
||||
tutils.reset_seed()
|
||||
|
||||
model = EvalModelTemplate()
|
||||
model.configure_optimizers = model.configure_optimizers__single_scheduler
|
||||
|
||||
lr_logger = LearningRateLogger()
|
||||
trainer = Trainer(
|
||||
default_root_dir=tmpdir,
|
||||
max_epochs=5,
|
||||
val_percent_check=0.1,
|
||||
train_percent_check=0.5,
|
||||
callbacks=[lr_logger]
|
||||
)
|
||||
result = trainer.fit(model)
|
||||
assert result
|
||||
|
||||
assert lr_logger.lrs, 'No learning rates logged'
|
||||
assert len(lr_logger.lrs) == len(trainer.lr_schedulers), \
|
||||
'Number of learning rates logged does not match number of lr schedulers'
|
||||
assert all([k in ['lr-Adam'] for k in lr_logger.lrs.keys()]), \
|
||||
'Names of learning rates not set correctly'
|
||||
|
||||
|
||||
def test_lr_logger_no_lr(tmpdir):
|
||||
tutils.reset_seed()
|
||||
|
||||
model = EvalModelTemplate()
|
||||
|
||||
lr_logger = LearningRateLogger()
|
||||
trainer = Trainer(
|
||||
default_root_dir=tmpdir,
|
||||
max_epochs=5,
|
||||
val_percent_check=0.1,
|
||||
train_percent_check=0.5,
|
||||
callbacks=[lr_logger]
|
||||
)
|
||||
|
||||
with pytest.warns(RuntimeWarning):
|
||||
result = trainer.fit(model)
|
||||
assert result
|
||||
|
||||
|
||||
def test_lr_logger_multi_lrs(tmpdir):
|
||||
""" Test that learning rates are extracted and logged for multi lr schedulers. """
|
||||
tutils.reset_seed()
|
||||
|
||||
model = EvalModelTemplate()
|
||||
model.configure_optimizers = model.configure_optimizers__multiple_schedulers
|
||||
|
||||
lr_logger = LearningRateLogger()
|
||||
trainer = Trainer(
|
||||
default_root_dir=tmpdir,
|
||||
max_epochs=10,
|
||||
val_percent_check=0.1,
|
||||
train_percent_check=0.5,
|
||||
callbacks=[lr_logger]
|
||||
)
|
||||
result = trainer.fit(model)
|
||||
assert result
|
||||
|
||||
assert lr_logger.lrs, 'No learning rates logged'
|
||||
assert len(lr_logger.lrs) == len(trainer.lr_schedulers), \
|
||||
'Number of learning rates logged does not match number of lr schedulers'
|
||||
assert all([k in ['lr-Adam', 'lr-Adam-1'] for k in lr_logger.lrs.keys()]), \
|
||||
'Names of learning rates not set correctly'
|
||||
assert all(len(lr) == trainer.max_epochs for k, lr in lr_logger.lrs.items()), \
|
||||
'Length of logged learning rates exceeds the number of epochs'
|
||||
|
||||
|
||||
def test_lr_logger_param_groups(tmpdir):
|
||||
""" Test that learning rates are extracted and logged for single lr scheduler. """
|
||||
tutils.reset_seed()
|
||||
|
||||
model = EvalModelTemplate()
|
||||
model.configure_optimizers = model.configure_optimizers__param_groups
|
||||
|
||||
lr_logger = LearningRateLogger()
|
||||
trainer = Trainer(
|
||||
default_root_dir=tmpdir,
|
||||
max_epochs=5,
|
||||
val_percent_check=0.1,
|
||||
train_percent_check=0.5,
|
||||
callbacks=[lr_logger]
|
||||
)
|
||||
result = trainer.fit(model)
|
||||
assert result
|
||||
|
||||
assert lr_logger.lrs, 'No learning rates logged'
|
||||
assert len(lr_logger.lrs) == 2 * len(trainer.lr_schedulers), \
|
||||
'Number of learning rates logged does not match number of param groups'
|
||||
assert all([k in ['lr-Adam/pg1', 'lr-Adam/pg2'] for k in lr_logger.lrs.keys()]), \
|
||||
'Names of learning rates not set correctly'
|
Loading…
Reference in New Issue