Always run validation inside the training loop epoch (#7357)

Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com>
This commit is contained in:
Carlos Mocholí 2021-05-26 14:26:48 +02:00 committed by GitHub
parent 037a71b156
commit 311d9fe67e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 64 additions and 86 deletions

View File

@ -62,6 +62,10 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Changed `clip_grad_norm` to use `torch.nn.utils.clip_grad_norm_` ([#7025](https://github.com/PyTorchLightning/pytorch-lightning/pull/7025))
- Validation is now always run inside the training epoch scope ([#7357](https://github.com/PyTorchLightning/pytorch-lightning/pull/7357))
- Refactored Loops
* Moved attributes `global_step`, `current_epoch`, `max/min_steps`, `max/min_epochs`, `batch_idx`, and `total_batch_idx` to TrainLoop ([#7437](https://github.com/PyTorchLightning/pytorch-lightning/pull/7025))
* Refactored result handling in training loop ([#7506](https://github.com/PyTorchLightning/pytorch-lightning/pull/7506))

View File

@ -928,7 +928,7 @@ class Trainer(
self.state.stage = None
raise
def _run_evaluation(self, on_epoch: bool = False) -> _EVALUATE_OUTPUT:
def _run_evaluation(self) -> _EVALUATE_OUTPUT:
if not (self.evaluating or self.sanity_checking):
rank_zero_warn(
f"`trainer._run_evaluation()` was called but the running stage is set to {self.state.stage}."
@ -1010,17 +1010,6 @@ class Trainer(
# hook
self.evaluation_loop.on_evaluation_epoch_end()
# update epoch-level lr_schedulers
if on_epoch:
self.optimizer_connector.update_learning_rates(
interval='epoch',
opt_indices=[
opt_idx for opt_idx, _ in self.train_loop.get_active_optimizers(
batch_idx=(self.train_loop.total_batch_idx - 1)
) # Select the optimizers which were used in the last batch of the epoch
],
)
# log epoch metrics
eval_loop_results = self.logger_connector.get_evaluate_epoch_results()

View File

@ -479,7 +479,6 @@ class TrainLoop:
train_dataloader = self.trainer.data_connector.get_profiled_train_dataloader(train_dataloader)
dataloader_idx = 0
batch_idx = None
is_last_batch = None
for batch_idx, (batch, is_last_batch) in train_dataloader:
self.batch_idx = batch_idx
@ -529,16 +528,13 @@ class TrainLoop:
self.total_batch_idx += 1
max_steps_reached = (
self.max_steps is not None and self.max_steps <= self.global_step + 1
and self._accumulated_batches_reached()
)
if max_steps_reached or self.trainer.should_stop or self._num_training_batches_reached(is_last_batch):
break
# progress global step according to grads progress
self.increment_accumulated_grad_global_step()
max_steps_reached = (self.max_steps is not None and self.max_steps <= self.global_step)
if max_steps_reached or self.trainer.should_stop or self._num_training_batches_reached(is_last_batch):
break
if batch_idx is None:
# dataloader/iterator did not produce a batch
return
@ -546,27 +542,24 @@ class TrainLoop:
# handle epoch_output on epoch end
self.on_train_epoch_end(epoch_output)
# the global step is manually decreased here due to backwards compatibility with existing loggers
# as they expect that the same step is used when logging epoch end metrics even when the batch loop has
# finished. this means the attribute does not exactly track the number of optimizer steps applied.
# TODO(@carmocca): deprecate and rename so users don't get confused
self.global_step -= 1
# log epoch metrics
self.trainer.logger_connector.log_train_epoch_end_metrics(epoch_output)
self.global_step += 1
should_check_val = self._should_check_val_fx(batch_idx, is_last_batch, on_epoch=True)
should_skip_eval = self.trainer.evaluation_loop.should_skip_evaluation(self.trainer.num_val_batches)
should_train_only = self.trainer.disable_validation or should_skip_eval
self.update_lr_schedulers('epoch')
# update epoch level lr_schedulers if no val loop outside train loop is triggered
if not should_check_val or should_train_only:
self.update_lr_schedulers('epoch')
if should_train_only:
did_train_only = self.trainer.disable_validation or self.trainer.evaluation_loop.should_skip_evaluation(
self.trainer.num_val_batches
)
if did_train_only:
self.global_step -= 1
self.check_checkpoint_callback(True)
if should_check_val:
self.trainer.validating = True
self.trainer._run_evaluation(on_epoch=True)
self.trainer.training = True
if batch_output.signal != -1:
self.increment_accumulated_grad_global_step()
self.global_step += 1
def on_train_epoch_end(self, epoch_output: List[List[List[Result]]]) -> None:
# inform logger the batch loop has finished
@ -882,7 +875,7 @@ class TrainLoop:
is_final_batch = self._num_training_batches_reached()
return not (accumulation_done or is_final_batch)
def _should_check_val_fx(self, batch_idx: int, is_last_batch: bool, on_epoch: bool = False) -> bool:
def _should_check_val_fx(self, batch_idx: int, is_last_batch: bool) -> bool:
""" Decide if we should run validation. """
if not self.trainer.enable_validation:
return False
@ -893,26 +886,19 @@ class TrainLoop:
# val_check_batch is inf for iterable datasets with no length defined
is_infinite_dataset = self.trainer.val_check_batch == float('inf')
if on_epoch and is_last_batch and is_infinite_dataset:
if is_last_batch and is_infinite_dataset:
return True
if self.trainer.should_stop:
return True
# TODO: let training/eval loop handle logic around limit_*_batches and val_check_batch
is_val_check_batch = False
if isinstance(self.trainer.limit_train_batches, int) and self.trainer.val_check_batch == float('inf'):
is_val_check_batch = is_last_batch
if isinstance(self.trainer.limit_train_batches, int) and is_infinite_dataset:
is_val_check_batch = (batch_idx + 1) % self.trainer.limit_train_batches == 0
elif self.trainer.val_check_batch != float('inf'):
is_val_check_batch = (batch_idx + 1) % self.trainer.val_check_batch == 0
# Note: num_training_batches is also inf for iterable datasets with no length defined
epoch_end_val_check = (batch_idx + 1) % self.trainer.num_training_batches == 0
if on_epoch:
return is_val_check_batch and epoch_end_val_check
else:
return is_val_check_batch and not epoch_end_val_check
return is_val_check_batch
def _build_kwargs(self, batch, batch_idx, opt_idx, hiddens):
# enable not needing to add opt_idx to training_step

View File

@ -83,8 +83,6 @@ def test_trainer_callback_hook_system_fit(_, tmpdir):
call.on_after_backward(trainer, model),
call.on_train_batch_end(trainer, model, ANY, ANY, 2, 0),
call.on_batch_end(trainer, model),
call.on_train_epoch_end(trainer, model, ANY),
call.on_epoch_end(trainer, model),
call.on_validation_start(trainer, model),
call.on_epoch_start(trainer, model),
call.on_validation_epoch_start(trainer, model),
@ -94,6 +92,8 @@ def test_trainer_callback_hook_system_fit(_, tmpdir):
call.on_epoch_end(trainer, model),
call.on_validation_end(trainer, model),
call.on_save_checkpoint(trainer, model), # should take ANY but we are inspecting signature for BC
call.on_train_epoch_end(trainer, model, ANY),
call.on_epoch_end(trainer, model),
call.on_train_end(trainer, model),
call.on_fit_end(trainer, model),
call.teardown(trainer, model, 'fit'),

View File

@ -169,7 +169,7 @@ def test_early_stopping_patience_train(
model.validation_step = None
early_stop_callback = EarlyStopping(
monitor="train_loss", patience=patience, verbose=True, check_on_train_epoch_end=validation_step_none
monitor="train_loss", patience=patience, verbose=True, check_on_train_epoch_end=True
)
trainer = Trainer(
default_root_dir=tmpdir,

View File

@ -60,9 +60,8 @@ class LogInTwoMethods(BoringModel):
"validation_step_none,val_dataloaders_none,monitor",
[
(False, False, 'val_log'),
(False, False, 'train_log_epoch'),
(True, False, 'train_log_epoch'),
(False, True, 'train_log_epoch'),
(False, True, 'val_log'),
],
)
@pytest.mark.parametrize('reduce_lr_on_plateau', [False, True])
@ -76,7 +75,7 @@ def test_model_checkpoint_score_and_ckpt(
max_epochs = 3
limit_train_batches = 5
limit_val_batches = 7
lr = 1e-1
lr, gamma = 1e-1, 2
class CustomBoringModel(BoringModel):
@ -106,7 +105,7 @@ def test_model_checkpoint_score_and_ckpt(
'strict': True,
}
else:
lr_scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=1)
lr_scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=1, gamma=gamma)
return [optimizer], [lr_scheduler]
@ -153,9 +152,12 @@ def test_model_checkpoint_score_and_ckpt(
assert mc_specific_data['current_score'] == score
if not reduce_lr_on_plateau:
lr_scheduler_specific_data = chk['lr_schedulers'][0]
assert lr_scheduler_specific_data['_step_count'] == epoch + 2
assert lr_scheduler_specific_data['_last_lr'][0] == lr * (lr**(epoch + 1))
actual_step_count = chk['lr_schedulers'][0]['_step_count']
actual_lr = chk['lr_schedulers'][0]['_last_lr'][0]
# if validation_step_none, the checkpoint gets saved after the learning rate update
# so we need to increase the count by one
assert actual_step_count == epoch + 1 + validation_step_none
assert actual_lr == lr * gamma**(epoch + validation_step_none)
assert lr_scheduler_debug[epoch]['monitor_val'] == (score if reduce_lr_on_plateau else None)
assert lr_scheduler_debug[epoch]['monitor_key'] == (monitor if reduce_lr_on_plateau else None)
@ -180,23 +182,21 @@ def test_model_checkpoint_score_and_ckpt_val_check_interval(
max_epochs = 3
limit_train_batches = 12
limit_val_batches = 7
lr = 1e-1
lr, gamma = 1e-1, 2
monitor = 'val_log'
per_epoch_steps = int(limit_train_batches * val_check_interval)
per_epoch_call_count = limit_train_batches // per_epoch_steps
left_over_steps = limit_train_batches % per_epoch_steps
per_val_train_batches = int(limit_train_batches * val_check_interval)
per_epoch_val_checks, leftover_train_batches = divmod(limit_train_batches, per_val_train_batches)
class CustomBoringModel(BoringModel):
def __init__(self):
super().__init__()
self.val_logs = torch.randn(per_epoch_call_count * max_epochs, limit_val_batches)
self.val_logs = torch.randn(per_epoch_val_checks * max_epochs, limit_val_batches)
self.val_loop_count = 0
def validation_step(self, batch, batch_idx):
log_value = self.val_logs[self.val_loop_count, batch_idx]
self.log('val_log', log_value)
self.log('epoch', self.current_epoch, on_epoch=True)
return super().validation_step(batch, batch_idx)
def validation_epoch_end(self, outputs):
@ -213,7 +213,7 @@ def test_model_checkpoint_score_and_ckpt_val_check_interval(
'strict': True,
}
else:
lr_scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=1)
lr_scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=1, gamma=gamma)
return [optimizer], [lr_scheduler]
@ -241,26 +241,27 @@ def test_model_checkpoint_score_and_ckpt_val_check_interval(
# on_train_end ckpt callback is called which creates an additional ckpt in case no ckpt is created at the
# end of epoch, thus if val_check_interval doesn't align with the training steps we create an additional ckpt
additional_ckpt, additional_ckpt_path = 0, None
additional_ckpt, additional_ckpt_path = False, None
if not epoch_aligned:
additional_ckpt_path = [f for f in ckpt_files if 'v1' in f.stem][0]
additional_ckpt = 1
additional_ckpt = True
additional_ckpt = 1 if not epoch_aligned else 0
assert len(ckpt_files) == len(scores) + additional_ckpt == per_epoch_call_count * max_epochs + additional_ckpt
assert len(ckpt_files) == len(scores) + additional_ckpt == per_epoch_val_checks * max_epochs + additional_ckpt
assert len(lr_scheduler_debug) == max_epochs
def _make_assertions(epoch, ix, add=''):
global_ix = ix + per_epoch_call_count * epoch
def _make_assertions(epoch, ix, version=''):
global_ix = ix + per_epoch_val_checks * epoch
duplicated = bool(version)
score = scores[global_ix]
expected_score = getattr(model, f'{monitor}s')[global_ix].mean().item()
expected_filename = f'{monitor}={score:.4f}-epoch={epoch}{add}.ckpt'
expected_filename = f'{monitor}={score:.4f}-epoch={epoch}{version}.ckpt'
assert math.isclose(score, expected_score, rel_tol=1e-4)
chk = pl_load(os.path.join(checkpoint.dirpath, expected_filename))
assert chk['epoch'] == epoch + 1
epoch_num = epoch + (1 if add else 0)
expected_global_step = per_epoch_steps * (global_ix + 1) + (left_over_steps * epoch_num)
epoch_num = epoch + duplicated
expected_global_step = per_val_train_batches * (global_ix + 1) + (leftover_train_batches * epoch_num)
assert chk['global_step'] == expected_global_step
mc_specific_data = chk['callbacks'][type(checkpoint)]
@ -269,15 +270,15 @@ def test_model_checkpoint_score_and_ckpt_val_check_interval(
assert mc_specific_data['current_score'] == score
if not reduce_lr_on_plateau:
lr_scheduler_specific_data = chk['lr_schedulers'][0]
did_update = 1 if (ix + 1 == per_epoch_call_count) and (epoch_aligned or add) else 0
assert lr_scheduler_specific_data['_step_count'] == epoch + 1 + did_update
assert lr_scheduler_specific_data['_last_lr'][0] == lr * (lr**(epoch + did_update))
actual_step_count = chk['lr_schedulers'][0]['_step_count']
actual_lr = chk['lr_schedulers'][0]['_last_lr'][0]
assert actual_step_count == epoch + 1 + duplicated
assert actual_lr == lr * gamma**(epoch + duplicated)
return score
for epoch in range(max_epochs):
for i in range(per_epoch_call_count):
for i in range(per_epoch_val_checks):
score = _make_assertions(epoch, i)
assert lr_scheduler_debug[epoch]['monitor_val'] == (score if reduce_lr_on_plateau else None)
@ -285,9 +286,7 @@ def test_model_checkpoint_score_and_ckpt_val_check_interval(
# check the ckpt file saved on_train_end
if additional_ckpt_path:
epoch = max_epochs - 1
i = per_epoch_call_count - 1
_make_assertions(epoch, i, add='-v1')
_make_assertions(max_epochs - 1, per_epoch_val_checks - 1, version='-v1')
@pytest.mark.parametrize("save_top_k", [-1, 0, 1, 2])

View File

@ -469,9 +469,6 @@ def test_trainer_model_hook_system_fit(tmpdir):
'on_epoch_start',
'on_train_epoch_start',
*(model.train_batch * train_batches),
'training_epoch_end',
'on_train_epoch_end',
'on_epoch_end',
'on_validation_model_eval',
'on_validation_start',
'on_epoch_start',
@ -483,6 +480,9 @@ def test_trainer_model_hook_system_fit(tmpdir):
'on_save_checkpoint',
'on_validation_end',
'on_validation_model_train',
'training_epoch_end',
'on_train_epoch_end',
'on_epoch_end',
'on_train_end',
'on_fit_end',
'teardown_fit',

View File

@ -141,4 +141,4 @@ def test_should_stop_mid_epoch(tmpdir):
assert trainer.current_epoch == 0
assert trainer.global_step == 5
assert model.validation_called_at == (0, 4) # TODO(@carmocca): should be 5 - will be fixed in next PR
assert model.validation_called_at == (0, 4)