Always run validation inside the training loop epoch (#7357)
Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com>
This commit is contained in:
parent
037a71b156
commit
311d9fe67e
|
@ -62,6 +62,10 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
|
|||
|
||||
- Changed `clip_grad_norm` to use `torch.nn.utils.clip_grad_norm_` ([#7025](https://github.com/PyTorchLightning/pytorch-lightning/pull/7025))
|
||||
|
||||
|
||||
- Validation is now always run inside the training epoch scope ([#7357](https://github.com/PyTorchLightning/pytorch-lightning/pull/7357))
|
||||
|
||||
|
||||
- Refactored Loops
|
||||
* Moved attributes `global_step`, `current_epoch`, `max/min_steps`, `max/min_epochs`, `batch_idx`, and `total_batch_idx` to TrainLoop ([#7437](https://github.com/PyTorchLightning/pytorch-lightning/pull/7025))
|
||||
* Refactored result handling in training loop ([#7506](https://github.com/PyTorchLightning/pytorch-lightning/pull/7506))
|
||||
|
|
|
@ -928,7 +928,7 @@ class Trainer(
|
|||
self.state.stage = None
|
||||
raise
|
||||
|
||||
def _run_evaluation(self, on_epoch: bool = False) -> _EVALUATE_OUTPUT:
|
||||
def _run_evaluation(self) -> _EVALUATE_OUTPUT:
|
||||
if not (self.evaluating or self.sanity_checking):
|
||||
rank_zero_warn(
|
||||
f"`trainer._run_evaluation()` was called but the running stage is set to {self.state.stage}."
|
||||
|
@ -1010,17 +1010,6 @@ class Trainer(
|
|||
# hook
|
||||
self.evaluation_loop.on_evaluation_epoch_end()
|
||||
|
||||
# update epoch-level lr_schedulers
|
||||
if on_epoch:
|
||||
self.optimizer_connector.update_learning_rates(
|
||||
interval='epoch',
|
||||
opt_indices=[
|
||||
opt_idx for opt_idx, _ in self.train_loop.get_active_optimizers(
|
||||
batch_idx=(self.train_loop.total_batch_idx - 1)
|
||||
) # Select the optimizers which were used in the last batch of the epoch
|
||||
],
|
||||
)
|
||||
|
||||
# log epoch metrics
|
||||
eval_loop_results = self.logger_connector.get_evaluate_epoch_results()
|
||||
|
||||
|
|
|
@ -479,7 +479,6 @@ class TrainLoop:
|
|||
train_dataloader = self.trainer.data_connector.get_profiled_train_dataloader(train_dataloader)
|
||||
dataloader_idx = 0
|
||||
batch_idx = None
|
||||
is_last_batch = None
|
||||
|
||||
for batch_idx, (batch, is_last_batch) in train_dataloader:
|
||||
self.batch_idx = batch_idx
|
||||
|
@ -529,16 +528,13 @@ class TrainLoop:
|
|||
|
||||
self.total_batch_idx += 1
|
||||
|
||||
max_steps_reached = (
|
||||
self.max_steps is not None and self.max_steps <= self.global_step + 1
|
||||
and self._accumulated_batches_reached()
|
||||
)
|
||||
if max_steps_reached or self.trainer.should_stop or self._num_training_batches_reached(is_last_batch):
|
||||
break
|
||||
|
||||
# progress global step according to grads progress
|
||||
self.increment_accumulated_grad_global_step()
|
||||
|
||||
max_steps_reached = (self.max_steps is not None and self.max_steps <= self.global_step)
|
||||
if max_steps_reached or self.trainer.should_stop or self._num_training_batches_reached(is_last_batch):
|
||||
break
|
||||
|
||||
if batch_idx is None:
|
||||
# dataloader/iterator did not produce a batch
|
||||
return
|
||||
|
@ -546,27 +542,24 @@ class TrainLoop:
|
|||
# handle epoch_output on epoch end
|
||||
self.on_train_epoch_end(epoch_output)
|
||||
|
||||
# the global step is manually decreased here due to backwards compatibility with existing loggers
|
||||
# as they expect that the same step is used when logging epoch end metrics even when the batch loop has
|
||||
# finished. this means the attribute does not exactly track the number of optimizer steps applied.
|
||||
# TODO(@carmocca): deprecate and rename so users don't get confused
|
||||
self.global_step -= 1
|
||||
# log epoch metrics
|
||||
self.trainer.logger_connector.log_train_epoch_end_metrics(epoch_output)
|
||||
self.global_step += 1
|
||||
|
||||
should_check_val = self._should_check_val_fx(batch_idx, is_last_batch, on_epoch=True)
|
||||
should_skip_eval = self.trainer.evaluation_loop.should_skip_evaluation(self.trainer.num_val_batches)
|
||||
should_train_only = self.trainer.disable_validation or should_skip_eval
|
||||
self.update_lr_schedulers('epoch')
|
||||
|
||||
# update epoch level lr_schedulers if no val loop outside train loop is triggered
|
||||
if not should_check_val or should_train_only:
|
||||
self.update_lr_schedulers('epoch')
|
||||
|
||||
if should_train_only:
|
||||
did_train_only = self.trainer.disable_validation or self.trainer.evaluation_loop.should_skip_evaluation(
|
||||
self.trainer.num_val_batches
|
||||
)
|
||||
if did_train_only:
|
||||
self.global_step -= 1
|
||||
self.check_checkpoint_callback(True)
|
||||
|
||||
if should_check_val:
|
||||
self.trainer.validating = True
|
||||
self.trainer._run_evaluation(on_epoch=True)
|
||||
self.trainer.training = True
|
||||
|
||||
if batch_output.signal != -1:
|
||||
self.increment_accumulated_grad_global_step()
|
||||
self.global_step += 1
|
||||
|
||||
def on_train_epoch_end(self, epoch_output: List[List[List[Result]]]) -> None:
|
||||
# inform logger the batch loop has finished
|
||||
|
@ -882,7 +875,7 @@ class TrainLoop:
|
|||
is_final_batch = self._num_training_batches_reached()
|
||||
return not (accumulation_done or is_final_batch)
|
||||
|
||||
def _should_check_val_fx(self, batch_idx: int, is_last_batch: bool, on_epoch: bool = False) -> bool:
|
||||
def _should_check_val_fx(self, batch_idx: int, is_last_batch: bool) -> bool:
|
||||
""" Decide if we should run validation. """
|
||||
if not self.trainer.enable_validation:
|
||||
return False
|
||||
|
@ -893,26 +886,19 @@ class TrainLoop:
|
|||
|
||||
# val_check_batch is inf for iterable datasets with no length defined
|
||||
is_infinite_dataset = self.trainer.val_check_batch == float('inf')
|
||||
if on_epoch and is_last_batch and is_infinite_dataset:
|
||||
if is_last_batch and is_infinite_dataset:
|
||||
return True
|
||||
|
||||
if self.trainer.should_stop:
|
||||
return True
|
||||
|
||||
# TODO: let training/eval loop handle logic around limit_*_batches and val_check_batch
|
||||
is_val_check_batch = False
|
||||
if isinstance(self.trainer.limit_train_batches, int) and self.trainer.val_check_batch == float('inf'):
|
||||
is_val_check_batch = is_last_batch
|
||||
if isinstance(self.trainer.limit_train_batches, int) and is_infinite_dataset:
|
||||
is_val_check_batch = (batch_idx + 1) % self.trainer.limit_train_batches == 0
|
||||
elif self.trainer.val_check_batch != float('inf'):
|
||||
is_val_check_batch = (batch_idx + 1) % self.trainer.val_check_batch == 0
|
||||
|
||||
# Note: num_training_batches is also inf for iterable datasets with no length defined
|
||||
epoch_end_val_check = (batch_idx + 1) % self.trainer.num_training_batches == 0
|
||||
|
||||
if on_epoch:
|
||||
return is_val_check_batch and epoch_end_val_check
|
||||
else:
|
||||
return is_val_check_batch and not epoch_end_val_check
|
||||
return is_val_check_batch
|
||||
|
||||
def _build_kwargs(self, batch, batch_idx, opt_idx, hiddens):
|
||||
# enable not needing to add opt_idx to training_step
|
||||
|
|
|
@ -83,8 +83,6 @@ def test_trainer_callback_hook_system_fit(_, tmpdir):
|
|||
call.on_after_backward(trainer, model),
|
||||
call.on_train_batch_end(trainer, model, ANY, ANY, 2, 0),
|
||||
call.on_batch_end(trainer, model),
|
||||
call.on_train_epoch_end(trainer, model, ANY),
|
||||
call.on_epoch_end(trainer, model),
|
||||
call.on_validation_start(trainer, model),
|
||||
call.on_epoch_start(trainer, model),
|
||||
call.on_validation_epoch_start(trainer, model),
|
||||
|
@ -94,6 +92,8 @@ def test_trainer_callback_hook_system_fit(_, tmpdir):
|
|||
call.on_epoch_end(trainer, model),
|
||||
call.on_validation_end(trainer, model),
|
||||
call.on_save_checkpoint(trainer, model), # should take ANY but we are inspecting signature for BC
|
||||
call.on_train_epoch_end(trainer, model, ANY),
|
||||
call.on_epoch_end(trainer, model),
|
||||
call.on_train_end(trainer, model),
|
||||
call.on_fit_end(trainer, model),
|
||||
call.teardown(trainer, model, 'fit'),
|
||||
|
|
|
@ -169,7 +169,7 @@ def test_early_stopping_patience_train(
|
|||
model.validation_step = None
|
||||
|
||||
early_stop_callback = EarlyStopping(
|
||||
monitor="train_loss", patience=patience, verbose=True, check_on_train_epoch_end=validation_step_none
|
||||
monitor="train_loss", patience=patience, verbose=True, check_on_train_epoch_end=True
|
||||
)
|
||||
trainer = Trainer(
|
||||
default_root_dir=tmpdir,
|
||||
|
|
|
@ -60,9 +60,8 @@ class LogInTwoMethods(BoringModel):
|
|||
"validation_step_none,val_dataloaders_none,monitor",
|
||||
[
|
||||
(False, False, 'val_log'),
|
||||
(False, False, 'train_log_epoch'),
|
||||
(True, False, 'train_log_epoch'),
|
||||
(False, True, 'train_log_epoch'),
|
||||
(False, True, 'val_log'),
|
||||
],
|
||||
)
|
||||
@pytest.mark.parametrize('reduce_lr_on_plateau', [False, True])
|
||||
|
@ -76,7 +75,7 @@ def test_model_checkpoint_score_and_ckpt(
|
|||
max_epochs = 3
|
||||
limit_train_batches = 5
|
||||
limit_val_batches = 7
|
||||
lr = 1e-1
|
||||
lr, gamma = 1e-1, 2
|
||||
|
||||
class CustomBoringModel(BoringModel):
|
||||
|
||||
|
@ -106,7 +105,7 @@ def test_model_checkpoint_score_and_ckpt(
|
|||
'strict': True,
|
||||
}
|
||||
else:
|
||||
lr_scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=1)
|
||||
lr_scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=1, gamma=gamma)
|
||||
|
||||
return [optimizer], [lr_scheduler]
|
||||
|
||||
|
@ -153,9 +152,12 @@ def test_model_checkpoint_score_and_ckpt(
|
|||
assert mc_specific_data['current_score'] == score
|
||||
|
||||
if not reduce_lr_on_plateau:
|
||||
lr_scheduler_specific_data = chk['lr_schedulers'][0]
|
||||
assert lr_scheduler_specific_data['_step_count'] == epoch + 2
|
||||
assert lr_scheduler_specific_data['_last_lr'][0] == lr * (lr**(epoch + 1))
|
||||
actual_step_count = chk['lr_schedulers'][0]['_step_count']
|
||||
actual_lr = chk['lr_schedulers'][0]['_last_lr'][0]
|
||||
# if validation_step_none, the checkpoint gets saved after the learning rate update
|
||||
# so we need to increase the count by one
|
||||
assert actual_step_count == epoch + 1 + validation_step_none
|
||||
assert actual_lr == lr * gamma**(epoch + validation_step_none)
|
||||
|
||||
assert lr_scheduler_debug[epoch]['monitor_val'] == (score if reduce_lr_on_plateau else None)
|
||||
assert lr_scheduler_debug[epoch]['monitor_key'] == (monitor if reduce_lr_on_plateau else None)
|
||||
|
@ -180,23 +182,21 @@ def test_model_checkpoint_score_and_ckpt_val_check_interval(
|
|||
max_epochs = 3
|
||||
limit_train_batches = 12
|
||||
limit_val_batches = 7
|
||||
lr = 1e-1
|
||||
lr, gamma = 1e-1, 2
|
||||
monitor = 'val_log'
|
||||
per_epoch_steps = int(limit_train_batches * val_check_interval)
|
||||
per_epoch_call_count = limit_train_batches // per_epoch_steps
|
||||
left_over_steps = limit_train_batches % per_epoch_steps
|
||||
per_val_train_batches = int(limit_train_batches * val_check_interval)
|
||||
per_epoch_val_checks, leftover_train_batches = divmod(limit_train_batches, per_val_train_batches)
|
||||
|
||||
class CustomBoringModel(BoringModel):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.val_logs = torch.randn(per_epoch_call_count * max_epochs, limit_val_batches)
|
||||
self.val_logs = torch.randn(per_epoch_val_checks * max_epochs, limit_val_batches)
|
||||
self.val_loop_count = 0
|
||||
|
||||
def validation_step(self, batch, batch_idx):
|
||||
log_value = self.val_logs[self.val_loop_count, batch_idx]
|
||||
self.log('val_log', log_value)
|
||||
self.log('epoch', self.current_epoch, on_epoch=True)
|
||||
return super().validation_step(batch, batch_idx)
|
||||
|
||||
def validation_epoch_end(self, outputs):
|
||||
|
@ -213,7 +213,7 @@ def test_model_checkpoint_score_and_ckpt_val_check_interval(
|
|||
'strict': True,
|
||||
}
|
||||
else:
|
||||
lr_scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=1)
|
||||
lr_scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=1, gamma=gamma)
|
||||
|
||||
return [optimizer], [lr_scheduler]
|
||||
|
||||
|
@ -241,26 +241,27 @@ def test_model_checkpoint_score_and_ckpt_val_check_interval(
|
|||
|
||||
# on_train_end ckpt callback is called which creates an additional ckpt in case no ckpt is created at the
|
||||
# end of epoch, thus if val_check_interval doesn't align with the training steps we create an additional ckpt
|
||||
additional_ckpt, additional_ckpt_path = 0, None
|
||||
additional_ckpt, additional_ckpt_path = False, None
|
||||
if not epoch_aligned:
|
||||
additional_ckpt_path = [f for f in ckpt_files if 'v1' in f.stem][0]
|
||||
additional_ckpt = 1
|
||||
additional_ckpt = True
|
||||
|
||||
additional_ckpt = 1 if not epoch_aligned else 0
|
||||
assert len(ckpt_files) == len(scores) + additional_ckpt == per_epoch_call_count * max_epochs + additional_ckpt
|
||||
assert len(ckpt_files) == len(scores) + additional_ckpt == per_epoch_val_checks * max_epochs + additional_ckpt
|
||||
assert len(lr_scheduler_debug) == max_epochs
|
||||
|
||||
def _make_assertions(epoch, ix, add=''):
|
||||
global_ix = ix + per_epoch_call_count * epoch
|
||||
def _make_assertions(epoch, ix, version=''):
|
||||
global_ix = ix + per_epoch_val_checks * epoch
|
||||
duplicated = bool(version)
|
||||
|
||||
score = scores[global_ix]
|
||||
expected_score = getattr(model, f'{monitor}s')[global_ix].mean().item()
|
||||
expected_filename = f'{monitor}={score:.4f}-epoch={epoch}{add}.ckpt'
|
||||
expected_filename = f'{monitor}={score:.4f}-epoch={epoch}{version}.ckpt'
|
||||
assert math.isclose(score, expected_score, rel_tol=1e-4)
|
||||
|
||||
chk = pl_load(os.path.join(checkpoint.dirpath, expected_filename))
|
||||
assert chk['epoch'] == epoch + 1
|
||||
epoch_num = epoch + (1 if add else 0)
|
||||
expected_global_step = per_epoch_steps * (global_ix + 1) + (left_over_steps * epoch_num)
|
||||
epoch_num = epoch + duplicated
|
||||
expected_global_step = per_val_train_batches * (global_ix + 1) + (leftover_train_batches * epoch_num)
|
||||
assert chk['global_step'] == expected_global_step
|
||||
|
||||
mc_specific_data = chk['callbacks'][type(checkpoint)]
|
||||
|
@ -269,15 +270,15 @@ def test_model_checkpoint_score_and_ckpt_val_check_interval(
|
|||
assert mc_specific_data['current_score'] == score
|
||||
|
||||
if not reduce_lr_on_plateau:
|
||||
lr_scheduler_specific_data = chk['lr_schedulers'][0]
|
||||
did_update = 1 if (ix + 1 == per_epoch_call_count) and (epoch_aligned or add) else 0
|
||||
assert lr_scheduler_specific_data['_step_count'] == epoch + 1 + did_update
|
||||
assert lr_scheduler_specific_data['_last_lr'][0] == lr * (lr**(epoch + did_update))
|
||||
actual_step_count = chk['lr_schedulers'][0]['_step_count']
|
||||
actual_lr = chk['lr_schedulers'][0]['_last_lr'][0]
|
||||
assert actual_step_count == epoch + 1 + duplicated
|
||||
assert actual_lr == lr * gamma**(epoch + duplicated)
|
||||
|
||||
return score
|
||||
|
||||
for epoch in range(max_epochs):
|
||||
for i in range(per_epoch_call_count):
|
||||
for i in range(per_epoch_val_checks):
|
||||
score = _make_assertions(epoch, i)
|
||||
|
||||
assert lr_scheduler_debug[epoch]['monitor_val'] == (score if reduce_lr_on_plateau else None)
|
||||
|
@ -285,9 +286,7 @@ def test_model_checkpoint_score_and_ckpt_val_check_interval(
|
|||
|
||||
# check the ckpt file saved on_train_end
|
||||
if additional_ckpt_path:
|
||||
epoch = max_epochs - 1
|
||||
i = per_epoch_call_count - 1
|
||||
_make_assertions(epoch, i, add='-v1')
|
||||
_make_assertions(max_epochs - 1, per_epoch_val_checks - 1, version='-v1')
|
||||
|
||||
|
||||
@pytest.mark.parametrize("save_top_k", [-1, 0, 1, 2])
|
||||
|
|
|
@ -469,9 +469,6 @@ def test_trainer_model_hook_system_fit(tmpdir):
|
|||
'on_epoch_start',
|
||||
'on_train_epoch_start',
|
||||
*(model.train_batch * train_batches),
|
||||
'training_epoch_end',
|
||||
'on_train_epoch_end',
|
||||
'on_epoch_end',
|
||||
'on_validation_model_eval',
|
||||
'on_validation_start',
|
||||
'on_epoch_start',
|
||||
|
@ -483,6 +480,9 @@ def test_trainer_model_hook_system_fit(tmpdir):
|
|||
'on_save_checkpoint',
|
||||
'on_validation_end',
|
||||
'on_validation_model_train',
|
||||
'training_epoch_end',
|
||||
'on_train_epoch_end',
|
||||
'on_epoch_end',
|
||||
'on_train_end',
|
||||
'on_fit_end',
|
||||
'teardown_fit',
|
||||
|
|
|
@ -141,4 +141,4 @@ def test_should_stop_mid_epoch(tmpdir):
|
|||
|
||||
assert trainer.current_epoch == 0
|
||||
assert trainer.global_step == 5
|
||||
assert model.validation_called_at == (0, 4) # TODO(@carmocca): should be 5 - will be fixed in next PR
|
||||
assert model.validation_called_at == (0, 4)
|
||||
|
|
Loading…
Reference in New Issue