Fix checkpointed state for lr_schedulers with step interval (#7877)
Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com> Co-authored-by: Carlos Mocholi <carlossmocholi@gmail.com> Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>
This commit is contained in:
parent
2303f9ced8
commit
d1efae2e47
|
@ -260,6 +260,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
|
|||
### Fixed
|
||||
|
||||
|
||||
- Fixed `lr_scheduler` checkpointed state by calling `update_lr_schedulers` before saving checkpoints ([#7877](https://github.com/PyTorchLightning/pytorch-lightning/pull/7877))
|
||||
|
||||
|
||||
- Fixed ambiguous warning when both overfit and train dataloader shuffling are enabled ([#7685](https://github.com/PyTorchLightning/pytorch-lightning/pull/7685))
|
||||
|
||||
|
||||
|
|
|
@ -219,7 +219,7 @@ class FitLoop(Loop):
|
|||
if self.training_loop.batches_seen == 0:
|
||||
return
|
||||
|
||||
self.training_loop.update_lr_schedulers('epoch')
|
||||
self.training_loop.update_lr_schedulers('epoch', update_plateau_schedulers=True)
|
||||
|
||||
did_train_only = self.trainer.disable_validation or self.trainer.evaluation_loop.skip
|
||||
if did_train_only:
|
||||
|
|
|
@ -115,6 +115,12 @@ class TrainingEpochLoop(Loop):
|
|||
if batch_output.signal == -1:
|
||||
raise StopIteration
|
||||
|
||||
# update non-plateau LR schedulers
|
||||
# update epoch-interval ones only when we are at the end of training epoch
|
||||
self.update_lr_schedulers('step', update_plateau_schedulers=False)
|
||||
if self._num_training_batches_reached(is_last):
|
||||
self.update_lr_schedulers('epoch', update_plateau_schedulers=False)
|
||||
|
||||
batch_end_outputs = [opt_idx_out for opt_idx_out in batch_output.training_step_output if len(opt_idx_out)]
|
||||
processed_batch_end_outputs = self._prepare_outputs(batch_end_outputs, batch_mode=True)
|
||||
|
||||
|
@ -153,8 +159,8 @@ class TrainingEpochLoop(Loop):
|
|||
# -----------------------------------------
|
||||
self.save_loggers_on_train_batch_end()
|
||||
|
||||
# update LR schedulers
|
||||
self.update_lr_schedulers('step')
|
||||
# update plateau LR scheduler after metrics are logged
|
||||
self.update_lr_schedulers('step', update_plateau_schedulers=True)
|
||||
self.trainer.checkpoint_connector.has_trained = True
|
||||
|
||||
self.total_batch_idx += 1
|
||||
|
@ -351,15 +357,13 @@ class TrainingEpochLoop(Loop):
|
|||
processed_outputs = processed_outputs[0]
|
||||
return processed_outputs
|
||||
|
||||
def update_lr_schedulers(self, interval: str) -> None:
|
||||
def update_lr_schedulers(self, interval: str, update_plateau_schedulers: bool) -> None:
|
||||
"""updates the lr schedulers based on the given interval"""
|
||||
if interval == "step":
|
||||
finished_accumulation = self.batch_loop._accumulated_batches_reached()
|
||||
finished_epoch = self._num_training_batches_reached()
|
||||
if not finished_accumulation and not finished_epoch:
|
||||
if interval == "step" and self.batch_loop.should_accumulate():
|
||||
return
|
||||
self.trainer.optimizer_connector.update_learning_rates(
|
||||
interval=interval,
|
||||
update_plateau_schedulers=update_plateau_schedulers,
|
||||
opt_indices=[opt_idx for opt_idx, _ in self.batch_loop.get_active_optimizers(self.total_batch_idx)],
|
||||
)
|
||||
|
||||
|
|
|
@ -29,11 +29,17 @@ class OptimizerConnector:
|
|||
self.trainer.optimizers = []
|
||||
self.trainer.optimizer_frequencies = []
|
||||
|
||||
def update_learning_rates(self, interval: str, opt_indices: Optional[List[int]] = None) -> None:
|
||||
def update_learning_rates(
|
||||
self, interval: str, update_plateau_schedulers: bool, opt_indices: Optional[List[int]] = None
|
||||
) -> None:
|
||||
"""Update learning rates.
|
||||
|
||||
Args:
|
||||
interval: either 'epoch' or 'step'.
|
||||
update_plateau_schedulers: control whether ``ReduceLROnPlateau`` or non-plateau schedulers get updated.
|
||||
This is used so non-plateau schedulers can be updated before running validation. Checkpoints are
|
||||
commonly saved during validation, however, on-plateau schedulers might monitor a validation metric
|
||||
so they have to be updated separately.
|
||||
opt_indices: indices of the optimizers to update.
|
||||
"""
|
||||
if not self.trainer.lr_schedulers or not self.trainer.lightning_module.automatic_optimization:
|
||||
|
@ -46,6 +52,9 @@ class OptimizerConnector:
|
|||
if isinstance(lr_scheduler['opt_idx'], int) and lr_scheduler['opt_idx'] not in opt_indices:
|
||||
continue
|
||||
|
||||
if update_plateau_schedulers ^ lr_scheduler["reduce_on_plateau"]:
|
||||
continue
|
||||
|
||||
current_idx = self.trainer.train_loop.batch_idx if interval == 'step' else self.trainer.current_epoch
|
||||
current_idx += 1 # account for both batch and epoch starts from 0
|
||||
# Take step if call to update_learning_rates matches the interval key and
|
||||
|
|
|
@ -27,7 +27,8 @@ from tests.helpers import BoringModel, RandomDataset
|
|||
|
||||
class TestBackboneFinetuningCallback(BackboneFinetuning):
|
||||
|
||||
def on_train_epoch_end(self, trainer, pl_module):
|
||||
def on_train_epoch_start(self, trainer, pl_module):
|
||||
super().on_train_epoch_start(trainer, pl_module)
|
||||
epoch = trainer.current_epoch
|
||||
if self.unfreeze_backbone_at_epoch <= epoch:
|
||||
optimizer = trainer.optimizers[0]
|
||||
|
|
|
@ -162,10 +162,9 @@ def test_model_checkpoint_score_and_ckpt(
|
|||
if not reduce_lr_on_plateau:
|
||||
actual_step_count = chk['lr_schedulers'][0]['_step_count']
|
||||
actual_lr = chk['lr_schedulers'][0]['_last_lr'][0]
|
||||
# if validation_step_none, the checkpoint gets saved after the learning rate update
|
||||
# so we need to increase the count by one
|
||||
assert actual_step_count == epoch + 1 + validation_step_none
|
||||
assert actual_lr == lr * gamma**(epoch + validation_step_none)
|
||||
# checkpoint is saved after updating lr_scheduler states
|
||||
assert actual_step_count == epoch + 2 # step_count starts at 1
|
||||
assert actual_lr == lr * gamma**(epoch + 1)
|
||||
|
||||
assert lr_scheduler_debug[epoch]['monitor_val'] == (score if reduce_lr_on_plateau else None)
|
||||
assert lr_scheduler_debug[epoch]['monitor_key'] == (monitor if reduce_lr_on_plateau else None)
|
||||
|
@ -262,6 +261,11 @@ def test_model_checkpoint_score_and_ckpt_val_check_interval(
|
|||
global_ix = ix + per_epoch_val_checks * epoch
|
||||
duplicated = bool(version)
|
||||
|
||||
# checkpoint saved at the end of training epoch will have updated lr_scheduler states
|
||||
epoch_end_checkpoint = duplicated
|
||||
if epoch_aligned:
|
||||
epoch_end_checkpoint = ix == (per_epoch_val_checks - 1)
|
||||
|
||||
score = model.scores[global_ix]
|
||||
expected_score = getattr(model, f'{monitor}s')[global_ix].mean().item()
|
||||
expected_filename = f'{monitor}={score:.4f}-epoch={epoch}{version}.ckpt'
|
||||
|
@ -281,8 +285,8 @@ def test_model_checkpoint_score_and_ckpt_val_check_interval(
|
|||
if not reduce_lr_on_plateau:
|
||||
actual_step_count = chk['lr_schedulers'][0]['_step_count']
|
||||
actual_lr = chk['lr_schedulers'][0]['_last_lr'][0]
|
||||
assert actual_step_count == epoch + 1 + duplicated
|
||||
assert actual_lr == lr * gamma**(epoch + duplicated)
|
||||
assert actual_step_count == epoch + 1 + epoch_end_checkpoint
|
||||
assert actual_lr == lr * gamma**(epoch + epoch_end_checkpoint)
|
||||
|
||||
return score
|
||||
|
||||
|
|
|
@ -18,6 +18,7 @@ import torch
|
|||
from torch import optim
|
||||
|
||||
from pytorch_lightning import Callback, Trainer
|
||||
from pytorch_lightning.callbacks import ModelCheckpoint
|
||||
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
||||
from tests.base import EvalModelTemplate
|
||||
from tests.helpers.boring_model import BoringModel
|
||||
|
@ -620,3 +621,87 @@ def test_lr_scheduler_epoch_step_frequency(mocked_sched, check_val_every_n_epoch
|
|||
)
|
||||
trainer.fit(model)
|
||||
assert mocked_sched.call_count == expected_steps
|
||||
|
||||
|
||||
@pytest.mark.parametrize('every_n_train_steps, epoch_interval', [(None, True), (2, False), (2, True)])
|
||||
def test_lr_scheduler_state_updated_before_saving(tmpdir, every_n_train_steps, epoch_interval):
|
||||
batches = 2
|
||||
max_epochs = 1
|
||||
lr, gamma = 1, 10
|
||||
trainer = Trainer(
|
||||
default_root_dir=tmpdir,
|
||||
progress_bar_refresh_rate=0,
|
||||
logger=False,
|
||||
max_epochs=max_epochs,
|
||||
limit_train_batches=batches,
|
||||
limit_val_batches=1,
|
||||
callbacks=[ModelCheckpoint(dirpath=tmpdir, every_n_train_steps=every_n_train_steps)]
|
||||
)
|
||||
|
||||
class TestModel(BoringModel):
|
||||
|
||||
def configure_optimizers(self):
|
||||
optimizer = torch.optim.SGD(self.parameters(), lr=lr)
|
||||
lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1, gamma=gamma)
|
||||
lr_dict = {'scheduler': lr_scheduler}
|
||||
if not epoch_interval:
|
||||
lr_dict['interval'] = 'step'
|
||||
return [optimizer], [lr_dict]
|
||||
|
||||
def on_save_checkpoint(self, checkpoint):
|
||||
lr_dict = checkpoint['lr_schedulers'][0]
|
||||
# 2 batches ran. since the lr_dict interval is `step`, the step count should be 2
|
||||
assert self.trainer.global_step + 1 == batches # the global step hasn't been increased yet
|
||||
compare_to = max_epochs if epoch_interval else batches
|
||||
assert lr_dict['_step_count'] - 1 == compare_to # step count starts at 1
|
||||
assert lr_dict['_last_lr'] == [lr * gamma**compare_to]
|
||||
self.on_save_checkpoint_called = True
|
||||
|
||||
model = TestModel()
|
||||
trainer.fit(model)
|
||||
assert model.on_save_checkpoint_called
|
||||
|
||||
|
||||
def test_plateau_scheduler_lr_step_interval_updated_after_saving(tmpdir):
|
||||
batches = 4
|
||||
trainer = Trainer(
|
||||
default_root_dir=tmpdir,
|
||||
progress_bar_refresh_rate=0,
|
||||
logger=False,
|
||||
max_epochs=1,
|
||||
limit_train_batches=batches,
|
||||
limit_val_batches=1,
|
||||
callbacks=[ModelCheckpoint(dirpath=tmpdir)]
|
||||
)
|
||||
|
||||
class TestModel(BoringModel):
|
||||
|
||||
def training_step(self, batch, batch_idx, optimizer_idx):
|
||||
self.log("foo", batch_idx)
|
||||
return super().training_step(batch, batch_idx)
|
||||
|
||||
def configure_optimizers(self):
|
||||
optimizer_1 = torch.optim.Adam(self.parameters())
|
||||
optimizer_2 = torch.optim.Adam(self.parameters())
|
||||
|
||||
lr_scheduler1 = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer_1)
|
||||
lr_dict_1 = {'scheduler': lr_scheduler1, 'interval': 'step', 'monitor': 'foo'}
|
||||
|
||||
lr_scheduler2 = torch.optim.lr_scheduler.StepLR(optimizer_2, step_size=1)
|
||||
lr_dict_2 = {'scheduler': lr_scheduler2, 'interval': 'step'}
|
||||
return [optimizer_1, optimizer_2], [lr_dict_1, lr_dict_2]
|
||||
|
||||
def on_save_checkpoint(self, checkpoint):
|
||||
lr_dict_1 = checkpoint['lr_schedulers'][0]
|
||||
# since plateau schedulers are updated after saving checkpoint, last_epoch should be 3
|
||||
assert lr_dict_1['last_epoch'] == batches - 1 # last epoch starts at 0
|
||||
|
||||
lr_dict_2 = checkpoint['lr_schedulers'][1]
|
||||
assert lr_dict_2['_step_count'] - 1 == batches # step count starts at 1
|
||||
|
||||
self.on_save_checkpoint_called = True
|
||||
|
||||
model = TestModel()
|
||||
model.training_epoch_end = None
|
||||
trainer.fit(model)
|
||||
assert model.on_save_checkpoint_called
|
||||
|
|
Loading…
Reference in New Issue