[Bugfix] Fixed epoch level schedulers not being called when val_check_interval < 1.0 (#6075)
* fix bug * fix tests * changelog * fix pep8 * fix tests * fix and add some tests * add test for rlop * chlog * Update CHANGELOG.md Co-authored-by: rohitgr7 <rohitgr1998@gmail.com>
This commit is contained in:
parent
a731269056
commit
1b498d1f14
|
@ -45,6 +45,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
|
|||
- Expose DeepSpeed loss parameters to allow users to fix loss instability ([#6115](https://github.com/PyTorchLightning/pytorch-lightning/pull/6115))
|
||||
|
||||
|
||||
- Fixed epoch level schedulers not being called when `val_check_interval < 1.0` ([#6075](https://github.com/PyTorchLightning/pytorch-lightning/pull/6075))
|
||||
|
||||
|
||||
## [1.2.1] - 2021-02-23
|
||||
|
||||
### Fixed
|
||||
|
|
|
@ -66,13 +66,21 @@ class OptimizerConnector:
|
|||
continue
|
||||
# update LR
|
||||
old_lr = lr_scheduler['scheduler'].optimizer.param_groups[0]['lr']
|
||||
|
||||
if lr_scheduler['reduce_on_plateau']:
|
||||
lr_scheduler['scheduler'].step(monitor_val)
|
||||
else:
|
||||
lr_scheduler['scheduler'].step()
|
||||
|
||||
new_lr = lr_scheduler['scheduler'].optimizer.param_groups[0]['lr']
|
||||
|
||||
if self.trainer.dev_debugger.enabled:
|
||||
self.trainer.dev_debugger.track_lr_schedulers_update(
|
||||
self.trainer.batch_idx, interval, scheduler_idx, old_lr, new_lr, monitor_key=monitor_key
|
||||
self.trainer.batch_idx,
|
||||
interval,
|
||||
scheduler_idx,
|
||||
old_lr,
|
||||
new_lr,
|
||||
monitor_key=monitor_key,
|
||||
monitor_val=monitor_val
|
||||
)
|
||||
|
|
|
@ -478,6 +478,7 @@ class TrainLoop:
|
|||
train_dataloader = self.trainer.data_connector.get_profiled_train_dataloader(train_dataloader)
|
||||
dataloader_idx = 0
|
||||
should_check_val = False
|
||||
val_loop_called = False
|
||||
|
||||
for batch_idx, (batch, is_last_batch) in train_dataloader:
|
||||
|
||||
|
@ -513,6 +514,7 @@ class TrainLoop:
|
|||
should_check_val = self.should_check_val_fx(batch_idx, is_last_batch)
|
||||
if should_check_val:
|
||||
self.trainer.run_evaluation()
|
||||
val_loop_called = True
|
||||
|
||||
# reset stage to train
|
||||
self.trainer._running_stage = RunningStage.TRAINING
|
||||
|
@ -558,21 +560,23 @@ class TrainLoop:
|
|||
)
|
||||
|
||||
should_check_val = self.should_check_val_fx(batch_idx, is_last_batch, on_epoch=True)
|
||||
should_skip_eval = self.trainer.evaluation_loop.should_skip_evaluation(self.trainer.num_val_batches)
|
||||
should_train_only = self.trainer.disable_validation or should_skip_eval
|
||||
|
||||
# update epoch level lr_schedulers if no val loop outside train loop is triggered
|
||||
if (val_loop_called and not should_check_val) or should_train_only:
|
||||
self.trainer.optimizer_connector.update_learning_rates(interval='epoch')
|
||||
|
||||
if should_train_only:
|
||||
self.check_checkpoint_callback(True)
|
||||
self.check_early_stopping_callback(True)
|
||||
|
||||
if should_check_val:
|
||||
self.trainer.run_evaluation(on_epoch=True)
|
||||
|
||||
# reset stage to train
|
||||
self.trainer._running_stage = RunningStage.TRAINING
|
||||
|
||||
should_skip_eval = self.trainer.evaluation_loop.should_skip_evaluation(self.trainer.num_val_batches)
|
||||
should_train_only = self.trainer.disable_validation or should_skip_eval
|
||||
|
||||
if should_train_only:
|
||||
# update epoch level lr_schedulers
|
||||
self.trainer.optimizer_connector.update_learning_rates(interval='epoch')
|
||||
self.check_checkpoint_callback(True)
|
||||
self.check_early_stopping_callback(True)
|
||||
|
||||
# increment the global step once
|
||||
# progress global step according to grads progress
|
||||
self.increment_accumulated_grad_global_step()
|
||||
|
@ -818,7 +822,7 @@ class TrainLoop:
|
|||
is_val_check_epoch = (self.trainer.current_epoch + 1) % self.trainer.check_val_every_n_epoch == 0
|
||||
can_check_val = self.trainer.enable_validation and is_val_check_epoch
|
||||
is_last_batch_for_infinite_dataset = is_last_batch and self.trainer.val_check_batch == float("inf")
|
||||
epoch_end_val_check = self.trainer.val_check_batch == self.trainer.num_training_batches
|
||||
epoch_end_val_check = (batch_idx + 1) % self.trainer.num_training_batches == 0
|
||||
|
||||
should_check_val = ((is_val_check_batch and epoch_end_val_check) or self.trainer.should_stop
|
||||
or is_last_batch_for_infinite_dataset
|
||||
|
|
|
@ -121,13 +121,16 @@ class InternalDebugger(object):
|
|||
self.saved_train_losses.append(loss_dict)
|
||||
|
||||
@enabled_only
|
||||
def track_lr_schedulers_update(self, batch_idx, interval, scheduler_idx, old_lr, new_lr, monitor_key=None):
|
||||
def track_lr_schedulers_update(
|
||||
self, batch_idx, interval, scheduler_idx, old_lr, new_lr, monitor_key=None, monitor_val=None
|
||||
):
|
||||
loss_dict = {
|
||||
'batch_idx': batch_idx,
|
||||
'interval': interval,
|
||||
'scheduler_idx': scheduler_idx,
|
||||
'epoch': self.trainer.current_epoch,
|
||||
'monitor_key': monitor_key,
|
||||
'monitor_val': monitor_val,
|
||||
'old_lr': old_lr,
|
||||
'new_lr': new_lr
|
||||
}
|
||||
|
|
|
@ -26,6 +26,7 @@ import pytest
|
|||
import torch
|
||||
import yaml
|
||||
from omegaconf import Container, OmegaConf
|
||||
from torch import optim
|
||||
|
||||
import pytorch_lightning as pl
|
||||
import tests.helpers.utils as tutils
|
||||
|
@ -47,8 +48,8 @@ class LogInTwoMethods(BoringModel):
|
|||
|
||||
def validation_epoch_end(self, outputs):
|
||||
outs = torch.stack([x['x'] for x in outputs]).mean()
|
||||
self.log('epoch', self.current_epoch, on_epoch=True)
|
||||
self.log('val_acc', outs, on_epoch=True)
|
||||
self.log('epoch', self.current_epoch)
|
||||
self.log('val_acc', outs)
|
||||
|
||||
|
||||
@mock.patch.dict(os.environ, {"PL_DEV_DEBUG": "1"})
|
||||
|
@ -57,7 +58,8 @@ class LogInTwoMethods(BoringModel):
|
|||
[('base', "base", 'val_log'), ('base', "base", 'train_log_epoch'), (None, "base", 'train_log_epoch'),
|
||||
("base", None, 'train_log_epoch')],
|
||||
)
|
||||
def test_model_checkpoint_correct_score_and_checkpoint(tmpdir, validation_step, val_dataloaders, monitor):
|
||||
@pytest.mark.parametrize('reduce_lr_on_plateau', [False, True])
|
||||
def test_model_checkpoint_score_and_ckpt(tmpdir, validation_step, val_dataloaders, monitor, reduce_lr_on_plateau):
|
||||
"""
|
||||
Test that when a model checkpoint is saved, it saves with
|
||||
the correct score appended to ckpt_path and checkpoint data
|
||||
|
@ -65,6 +67,7 @@ def test_model_checkpoint_correct_score_and_checkpoint(tmpdir, validation_step,
|
|||
max_epochs = 3
|
||||
limit_train_batches = 5
|
||||
limit_val_batches = 7
|
||||
lr = 1e-1
|
||||
|
||||
class CustomBoringModel(BoringModel):
|
||||
|
||||
|
@ -74,21 +77,28 @@ def test_model_checkpoint_correct_score_and_checkpoint(tmpdir, validation_step,
|
|||
self.val_logs = torch.randn(max_epochs, limit_val_batches)
|
||||
|
||||
def training_step(self, batch, batch_idx):
|
||||
out = super().training_step(batch, batch_idx)
|
||||
log_value = self.train_log_epochs[self.current_epoch, batch_idx]
|
||||
self.log('train_log', log_value, on_epoch=True)
|
||||
return out
|
||||
return super().training_step(batch, batch_idx)
|
||||
|
||||
def validation_step(self, batch, batch_idx):
|
||||
out = super().validation_step(batch, batch_idx)
|
||||
log_value = self.val_logs[self.current_epoch, batch_idx]
|
||||
self.log('val_log', log_value)
|
||||
self.log('epoch', self.current_epoch, on_epoch=True)
|
||||
return out
|
||||
return super().validation_step(batch, batch_idx)
|
||||
|
||||
def configure_optimizers(self):
|
||||
optimizer = torch.optim.SGD(self.layer.parameters(), lr=0.2)
|
||||
lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1)
|
||||
optimizer = optim.SGD(self.parameters(), lr=lr)
|
||||
|
||||
if reduce_lr_on_plateau:
|
||||
lr_scheduler = {
|
||||
'scheduler': optim.lr_scheduler.ReduceLROnPlateau(optimizer),
|
||||
'monitor': monitor,
|
||||
'strict': True,
|
||||
}
|
||||
else:
|
||||
lr_scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=1)
|
||||
|
||||
return [optimizer], [lr_scheduler]
|
||||
|
||||
filename = '{' + f'{monitor}' + ':.4f}-{epoch}'
|
||||
|
@ -109,11 +119,15 @@ def test_model_checkpoint_correct_score_and_checkpoint(tmpdir, validation_step,
|
|||
max_epochs=max_epochs,
|
||||
progress_bar_refresh_rate=0,
|
||||
)
|
||||
trainer.fit(model)
|
||||
results = trainer.fit(model)
|
||||
assert results
|
||||
assert trainer.state == TrainerState.FINISHED, f"Training failed with {trainer.state}"
|
||||
|
||||
ckpt_files = list(Path(tmpdir).glob('*.ckpt'))
|
||||
scores = [metric[monitor] for metric in trainer.dev_debugger.logged_metrics if monitor in metric]
|
||||
lr_scheduler_debug = trainer.dev_debugger.saved_lr_scheduler_updates
|
||||
assert len(ckpt_files) == len(scores) == max_epochs
|
||||
assert len(lr_scheduler_debug) == max_epochs
|
||||
|
||||
for epoch in range(max_epochs):
|
||||
score = scores[epoch]
|
||||
|
@ -130,9 +144,118 @@ def test_model_checkpoint_correct_score_and_checkpoint(tmpdir, validation_step,
|
|||
assert mc_specific_data['monitor'] == monitor
|
||||
assert mc_specific_data['current_score'] == score
|
||||
|
||||
lr_scheduler_specific_data = chk['lr_schedulers'][0]
|
||||
assert lr_scheduler_specific_data['_step_count'] == epoch + 2
|
||||
assert lr_scheduler_specific_data['_last_lr'][0], 4 == 0.2 * (0.1**(epoch + 1))
|
||||
if not reduce_lr_on_plateau:
|
||||
lr_scheduler_specific_data = chk['lr_schedulers'][0]
|
||||
assert lr_scheduler_specific_data['_step_count'] == epoch + 2
|
||||
assert lr_scheduler_specific_data['_last_lr'][0] == lr * (lr**(epoch + 1))
|
||||
|
||||
assert lr_scheduler_debug[epoch]['monitor_val'] == (score if reduce_lr_on_plateau else None)
|
||||
assert lr_scheduler_debug[epoch]['monitor_key'] == (monitor if reduce_lr_on_plateau else None)
|
||||
|
||||
|
||||
@mock.patch.dict(os.environ, {"PL_DEV_DEBUG": "1"})
|
||||
@pytest.mark.parametrize(
|
||||
"val_check_interval,reduce_lr_on_plateau",
|
||||
[
|
||||
(0.25, True),
|
||||
(0.25, False),
|
||||
(0.33, False),
|
||||
],
|
||||
)
|
||||
def test_model_checkpoint_score_and_ckpt_val_check_interval(tmpdir, val_check_interval, reduce_lr_on_plateau):
|
||||
"""
|
||||
Test that when a model checkpoint is saved, it saves with the correct
|
||||
score appended to ckpt_path and checkpoint data with val_check_interval
|
||||
"""
|
||||
max_epochs = 3
|
||||
limit_train_batches = 12
|
||||
limit_val_batches = 7
|
||||
lr = 1e-1
|
||||
monitor = 'val_log'
|
||||
per_epoch_steps = int(limit_train_batches * val_check_interval)
|
||||
per_epoch_call_count = limit_train_batches // per_epoch_steps
|
||||
|
||||
class CustomBoringModel(BoringModel):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.val_logs = torch.randn(per_epoch_call_count * max_epochs, limit_val_batches)
|
||||
self.val_loop_count = 0
|
||||
|
||||
def validation_step(self, batch, batch_idx):
|
||||
log_value = self.val_logs[self.val_loop_count, batch_idx]
|
||||
self.log('val_log', log_value)
|
||||
self.log('epoch', self.current_epoch, on_epoch=True)
|
||||
return super().validation_step(batch, batch_idx)
|
||||
|
||||
def validation_epoch_end(self, outputs):
|
||||
self.val_loop_count += 1
|
||||
super().validation_epoch_end(outputs)
|
||||
|
||||
def configure_optimizers(self):
|
||||
optimizer = optim.SGD(self.parameters(), lr=lr)
|
||||
|
||||
if reduce_lr_on_plateau:
|
||||
lr_scheduler = {
|
||||
'scheduler': optim.lr_scheduler.ReduceLROnPlateau(optimizer),
|
||||
'monitor': monitor,
|
||||
'strict': True,
|
||||
}
|
||||
else:
|
||||
lr_scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=1)
|
||||
|
||||
return [optimizer], [lr_scheduler]
|
||||
|
||||
filename = '{' + f'{monitor}' + ':.4f}-{epoch}'
|
||||
checkpoint = ModelCheckpoint(dirpath=tmpdir, filename=filename, monitor=monitor, save_top_k=-1)
|
||||
|
||||
model = CustomBoringModel()
|
||||
|
||||
trainer = Trainer(
|
||||
default_root_dir=tmpdir,
|
||||
callbacks=[checkpoint],
|
||||
limit_train_batches=limit_train_batches,
|
||||
limit_val_batches=limit_val_batches,
|
||||
max_epochs=max_epochs,
|
||||
val_check_interval=val_check_interval,
|
||||
progress_bar_refresh_rate=0,
|
||||
num_sanity_val_steps=0,
|
||||
)
|
||||
results = trainer.fit(model)
|
||||
assert results
|
||||
assert trainer.state == TrainerState.FINISHED, f"Training failed with {trainer.state}"
|
||||
|
||||
ckpt_files = list(Path(tmpdir).glob('*.ckpt'))
|
||||
scores = [metric[monitor] for metric in trainer.dev_debugger.logged_metrics if monitor in metric]
|
||||
lr_scheduler_debug = trainer.dev_debugger.saved_lr_scheduler_updates
|
||||
assert len(ckpt_files) == len(scores) == per_epoch_call_count * max_epochs
|
||||
assert len(lr_scheduler_debug) == max_epochs
|
||||
|
||||
for epoch in range(max_epochs):
|
||||
for ix in range(per_epoch_call_count):
|
||||
global_ix = ix + per_epoch_call_count * epoch
|
||||
score = scores[global_ix]
|
||||
expected_score = getattr(model, f'{monitor}s')[global_ix].mean().item()
|
||||
expected_filename = f'{monitor}={score:.4f}-epoch={epoch}.ckpt'
|
||||
assert math.isclose(score, expected_score, rel_tol=1e-4)
|
||||
|
||||
chk = pl_load(os.path.join(checkpoint.dirpath, expected_filename))
|
||||
assert chk['epoch'] == epoch + 1
|
||||
assert chk['global_step'] == per_epoch_steps * (global_ix + 1)
|
||||
|
||||
mc_specific_data = chk['callbacks'][type(checkpoint)]
|
||||
assert mc_specific_data['dirpath'] == checkpoint.dirpath
|
||||
assert mc_specific_data['monitor'] == monitor
|
||||
assert mc_specific_data['current_score'] == score
|
||||
|
||||
if not reduce_lr_on_plateau:
|
||||
lr_scheduler_specific_data = chk['lr_schedulers'][0]
|
||||
did_update = 1 if ix + 1 == per_epoch_call_count else 0
|
||||
assert lr_scheduler_specific_data['_step_count'] == epoch + 1 + did_update
|
||||
assert lr_scheduler_specific_data['_last_lr'][0] == lr * (lr**(epoch + did_update))
|
||||
|
||||
assert lr_scheduler_debug[epoch]['monitor_val'] == (score if reduce_lr_on_plateau else None)
|
||||
assert lr_scheduler_debug[epoch]['monitor_key'] == (monitor if reduce_lr_on_plateau else None)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("save_top_k", [-1, 0, 1, 2])
|
||||
|
|
|
@ -34,6 +34,7 @@ def test_optimizer_with_scheduling(tmpdir):
|
|||
max_epochs=1,
|
||||
limit_val_batches=0.1,
|
||||
limit_train_batches=0.2,
|
||||
val_check_interval=0.5
|
||||
)
|
||||
trainer.fit(model)
|
||||
assert trainer.state == TrainerState.FINISHED, f"Training failed with {trainer.state}"
|
||||
|
@ -164,7 +165,7 @@ def test_reducelronplateau_scheduling(tmpdir):
|
|||
model.configure_optimizers = lambda: {
|
||||
'optimizer': optimizer,
|
||||
'lr_scheduler': torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer),
|
||||
'monitor': 'early_stop_on',
|
||||
'monitor': 'val_acc',
|
||||
}
|
||||
trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True)
|
||||
trainer.fit(model)
|
||||
|
@ -172,7 +173,7 @@ def test_reducelronplateau_scheduling(tmpdir):
|
|||
lr_scheduler = trainer.lr_schedulers[0]
|
||||
assert lr_scheduler == dict(
|
||||
scheduler=lr_scheduler['scheduler'],
|
||||
monitor='early_stop_on',
|
||||
monitor='val_acc',
|
||||
interval='epoch',
|
||||
frequency=1,
|
||||
reduce_on_plateau=True,
|
||||
|
|
Loading…
Reference in New Issue