diff --git a/CHANGELOG.md b/CHANGELOG.md index 9ab690429b..4bc697f5e4 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -13,6 +13,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Changed +- Changed `LearningRateLogger` to `LearningRateMonitor` ([#3251](https://github.com/PyTorchLightning/pytorch-lightning/pull/3251)) + - Used `fsspec` instead of `gfile` for all IO ([#3320](https://github.com/PyTorchLightning/pytorch-lightning/pull/3320)) ### Deprecated diff --git a/docs/source/callbacks.rst b/docs/source/callbacks.rst index c3a85887b8..a73a70fb79 100644 --- a/docs/source/callbacks.rst +++ b/docs/source/callbacks.rst @@ -120,7 +120,7 @@ Lightning has a few built-in callbacks. ---------------- -.. automodule:: pytorch_lightning.callbacks.lr_logger +.. automodule:: pytorch_lightning.callbacks.lr_monitor :noindex: :exclude-members: _extract_lr, diff --git a/pytorch_lightning/callbacks/__init__.py b/pytorch_lightning/callbacks/__init__.py index eab698d06d..039f4077ef 100644 --- a/pytorch_lightning/callbacks/__init__.py +++ b/pytorch_lightning/callbacks/__init__.py @@ -1,18 +1,20 @@ from pytorch_lightning.callbacks.base import Callback from pytorch_lightning.callbacks.early_stopping import EarlyStopping +from pytorch_lightning.callbacks.gpu_stats_monitor import GPUStatsMonitor from pytorch_lightning.callbacks.gradient_accumulation_scheduler import GradientAccumulationScheduler from pytorch_lightning.callbacks.lr_logger import LearningRateLogger +from pytorch_lightning.callbacks.lr_monitor import LearningRateMonitor from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint -from pytorch_lightning.callbacks.progress import ProgressBarBase, ProgressBar -from pytorch_lightning.callbacks.gpu_stats_monitor import GPUStatsMonitor +from pytorch_lightning.callbacks.progress import ProgressBar, ProgressBarBase __all__ = [ 'Callback', 'EarlyStopping', - 'ModelCheckpoint', + 'GPUStatsMonitor', 'GradientAccumulationScheduler', 'LearningRateLogger', - 'ProgressBarBase', + 'LearningRateMonitor', + 'ModelCheckpoint', 'ProgressBar', - 'GPUStatsMonitor' + 'ProgressBarBase', ] diff --git a/pytorch_lightning/callbacks/lr_logger.py b/pytorch_lightning/callbacks/lr_logger.py old mode 100755 new mode 100644 index 4209f8e8b6..76ade47087 --- a/pytorch_lightning/callbacks/lr_logger.py +++ b/pytorch_lightning/callbacks/lr_logger.py @@ -1,156 +1,9 @@ -# Copyright The PyTorch Lightning team. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -r""" - -Learning Rate Logger -==================== - -Log learning rate for lr schedulers during training - -""" - -from typing import Optional - -from pytorch_lightning.callbacks.base import Callback +from pytorch_lightning.callbacks.lr_monitor import LearningRateMonitor from pytorch_lightning.utilities import rank_zero_warn -from pytorch_lightning.utilities.exceptions import MisconfigurationException -class LearningRateLogger(Callback): - r""" - Automatically logs learning rate for learning rate schedulers during training. - - Args: - logging_interval: set to `epoch` or `step` to log `lr` of all optimizers - at the same interval, set to `None` to log at individual interval - according to the `interval` key of each scheduler. Defaults to ``None``. - - Example:: - - >>> from pytorch_lightning import Trainer - >>> from pytorch_lightning.callbacks import LearningRateLogger - >>> lr_logger = LearningRateLogger(logging_interval='step') - >>> trainer = Trainer(callbacks=[lr_logger]) - - Logging names are automatically determined based on optimizer class name. - In case of multiple optimizers of same type, they will be named `Adam`, - `Adam-1` etc. If a optimizer has multiple parameter groups they will - be named `Adam/pg1`, `Adam/pg2` etc. To control naming, pass in a - `name` keyword in the construction of the learning rate schdulers - - Example:: - - def configure_optimizer(self): - optimizer = torch.optim.Adam(...) - lr_scheduler = {'scheduler': torch.optim.lr_schedulers.LambdaLR(optimizer, ...) - 'name': 'my_logging_name'} - return [optimizer], [lr_scheduler] - """ - def __init__(self, logging_interval: Optional[str] = None): - if logging_interval not in (None, 'step', 'epoch'): - raise MisconfigurationException( - 'logging_interval should be `step` or `epoch` or `None`.' - ) - - self.logging_interval = logging_interval - self.lrs = None - self.lr_sch_names = [] - - def on_train_start(self, trainer, pl_module): - """ Called before training, determines unique names for all lr - schedulers in the case of multiple of the same type or in - the case of multiple parameter groups - """ - if not trainer.logger: - raise MisconfigurationException( - 'Cannot use LearningRateLogger callback with Trainer that has no logger.' - ) - - if not trainer.lr_schedulers: - rank_zero_warn( - 'You are using LearningRateLogger callback with models that' - ' have no learning rate schedulers. Please see documentation' - ' for `configure_optimizers` method.', RuntimeWarning - ) - - # Find names for schedulers - names = self._find_names(trainer.lr_schedulers) - - # Initialize for storing values - self.lrs = {name: [] for name in names} - - def on_batch_start(self, trainer, pl_module): - if self.logging_interval != 'epoch': - interval = 'step' if self.logging_interval is None else 'any' - latest_stat = self._extract_lr(trainer, interval) - - if trainer.logger is not None and latest_stat: - trainer.logger.log_metrics(latest_stat, step=trainer.global_step) - - def on_epoch_start(self, trainer, pl_module): - if self.logging_interval != 'step': - interval = 'epoch' if self.logging_interval is None else 'any' - latest_stat = self._extract_lr(trainer, interval) - - if trainer.logger is not None and latest_stat: - trainer.logger.log_metrics(latest_stat, step=trainer.current_epoch) - - def _extract_lr(self, trainer, interval): - latest_stat = {} - - for name, scheduler in zip(self.lr_sch_names, trainer.lr_schedulers): - if scheduler['interval'] == interval or interval == 'any': - param_groups = scheduler['scheduler'].optimizer.param_groups - if len(param_groups) != 1: - for i, pg in enumerate(param_groups): - lr, key = pg['lr'], f'{name}/pg{i + 1}' - self.lrs[key].append(lr) - latest_stat[key] = lr - else: - self.lrs[name].append(param_groups[0]['lr']) - latest_stat[name] = param_groups[0]['lr'] - - return latest_stat - - def _find_names(self, lr_schedulers): - # Create uniqe names in the case we have multiple of the same learning - # rate schduler + multiple parameter groups - names = [] - for scheduler in lr_schedulers: - sch = scheduler['scheduler'] - if 'name' in scheduler: - name = scheduler['name'] - else: - opt_name = 'lr-' + sch.optimizer.__class__.__name__ - i, name = 1, opt_name - # Multiple schduler of the same type - while True: - if name not in names: - break - i, name = i + 1, f'{opt_name}-{i}' - - # Multiple param groups for the same schduler - param_groups = sch.optimizer.param_groups - - if len(param_groups) != 1: - for i, pg in enumerate(param_groups): - temp = f'{name}/pg{i + 1}' - names.append(temp) - else: - names.append(name) - - self.lr_sch_names.append(name) - - return names +class LearningRateLogger(LearningRateMonitor): + def __init__(self, *args, **kwargs): + rank_zero_warn("`LearningRateLogger` is now `LearningRateMonitor`" + " and this will be removed in v0.11.0", DeprecationWarning) + super().__init__(*args, **kwargs) diff --git a/pytorch_lightning/callbacks/lr_monitor.py b/pytorch_lightning/callbacks/lr_monitor.py new file mode 100755 index 0000000000..da87e6d5a5 --- /dev/null +++ b/pytorch_lightning/callbacks/lr_monitor.py @@ -0,0 +1,159 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +r""" + +Learning Rate Monitor +===================== + +Monitor and logs learning rate for lr schedulers during training. + +""" + +from typing import Optional + +from pytorch_lightning.callbacks.base import Callback +from pytorch_lightning.utilities import rank_zero_warn +from pytorch_lightning.utilities.exceptions import MisconfigurationException + + +class LearningRateMonitor(Callback): + r""" + Automatically monitor and logs learning rate for learning rate schedulers during training. + + Args: + logging_interval: set to `epoch` or `step` to log `lr` of all optimizers + at the same interval, set to `None` to log at individual interval + according to the `interval` key of each scheduler. Defaults to ``None``. + + Example:: + + >>> from pytorch_lightning import Trainer + >>> from pytorch_lightning.callbacks import LearningRateMonitor + >>> lr_monitor = LearningRateMonitor(logging_interval='step') + >>> trainer = Trainer(callbacks=[lr_monitor]) + + Logging names are automatically determined based on optimizer class name. + In case of multiple optimizers of same type, they will be named `Adam`, + `Adam-1` etc. If a optimizer has multiple parameter groups they will + be named `Adam/pg1`, `Adam/pg2` etc. To control naming, pass in a + `name` keyword in the construction of the learning rate schdulers + + Example:: + + def configure_optimizer(self): + optimizer = torch.optim.Adam(...) + lr_scheduler = {'scheduler': torch.optim.lr_schedulers.LambdaLR(optimizer, ...) + 'name': 'my_logging_name'} + return [optimizer], [lr_scheduler] + + """ + def __init__(self, logging_interval: Optional[str] = None): + if logging_interval not in (None, 'step', 'epoch'): + raise MisconfigurationException( + 'logging_interval should be `step` or `epoch` or `None`.' + ) + + self.logging_interval = logging_interval + self.lrs = None + self.lr_sch_names = [] + + def on_train_start(self, trainer, pl_module): + """ + Called before training, determines unique names for all lr + schedulers in the case of multiple of the same type or in + the case of multiple parameter groups + """ + if not trainer.logger: + raise MisconfigurationException( + 'Cannot use LearningRateMonitor callback with Trainer that has no logger.' + ) + + if not trainer.lr_schedulers: + rank_zero_warn( + 'You are using LearningRateMonitor callback with models that' + ' have no learning rate schedulers. Please see documentation' + ' for `configure_optimizers` method.', RuntimeWarning + ) + + # Find names for schedulers + names = self._find_names(trainer.lr_schedulers) + + # Initialize for storing values + self.lrs = {name: [] for name in names} + + def on_batch_start(self, trainer, pl_module): + if self.logging_interval != 'epoch': + interval = 'step' if self.logging_interval is None else 'any' + latest_stat = self._extract_lr(trainer, interval) + + if trainer.logger is not None and latest_stat: + trainer.logger.log_metrics(latest_stat, step=trainer.global_step) + + def on_epoch_start(self, trainer, pl_module): + if self.logging_interval != 'step': + interval = 'epoch' if self.logging_interval is None else 'any' + latest_stat = self._extract_lr(trainer, interval) + + if trainer.logger is not None and latest_stat: + trainer.logger.log_metrics(latest_stat, step=trainer.current_epoch) + + def _extract_lr(self, trainer, interval): + latest_stat = {} + + for name, scheduler in zip(self.lr_sch_names, trainer.lr_schedulers): + if scheduler['interval'] == interval or interval == 'any': + param_groups = scheduler['scheduler'].optimizer.param_groups + if len(param_groups) != 1: + for i, pg in enumerate(param_groups): + lr, key = pg['lr'], f'{name}/pg{i + 1}' + self.lrs[key].append(lr) + latest_stat[key] = lr + else: + self.lrs[name].append(param_groups[0]['lr']) + latest_stat[name] = param_groups[0]['lr'] + + return latest_stat + + def _find_names(self, lr_schedulers): + # Create uniqe names in the case we have multiple of the same learning + # rate schduler + multiple parameter groups + names = [] + for scheduler in lr_schedulers: + sch = scheduler['scheduler'] + if 'name' in scheduler: + name = scheduler['name'] + else: + opt_name = 'lr-' + sch.optimizer.__class__.__name__ + i, name = 1, opt_name + + # Multiple schduler of the same type + while True: + if name not in names: + break + i, name = i + 1, f'{opt_name}-{i}' + + # Multiple param groups for the same schduler + param_groups = sch.optimizer.param_groups + + if len(param_groups) != 1: + for i, pg in enumerate(param_groups): + temp = f'{name}/pg{i + 1}' + names.append(temp) + else: + names.append(name) + + self.lr_sch_names.append(name) + + return names diff --git a/tests/callbacks/test_lr_logger.py b/tests/callbacks/test_lr_monitor.py similarity index 58% rename from tests/callbacks/test_lr_logger.py rename to tests/callbacks/test_lr_monitor.py index 264329fa3f..4370150768 100644 --- a/tests/callbacks/test_lr_logger.py +++ b/tests/callbacks/test_lr_monitor.py @@ -1,78 +1,96 @@ import pytest -import tests.base.develop_utils as tutils from pytorch_lightning import Trainer -from pytorch_lightning.callbacks import LearningRateLogger +from pytorch_lightning.callbacks import LearningRateMonitor +from pytorch_lightning.utilities.exceptions import MisconfigurationException from tests.base import EvalModelTemplate +import tests.base.develop_utils as tutils -def test_lr_logger_single_lr(tmpdir): +def test_lr_monitor_single_lr(tmpdir): """ Test that learning rates are extracted and logged for single lr scheduler. """ tutils.reset_seed() model = EvalModelTemplate() model.configure_optimizers = model.configure_optimizers__single_scheduler - lr_logger = LearningRateLogger() + lr_monitor = LearningRateMonitor() trainer = Trainer( default_root_dir=tmpdir, max_epochs=2, limit_val_batches=0.1, limit_train_batches=0.5, - callbacks=[lr_logger], + callbacks=[lr_monitor], ) result = trainer.fit(model) assert result - assert lr_logger.lrs, 'No learning rates logged' - assert len(lr_logger.lrs) == len(trainer.lr_schedulers), \ + assert lr_monitor.lrs, 'No learning rates logged' + assert len(lr_monitor.lrs) == len(trainer.lr_schedulers), \ 'Number of learning rates logged does not match number of lr schedulers' - assert all([k in ['lr-Adam'] for k in lr_logger.lrs.keys()]), \ + assert all([k in ['lr-Adam'] for k in lr_monitor.lrs.keys()]), \ 'Names of learning rates not set correctly' -def test_lr_logger_no_lr(tmpdir): +def test_lr_monitor_no_lr_scheduler(tmpdir): tutils.reset_seed() model = EvalModelTemplate() - lr_logger = LearningRateLogger() + lr_monitor = LearningRateMonitor() trainer = Trainer( default_root_dir=tmpdir, max_epochs=2, limit_val_batches=0.1, limit_train_batches=0.5, - callbacks=[lr_logger], + callbacks=[lr_monitor], ) - with pytest.warns(RuntimeWarning): + with pytest.warns(RuntimeWarning, match='have no learning rate schedulers'): result = trainer.fit(model) assert result +def test_lr_monitor_no_logger(tmpdir): + tutils.reset_seed() + + model = EvalModelTemplate() + + lr_monitor = LearningRateMonitor() + trainer = Trainer( + default_root_dir=tmpdir, + max_epochs=1, + callbacks=[lr_monitor], + logger=False + ) + + with pytest.raises(MisconfigurationException, match='Trainer that has no logger'): + trainer.fit(model) + + @pytest.mark.parametrize("logging_interval", ['step', 'epoch']) -def test_lr_logger_multi_lrs(tmpdir, logging_interval): +def test_lr_monitor_multi_lrs(tmpdir, logging_interval): """ Test that learning rates are extracted and logged for multi lr schedulers. """ tutils.reset_seed() model = EvalModelTemplate() model.configure_optimizers = model.configure_optimizers__multiple_schedulers - lr_logger = LearningRateLogger(logging_interval=logging_interval) + lr_monitor = LearningRateMonitor(logging_interval=logging_interval) trainer = Trainer( default_root_dir=tmpdir, max_epochs=2, limit_val_batches=0.1, limit_train_batches=0.5, - callbacks=[lr_logger], + callbacks=[lr_monitor], ) result = trainer.fit(model) assert result - assert lr_logger.lrs, 'No learning rates logged' - assert len(lr_logger.lrs) == len(trainer.lr_schedulers), \ + assert lr_monitor.lrs, 'No learning rates logged' + assert len(lr_monitor.lrs) == len(trainer.lr_schedulers), \ 'Number of learning rates logged does not match number of lr schedulers' - assert all([k in ['lr-Adam', 'lr-Adam-1'] for k in lr_logger.lrs.keys()]), \ + assert all([k in ['lr-Adam', 'lr-Adam-1'] for k in lr_monitor.lrs.keys()]), \ 'Names of learning rates not set correctly' if logging_interval == 'step': @@ -80,30 +98,30 @@ def test_lr_logger_multi_lrs(tmpdir, logging_interval): if logging_interval == 'epoch': expected_number_logged = trainer.max_epochs - assert all(len(lr) == expected_number_logged for lr in lr_logger.lrs.values()), \ + assert all(len(lr) == expected_number_logged for lr in lr_monitor.lrs.values()), \ 'Length of logged learning rates do not match the expected number' -def test_lr_logger_param_groups(tmpdir): +def test_lr_monitor_param_groups(tmpdir): """ Test that learning rates are extracted and logged for single lr scheduler. """ tutils.reset_seed() model = EvalModelTemplate() model.configure_optimizers = model.configure_optimizers__param_groups - lr_logger = LearningRateLogger() + lr_monitor = LearningRateMonitor() trainer = Trainer( default_root_dir=tmpdir, max_epochs=2, limit_val_batches=0.1, limit_train_batches=0.5, - callbacks=[lr_logger], + callbacks=[lr_monitor], ) result = trainer.fit(model) assert result - assert lr_logger.lrs, 'No learning rates logged' - assert len(lr_logger.lrs) == 2 * len(trainer.lr_schedulers), \ + assert lr_monitor.lrs, 'No learning rates logged' + assert len(lr_monitor.lrs) == 2 * len(trainer.lr_schedulers), \ 'Number of learning rates logged does not match number of param groups' - assert all([k in ['lr-Adam/pg1', 'lr-Adam/pg2'] for k in lr_logger.lrs.keys()]), \ + assert all([k in ['lr-Adam/pg1', 'lr-Adam/pg2'] for k in lr_monitor.lrs.keys()]), \ 'Names of learning rates not set correctly' diff --git a/tests/test_deprecated.py b/tests/test_deprecated.py index a34b3b3d72..a3ce174713 100644 --- a/tests/test_deprecated.py +++ b/tests/test_deprecated.py @@ -6,6 +6,7 @@ import pytest import torch from pytorch_lightning import Trainer +from pytorch_lightning.callbacks import LearningRateLogger from tests.base import EvalModelTemplate @@ -15,6 +16,11 @@ def _soft_unimport_module(str_module): del sys.modules[str_module] +def test_tbd_remove_in_v0_11_0_trainer(): + with pytest.deprecated_call(match='will be removed in v0.11.0'): + lr_logger = LearningRateLogger() + + def test_tbd_remove_in_v0_10_0_trainer(): rnd_val = random.random() with pytest.deprecated_call(match='will be removed in v0.10.0'):