diff --git a/CHANGELOG.md b/CHANGELOG.md index 6b5ddef3e0..3f1d85b777 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -260,6 +260,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Fixed +- Fixed `lr_scheduler` checkpointed state by calling `update_lr_schedulers` before saving checkpoints ([#7877](https://github.com/PyTorchLightning/pytorch-lightning/pull/7877)) + + - Fixed ambiguous warning when both overfit and train dataloader shuffling are enabled ([#7685](https://github.com/PyTorchLightning/pytorch-lightning/pull/7685)) diff --git a/pytorch_lightning/loops/fit_loop.py b/pytorch_lightning/loops/fit_loop.py index 756e48d1be..5ccfb2cb1d 100644 --- a/pytorch_lightning/loops/fit_loop.py +++ b/pytorch_lightning/loops/fit_loop.py @@ -219,7 +219,7 @@ class FitLoop(Loop): if self.training_loop.batches_seen == 0: return - self.training_loop.update_lr_schedulers('epoch') + self.training_loop.update_lr_schedulers('epoch', update_plateau_schedulers=True) did_train_only = self.trainer.disable_validation or self.trainer.evaluation_loop.skip if did_train_only: diff --git a/pytorch_lightning/loops/training_epoch_loop.py b/pytorch_lightning/loops/training_epoch_loop.py index ed0af261da..2938a3d759 100644 --- a/pytorch_lightning/loops/training_epoch_loop.py +++ b/pytorch_lightning/loops/training_epoch_loop.py @@ -115,6 +115,12 @@ class TrainingEpochLoop(Loop): if batch_output.signal == -1: raise StopIteration + # update non-plateau LR schedulers + # update epoch-interval ones only when we are at the end of training epoch + self.update_lr_schedulers('step', update_plateau_schedulers=False) + if self._num_training_batches_reached(is_last): + self.update_lr_schedulers('epoch', update_plateau_schedulers=False) + batch_end_outputs = [opt_idx_out for opt_idx_out in batch_output.training_step_output if len(opt_idx_out)] processed_batch_end_outputs = self._prepare_outputs(batch_end_outputs, batch_mode=True) @@ -153,8 +159,8 @@ class TrainingEpochLoop(Loop): # ----------------------------------------- self.save_loggers_on_train_batch_end() - # update LR schedulers - self.update_lr_schedulers('step') + # update plateau LR scheduler after metrics are logged + self.update_lr_schedulers('step', update_plateau_schedulers=True) self.trainer.checkpoint_connector.has_trained = True self.total_batch_idx += 1 @@ -351,15 +357,13 @@ class TrainingEpochLoop(Loop): processed_outputs = processed_outputs[0] return processed_outputs - def update_lr_schedulers(self, interval: str) -> None: + def update_lr_schedulers(self, interval: str, update_plateau_schedulers: bool) -> None: """updates the lr schedulers based on the given interval""" - if interval == "step": - finished_accumulation = self.batch_loop._accumulated_batches_reached() - finished_epoch = self._num_training_batches_reached() - if not finished_accumulation and not finished_epoch: - return + if interval == "step" and self.batch_loop.should_accumulate(): + return self.trainer.optimizer_connector.update_learning_rates( interval=interval, + update_plateau_schedulers=update_plateau_schedulers, opt_indices=[opt_idx for opt_idx, _ in self.batch_loop.get_active_optimizers(self.total_batch_idx)], ) diff --git a/pytorch_lightning/trainer/connectors/optimizer_connector.py b/pytorch_lightning/trainer/connectors/optimizer_connector.py index 2797504288..be13056bb1 100644 --- a/pytorch_lightning/trainer/connectors/optimizer_connector.py +++ b/pytorch_lightning/trainer/connectors/optimizer_connector.py @@ -29,11 +29,17 @@ class OptimizerConnector: self.trainer.optimizers = [] self.trainer.optimizer_frequencies = [] - def update_learning_rates(self, interval: str, opt_indices: Optional[List[int]] = None) -> None: + def update_learning_rates( + self, interval: str, update_plateau_schedulers: bool, opt_indices: Optional[List[int]] = None + ) -> None: """Update learning rates. Args: interval: either 'epoch' or 'step'. + update_plateau_schedulers: control whether ``ReduceLROnPlateau`` or non-plateau schedulers get updated. + This is used so non-plateau schedulers can be updated before running validation. Checkpoints are + commonly saved during validation, however, on-plateau schedulers might monitor a validation metric + so they have to be updated separately. opt_indices: indices of the optimizers to update. """ if not self.trainer.lr_schedulers or not self.trainer.lightning_module.automatic_optimization: @@ -46,6 +52,9 @@ class OptimizerConnector: if isinstance(lr_scheduler['opt_idx'], int) and lr_scheduler['opt_idx'] not in opt_indices: continue + if update_plateau_schedulers ^ lr_scheduler["reduce_on_plateau"]: + continue + current_idx = self.trainer.train_loop.batch_idx if interval == 'step' else self.trainer.current_epoch current_idx += 1 # account for both batch and epoch starts from 0 # Take step if call to update_learning_rates matches the interval key and diff --git a/tests/callbacks/test_finetuning_callback.py b/tests/callbacks/test_finetuning_callback.py index ef421e9219..fe4cd5cce1 100644 --- a/tests/callbacks/test_finetuning_callback.py +++ b/tests/callbacks/test_finetuning_callback.py @@ -27,7 +27,8 @@ from tests.helpers import BoringModel, RandomDataset class TestBackboneFinetuningCallback(BackboneFinetuning): - def on_train_epoch_end(self, trainer, pl_module): + def on_train_epoch_start(self, trainer, pl_module): + super().on_train_epoch_start(trainer, pl_module) epoch = trainer.current_epoch if self.unfreeze_backbone_at_epoch <= epoch: optimizer = trainer.optimizers[0] diff --git a/tests/checkpointing/test_model_checkpoint.py b/tests/checkpointing/test_model_checkpoint.py index 62b9d8364b..4b1733c135 100644 --- a/tests/checkpointing/test_model_checkpoint.py +++ b/tests/checkpointing/test_model_checkpoint.py @@ -162,10 +162,9 @@ def test_model_checkpoint_score_and_ckpt( if not reduce_lr_on_plateau: actual_step_count = chk['lr_schedulers'][0]['_step_count'] actual_lr = chk['lr_schedulers'][0]['_last_lr'][0] - # if validation_step_none, the checkpoint gets saved after the learning rate update - # so we need to increase the count by one - assert actual_step_count == epoch + 1 + validation_step_none - assert actual_lr == lr * gamma**(epoch + validation_step_none) + # checkpoint is saved after updating lr_scheduler states + assert actual_step_count == epoch + 2 # step_count starts at 1 + assert actual_lr == lr * gamma**(epoch + 1) assert lr_scheduler_debug[epoch]['monitor_val'] == (score if reduce_lr_on_plateau else None) assert lr_scheduler_debug[epoch]['monitor_key'] == (monitor if reduce_lr_on_plateau else None) @@ -262,6 +261,11 @@ def test_model_checkpoint_score_and_ckpt_val_check_interval( global_ix = ix + per_epoch_val_checks * epoch duplicated = bool(version) + # checkpoint saved at the end of training epoch will have updated lr_scheduler states + epoch_end_checkpoint = duplicated + if epoch_aligned: + epoch_end_checkpoint = ix == (per_epoch_val_checks - 1) + score = model.scores[global_ix] expected_score = getattr(model, f'{monitor}s')[global_ix].mean().item() expected_filename = f'{monitor}={score:.4f}-epoch={epoch}{version}.ckpt' @@ -281,8 +285,8 @@ def test_model_checkpoint_score_and_ckpt_val_check_interval( if not reduce_lr_on_plateau: actual_step_count = chk['lr_schedulers'][0]['_step_count'] actual_lr = chk['lr_schedulers'][0]['_last_lr'][0] - assert actual_step_count == epoch + 1 + duplicated - assert actual_lr == lr * gamma**(epoch + duplicated) + assert actual_step_count == epoch + 1 + epoch_end_checkpoint + assert actual_lr == lr * gamma**(epoch + epoch_end_checkpoint) return score diff --git a/tests/trainer/optimization/test_optimizers.py b/tests/trainer/optimization/test_optimizers.py index a81e0eecf5..6165aa1321 100644 --- a/tests/trainer/optimization/test_optimizers.py +++ b/tests/trainer/optimization/test_optimizers.py @@ -18,6 +18,7 @@ import torch from torch import optim from pytorch_lightning import Callback, Trainer +from pytorch_lightning.callbacks import ModelCheckpoint from pytorch_lightning.utilities.exceptions import MisconfigurationException from tests.base import EvalModelTemplate from tests.helpers.boring_model import BoringModel @@ -620,3 +621,87 @@ def test_lr_scheduler_epoch_step_frequency(mocked_sched, check_val_every_n_epoch ) trainer.fit(model) assert mocked_sched.call_count == expected_steps + + +@pytest.mark.parametrize('every_n_train_steps, epoch_interval', [(None, True), (2, False), (2, True)]) +def test_lr_scheduler_state_updated_before_saving(tmpdir, every_n_train_steps, epoch_interval): + batches = 2 + max_epochs = 1 + lr, gamma = 1, 10 + trainer = Trainer( + default_root_dir=tmpdir, + progress_bar_refresh_rate=0, + logger=False, + max_epochs=max_epochs, + limit_train_batches=batches, + limit_val_batches=1, + callbacks=[ModelCheckpoint(dirpath=tmpdir, every_n_train_steps=every_n_train_steps)] + ) + + class TestModel(BoringModel): + + def configure_optimizers(self): + optimizer = torch.optim.SGD(self.parameters(), lr=lr) + lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1, gamma=gamma) + lr_dict = {'scheduler': lr_scheduler} + if not epoch_interval: + lr_dict['interval'] = 'step' + return [optimizer], [lr_dict] + + def on_save_checkpoint(self, checkpoint): + lr_dict = checkpoint['lr_schedulers'][0] + # 2 batches ran. since the lr_dict interval is `step`, the step count should be 2 + assert self.trainer.global_step + 1 == batches # the global step hasn't been increased yet + compare_to = max_epochs if epoch_interval else batches + assert lr_dict['_step_count'] - 1 == compare_to # step count starts at 1 + assert lr_dict['_last_lr'] == [lr * gamma**compare_to] + self.on_save_checkpoint_called = True + + model = TestModel() + trainer.fit(model) + assert model.on_save_checkpoint_called + + +def test_plateau_scheduler_lr_step_interval_updated_after_saving(tmpdir): + batches = 4 + trainer = Trainer( + default_root_dir=tmpdir, + progress_bar_refresh_rate=0, + logger=False, + max_epochs=1, + limit_train_batches=batches, + limit_val_batches=1, + callbacks=[ModelCheckpoint(dirpath=tmpdir)] + ) + + class TestModel(BoringModel): + + def training_step(self, batch, batch_idx, optimizer_idx): + self.log("foo", batch_idx) + return super().training_step(batch, batch_idx) + + def configure_optimizers(self): + optimizer_1 = torch.optim.Adam(self.parameters()) + optimizer_2 = torch.optim.Adam(self.parameters()) + + lr_scheduler1 = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer_1) + lr_dict_1 = {'scheduler': lr_scheduler1, 'interval': 'step', 'monitor': 'foo'} + + lr_scheduler2 = torch.optim.lr_scheduler.StepLR(optimizer_2, step_size=1) + lr_dict_2 = {'scheduler': lr_scheduler2, 'interval': 'step'} + return [optimizer_1, optimizer_2], [lr_dict_1, lr_dict_2] + + def on_save_checkpoint(self, checkpoint): + lr_dict_1 = checkpoint['lr_schedulers'][0] + # since plateau schedulers are updated after saving checkpoint, last_epoch should be 3 + assert lr_dict_1['last_epoch'] == batches - 1 # last epoch starts at 0 + + lr_dict_2 = checkpoint['lr_schedulers'][1] + assert lr_dict_2['_step_count'] - 1 == batches # step count starts at 1 + + self.on_save_checkpoint_called = True + + model = TestModel() + model.training_epoch_end = None + trainer.fit(model) + assert model.on_save_checkpoint_called