[Bugfix] Fixed epoch level schedulers not being called when val_check_interval < 1.0 (#6075)

* fix bug

* fix tests

* changelog

* fix pep8

* fix tests

* fix and add some tests

* add test for rlop

* chlog

* Update CHANGELOG.md

Co-authored-by: rohitgr7 <rohitgr1998@gmail.com>
This commit is contained in:
Nicki Skafte 2021-02-24 12:16:33 +01:00 committed by GitHub
parent a731269056
commit 1b498d1f14
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 169 additions and 27 deletions

View File

@ -45,6 +45,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Expose DeepSpeed loss parameters to allow users to fix loss instability ([#6115](https://github.com/PyTorchLightning/pytorch-lightning/pull/6115))
- Fixed epoch level schedulers not being called when `val_check_interval < 1.0` ([#6075](https://github.com/PyTorchLightning/pytorch-lightning/pull/6075))
## [1.2.1] - 2021-02-23
### Fixed

View File

@ -66,13 +66,21 @@ class OptimizerConnector:
continue
# update LR
old_lr = lr_scheduler['scheduler'].optimizer.param_groups[0]['lr']
if lr_scheduler['reduce_on_plateau']:
lr_scheduler['scheduler'].step(monitor_val)
else:
lr_scheduler['scheduler'].step()
new_lr = lr_scheduler['scheduler'].optimizer.param_groups[0]['lr']
if self.trainer.dev_debugger.enabled:
self.trainer.dev_debugger.track_lr_schedulers_update(
self.trainer.batch_idx, interval, scheduler_idx, old_lr, new_lr, monitor_key=monitor_key
self.trainer.batch_idx,
interval,
scheduler_idx,
old_lr,
new_lr,
monitor_key=monitor_key,
monitor_val=monitor_val
)

View File

@ -478,6 +478,7 @@ class TrainLoop:
train_dataloader = self.trainer.data_connector.get_profiled_train_dataloader(train_dataloader)
dataloader_idx = 0
should_check_val = False
val_loop_called = False
for batch_idx, (batch, is_last_batch) in train_dataloader:
@ -513,6 +514,7 @@ class TrainLoop:
should_check_val = self.should_check_val_fx(batch_idx, is_last_batch)
if should_check_val:
self.trainer.run_evaluation()
val_loop_called = True
# reset stage to train
self.trainer._running_stage = RunningStage.TRAINING
@ -558,21 +560,23 @@ class TrainLoop:
)
should_check_val = self.should_check_val_fx(batch_idx, is_last_batch, on_epoch=True)
should_skip_eval = self.trainer.evaluation_loop.should_skip_evaluation(self.trainer.num_val_batches)
should_train_only = self.trainer.disable_validation or should_skip_eval
# update epoch level lr_schedulers if no val loop outside train loop is triggered
if (val_loop_called and not should_check_val) or should_train_only:
self.trainer.optimizer_connector.update_learning_rates(interval='epoch')
if should_train_only:
self.check_checkpoint_callback(True)
self.check_early_stopping_callback(True)
if should_check_val:
self.trainer.run_evaluation(on_epoch=True)
# reset stage to train
self.trainer._running_stage = RunningStage.TRAINING
should_skip_eval = self.trainer.evaluation_loop.should_skip_evaluation(self.trainer.num_val_batches)
should_train_only = self.trainer.disable_validation or should_skip_eval
if should_train_only:
# update epoch level lr_schedulers
self.trainer.optimizer_connector.update_learning_rates(interval='epoch')
self.check_checkpoint_callback(True)
self.check_early_stopping_callback(True)
# increment the global step once
# progress global step according to grads progress
self.increment_accumulated_grad_global_step()
@ -818,7 +822,7 @@ class TrainLoop:
is_val_check_epoch = (self.trainer.current_epoch + 1) % self.trainer.check_val_every_n_epoch == 0
can_check_val = self.trainer.enable_validation and is_val_check_epoch
is_last_batch_for_infinite_dataset = is_last_batch and self.trainer.val_check_batch == float("inf")
epoch_end_val_check = self.trainer.val_check_batch == self.trainer.num_training_batches
epoch_end_val_check = (batch_idx + 1) % self.trainer.num_training_batches == 0
should_check_val = ((is_val_check_batch and epoch_end_val_check) or self.trainer.should_stop
or is_last_batch_for_infinite_dataset

View File

@ -121,13 +121,16 @@ class InternalDebugger(object):
self.saved_train_losses.append(loss_dict)
@enabled_only
def track_lr_schedulers_update(self, batch_idx, interval, scheduler_idx, old_lr, new_lr, monitor_key=None):
def track_lr_schedulers_update(
self, batch_idx, interval, scheduler_idx, old_lr, new_lr, monitor_key=None, monitor_val=None
):
loss_dict = {
'batch_idx': batch_idx,
'interval': interval,
'scheduler_idx': scheduler_idx,
'epoch': self.trainer.current_epoch,
'monitor_key': monitor_key,
'monitor_val': monitor_val,
'old_lr': old_lr,
'new_lr': new_lr
}

View File

@ -26,6 +26,7 @@ import pytest
import torch
import yaml
from omegaconf import Container, OmegaConf
from torch import optim
import pytorch_lightning as pl
import tests.helpers.utils as tutils
@ -47,8 +48,8 @@ class LogInTwoMethods(BoringModel):
def validation_epoch_end(self, outputs):
outs = torch.stack([x['x'] for x in outputs]).mean()
self.log('epoch', self.current_epoch, on_epoch=True)
self.log('val_acc', outs, on_epoch=True)
self.log('epoch', self.current_epoch)
self.log('val_acc', outs)
@mock.patch.dict(os.environ, {"PL_DEV_DEBUG": "1"})
@ -57,7 +58,8 @@ class LogInTwoMethods(BoringModel):
[('base', "base", 'val_log'), ('base', "base", 'train_log_epoch'), (None, "base", 'train_log_epoch'),
("base", None, 'train_log_epoch')],
)
def test_model_checkpoint_correct_score_and_checkpoint(tmpdir, validation_step, val_dataloaders, monitor):
@pytest.mark.parametrize('reduce_lr_on_plateau', [False, True])
def test_model_checkpoint_score_and_ckpt(tmpdir, validation_step, val_dataloaders, monitor, reduce_lr_on_plateau):
"""
Test that when a model checkpoint is saved, it saves with
the correct score appended to ckpt_path and checkpoint data
@ -65,6 +67,7 @@ def test_model_checkpoint_correct_score_and_checkpoint(tmpdir, validation_step,
max_epochs = 3
limit_train_batches = 5
limit_val_batches = 7
lr = 1e-1
class CustomBoringModel(BoringModel):
@ -74,21 +77,28 @@ def test_model_checkpoint_correct_score_and_checkpoint(tmpdir, validation_step,
self.val_logs = torch.randn(max_epochs, limit_val_batches)
def training_step(self, batch, batch_idx):
out = super().training_step(batch, batch_idx)
log_value = self.train_log_epochs[self.current_epoch, batch_idx]
self.log('train_log', log_value, on_epoch=True)
return out
return super().training_step(batch, batch_idx)
def validation_step(self, batch, batch_idx):
out = super().validation_step(batch, batch_idx)
log_value = self.val_logs[self.current_epoch, batch_idx]
self.log('val_log', log_value)
self.log('epoch', self.current_epoch, on_epoch=True)
return out
return super().validation_step(batch, batch_idx)
def configure_optimizers(self):
optimizer = torch.optim.SGD(self.layer.parameters(), lr=0.2)
lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1)
optimizer = optim.SGD(self.parameters(), lr=lr)
if reduce_lr_on_plateau:
lr_scheduler = {
'scheduler': optim.lr_scheduler.ReduceLROnPlateau(optimizer),
'monitor': monitor,
'strict': True,
}
else:
lr_scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=1)
return [optimizer], [lr_scheduler]
filename = '{' + f'{monitor}' + ':.4f}-{epoch}'
@ -109,11 +119,15 @@ def test_model_checkpoint_correct_score_and_checkpoint(tmpdir, validation_step,
max_epochs=max_epochs,
progress_bar_refresh_rate=0,
)
trainer.fit(model)
results = trainer.fit(model)
assert results
assert trainer.state == TrainerState.FINISHED, f"Training failed with {trainer.state}"
ckpt_files = list(Path(tmpdir).glob('*.ckpt'))
scores = [metric[monitor] for metric in trainer.dev_debugger.logged_metrics if monitor in metric]
lr_scheduler_debug = trainer.dev_debugger.saved_lr_scheduler_updates
assert len(ckpt_files) == len(scores) == max_epochs
assert len(lr_scheduler_debug) == max_epochs
for epoch in range(max_epochs):
score = scores[epoch]
@ -130,9 +144,118 @@ def test_model_checkpoint_correct_score_and_checkpoint(tmpdir, validation_step,
assert mc_specific_data['monitor'] == monitor
assert mc_specific_data['current_score'] == score
lr_scheduler_specific_data = chk['lr_schedulers'][0]
assert lr_scheduler_specific_data['_step_count'] == epoch + 2
assert lr_scheduler_specific_data['_last_lr'][0], 4 == 0.2 * (0.1**(epoch + 1))
if not reduce_lr_on_plateau:
lr_scheduler_specific_data = chk['lr_schedulers'][0]
assert lr_scheduler_specific_data['_step_count'] == epoch + 2
assert lr_scheduler_specific_data['_last_lr'][0] == lr * (lr**(epoch + 1))
assert lr_scheduler_debug[epoch]['monitor_val'] == (score if reduce_lr_on_plateau else None)
assert lr_scheduler_debug[epoch]['monitor_key'] == (monitor if reduce_lr_on_plateau else None)
@mock.patch.dict(os.environ, {"PL_DEV_DEBUG": "1"})
@pytest.mark.parametrize(
"val_check_interval,reduce_lr_on_plateau",
[
(0.25, True),
(0.25, False),
(0.33, False),
],
)
def test_model_checkpoint_score_and_ckpt_val_check_interval(tmpdir, val_check_interval, reduce_lr_on_plateau):
"""
Test that when a model checkpoint is saved, it saves with the correct
score appended to ckpt_path and checkpoint data with val_check_interval
"""
max_epochs = 3
limit_train_batches = 12
limit_val_batches = 7
lr = 1e-1
monitor = 'val_log'
per_epoch_steps = int(limit_train_batches * val_check_interval)
per_epoch_call_count = limit_train_batches // per_epoch_steps
class CustomBoringModel(BoringModel):
def __init__(self):
super().__init__()
self.val_logs = torch.randn(per_epoch_call_count * max_epochs, limit_val_batches)
self.val_loop_count = 0
def validation_step(self, batch, batch_idx):
log_value = self.val_logs[self.val_loop_count, batch_idx]
self.log('val_log', log_value)
self.log('epoch', self.current_epoch, on_epoch=True)
return super().validation_step(batch, batch_idx)
def validation_epoch_end(self, outputs):
self.val_loop_count += 1
super().validation_epoch_end(outputs)
def configure_optimizers(self):
optimizer = optim.SGD(self.parameters(), lr=lr)
if reduce_lr_on_plateau:
lr_scheduler = {
'scheduler': optim.lr_scheduler.ReduceLROnPlateau(optimizer),
'monitor': monitor,
'strict': True,
}
else:
lr_scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=1)
return [optimizer], [lr_scheduler]
filename = '{' + f'{monitor}' + ':.4f}-{epoch}'
checkpoint = ModelCheckpoint(dirpath=tmpdir, filename=filename, monitor=monitor, save_top_k=-1)
model = CustomBoringModel()
trainer = Trainer(
default_root_dir=tmpdir,
callbacks=[checkpoint],
limit_train_batches=limit_train_batches,
limit_val_batches=limit_val_batches,
max_epochs=max_epochs,
val_check_interval=val_check_interval,
progress_bar_refresh_rate=0,
num_sanity_val_steps=0,
)
results = trainer.fit(model)
assert results
assert trainer.state == TrainerState.FINISHED, f"Training failed with {trainer.state}"
ckpt_files = list(Path(tmpdir).glob('*.ckpt'))
scores = [metric[monitor] for metric in trainer.dev_debugger.logged_metrics if monitor in metric]
lr_scheduler_debug = trainer.dev_debugger.saved_lr_scheduler_updates
assert len(ckpt_files) == len(scores) == per_epoch_call_count * max_epochs
assert len(lr_scheduler_debug) == max_epochs
for epoch in range(max_epochs):
for ix in range(per_epoch_call_count):
global_ix = ix + per_epoch_call_count * epoch
score = scores[global_ix]
expected_score = getattr(model, f'{monitor}s')[global_ix].mean().item()
expected_filename = f'{monitor}={score:.4f}-epoch={epoch}.ckpt'
assert math.isclose(score, expected_score, rel_tol=1e-4)
chk = pl_load(os.path.join(checkpoint.dirpath, expected_filename))
assert chk['epoch'] == epoch + 1
assert chk['global_step'] == per_epoch_steps * (global_ix + 1)
mc_specific_data = chk['callbacks'][type(checkpoint)]
assert mc_specific_data['dirpath'] == checkpoint.dirpath
assert mc_specific_data['monitor'] == monitor
assert mc_specific_data['current_score'] == score
if not reduce_lr_on_plateau:
lr_scheduler_specific_data = chk['lr_schedulers'][0]
did_update = 1 if ix + 1 == per_epoch_call_count else 0
assert lr_scheduler_specific_data['_step_count'] == epoch + 1 + did_update
assert lr_scheduler_specific_data['_last_lr'][0] == lr * (lr**(epoch + did_update))
assert lr_scheduler_debug[epoch]['monitor_val'] == (score if reduce_lr_on_plateau else None)
assert lr_scheduler_debug[epoch]['monitor_key'] == (monitor if reduce_lr_on_plateau else None)
@pytest.mark.parametrize("save_top_k", [-1, 0, 1, 2])

View File

@ -34,6 +34,7 @@ def test_optimizer_with_scheduling(tmpdir):
max_epochs=1,
limit_val_batches=0.1,
limit_train_batches=0.2,
val_check_interval=0.5
)
trainer.fit(model)
assert trainer.state == TrainerState.FINISHED, f"Training failed with {trainer.state}"
@ -164,7 +165,7 @@ def test_reducelronplateau_scheduling(tmpdir):
model.configure_optimizers = lambda: {
'optimizer': optimizer,
'lr_scheduler': torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer),
'monitor': 'early_stop_on',
'monitor': 'val_acc',
}
trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True)
trainer.fit(model)
@ -172,7 +173,7 @@ def test_reducelronplateau_scheduling(tmpdir):
lr_scheduler = trainer.lr_schedulers[0]
assert lr_scheduler == dict(
scheduler=lr_scheduler['scheduler'],
monitor='early_stop_on',
monitor='val_acc',
interval='epoch',
frequency=1,
reduce_on_plateau=True,