From 311d9fe67e99908570b151577749bd08118961d2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Carlos=20Mochol=C3=AD?= Date: Wed, 26 May 2021 14:26:48 +0200 Subject: [PATCH] Always run validation inside the training loop epoch (#7357) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Adrian Wälchli --- CHANGELOG.md | 4 ++ pytorch_lightning/trainer/trainer.py | 13 +---- pytorch_lightning/trainer/training_loop.py | 58 +++++++------------ tests/callbacks/test_callbacks.py | 4 +- tests/callbacks/test_early_stopping.py | 2 +- tests/checkpointing/test_model_checkpoint.py | 61 ++++++++++---------- tests/models/test_hooks.py | 6 +- tests/trainer/loops/test_training_loop.py | 2 +- 8 files changed, 64 insertions(+), 86 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index c07b13ac93..199aa70329 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -62,6 +62,10 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Changed `clip_grad_norm` to use `torch.nn.utils.clip_grad_norm_` ([#7025](https://github.com/PyTorchLightning/pytorch-lightning/pull/7025)) + +- Validation is now always run inside the training epoch scope ([#7357](https://github.com/PyTorchLightning/pytorch-lightning/pull/7357)) + + - Refactored Loops * Moved attributes `global_step`, `current_epoch`, `max/min_steps`, `max/min_epochs`, `batch_idx`, and `total_batch_idx` to TrainLoop ([#7437](https://github.com/PyTorchLightning/pytorch-lightning/pull/7025)) * Refactored result handling in training loop ([#7506](https://github.com/PyTorchLightning/pytorch-lightning/pull/7506)) diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 862abdbea4..b01f4fa36b 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -928,7 +928,7 @@ class Trainer( self.state.stage = None raise - def _run_evaluation(self, on_epoch: bool = False) -> _EVALUATE_OUTPUT: + def _run_evaluation(self) -> _EVALUATE_OUTPUT: if not (self.evaluating or self.sanity_checking): rank_zero_warn( f"`trainer._run_evaluation()` was called but the running stage is set to {self.state.stage}." @@ -1010,17 +1010,6 @@ class Trainer( # hook self.evaluation_loop.on_evaluation_epoch_end() - # update epoch-level lr_schedulers - if on_epoch: - self.optimizer_connector.update_learning_rates( - interval='epoch', - opt_indices=[ - opt_idx for opt_idx, _ in self.train_loop.get_active_optimizers( - batch_idx=(self.train_loop.total_batch_idx - 1) - ) # Select the optimizers which were used in the last batch of the epoch - ], - ) - # log epoch metrics eval_loop_results = self.logger_connector.get_evaluate_epoch_results() diff --git a/pytorch_lightning/trainer/training_loop.py b/pytorch_lightning/trainer/training_loop.py index 6213879013..09a32c3c96 100644 --- a/pytorch_lightning/trainer/training_loop.py +++ b/pytorch_lightning/trainer/training_loop.py @@ -479,7 +479,6 @@ class TrainLoop: train_dataloader = self.trainer.data_connector.get_profiled_train_dataloader(train_dataloader) dataloader_idx = 0 batch_idx = None - is_last_batch = None for batch_idx, (batch, is_last_batch) in train_dataloader: self.batch_idx = batch_idx @@ -529,16 +528,13 @@ class TrainLoop: self.total_batch_idx += 1 - max_steps_reached = ( - self.max_steps is not None and self.max_steps <= self.global_step + 1 - and self._accumulated_batches_reached() - ) - if max_steps_reached or self.trainer.should_stop or self._num_training_batches_reached(is_last_batch): - break - # progress global step according to grads progress self.increment_accumulated_grad_global_step() + max_steps_reached = (self.max_steps is not None and self.max_steps <= self.global_step) + if max_steps_reached or self.trainer.should_stop or self._num_training_batches_reached(is_last_batch): + break + if batch_idx is None: # dataloader/iterator did not produce a batch return @@ -546,27 +542,24 @@ class TrainLoop: # handle epoch_output on epoch end self.on_train_epoch_end(epoch_output) + # the global step is manually decreased here due to backwards compatibility with existing loggers + # as they expect that the same step is used when logging epoch end metrics even when the batch loop has + # finished. this means the attribute does not exactly track the number of optimizer steps applied. + # TODO(@carmocca): deprecate and rename so users don't get confused + self.global_step -= 1 # log epoch metrics self.trainer.logger_connector.log_train_epoch_end_metrics(epoch_output) + self.global_step += 1 - should_check_val = self._should_check_val_fx(batch_idx, is_last_batch, on_epoch=True) - should_skip_eval = self.trainer.evaluation_loop.should_skip_evaluation(self.trainer.num_val_batches) - should_train_only = self.trainer.disable_validation or should_skip_eval + self.update_lr_schedulers('epoch') - # update epoch level lr_schedulers if no val loop outside train loop is triggered - if not should_check_val or should_train_only: - self.update_lr_schedulers('epoch') - - if should_train_only: + did_train_only = self.trainer.disable_validation or self.trainer.evaluation_loop.should_skip_evaluation( + self.trainer.num_val_batches + ) + if did_train_only: + self.global_step -= 1 self.check_checkpoint_callback(True) - - if should_check_val: - self.trainer.validating = True - self.trainer._run_evaluation(on_epoch=True) - self.trainer.training = True - - if batch_output.signal != -1: - self.increment_accumulated_grad_global_step() + self.global_step += 1 def on_train_epoch_end(self, epoch_output: List[List[List[Result]]]) -> None: # inform logger the batch loop has finished @@ -882,7 +875,7 @@ class TrainLoop: is_final_batch = self._num_training_batches_reached() return not (accumulation_done or is_final_batch) - def _should_check_val_fx(self, batch_idx: int, is_last_batch: bool, on_epoch: bool = False) -> bool: + def _should_check_val_fx(self, batch_idx: int, is_last_batch: bool) -> bool: """ Decide if we should run validation. """ if not self.trainer.enable_validation: return False @@ -893,26 +886,19 @@ class TrainLoop: # val_check_batch is inf for iterable datasets with no length defined is_infinite_dataset = self.trainer.val_check_batch == float('inf') - if on_epoch and is_last_batch and is_infinite_dataset: + if is_last_batch and is_infinite_dataset: return True if self.trainer.should_stop: return True # TODO: let training/eval loop handle logic around limit_*_batches and val_check_batch - is_val_check_batch = False - if isinstance(self.trainer.limit_train_batches, int) and self.trainer.val_check_batch == float('inf'): + is_val_check_batch = is_last_batch + if isinstance(self.trainer.limit_train_batches, int) and is_infinite_dataset: is_val_check_batch = (batch_idx + 1) % self.trainer.limit_train_batches == 0 elif self.trainer.val_check_batch != float('inf'): is_val_check_batch = (batch_idx + 1) % self.trainer.val_check_batch == 0 - - # Note: num_training_batches is also inf for iterable datasets with no length defined - epoch_end_val_check = (batch_idx + 1) % self.trainer.num_training_batches == 0 - - if on_epoch: - return is_val_check_batch and epoch_end_val_check - else: - return is_val_check_batch and not epoch_end_val_check + return is_val_check_batch def _build_kwargs(self, batch, batch_idx, opt_idx, hiddens): # enable not needing to add opt_idx to training_step diff --git a/tests/callbacks/test_callbacks.py b/tests/callbacks/test_callbacks.py index 9b048e022c..a22e72ce09 100644 --- a/tests/callbacks/test_callbacks.py +++ b/tests/callbacks/test_callbacks.py @@ -83,8 +83,6 @@ def test_trainer_callback_hook_system_fit(_, tmpdir): call.on_after_backward(trainer, model), call.on_train_batch_end(trainer, model, ANY, ANY, 2, 0), call.on_batch_end(trainer, model), - call.on_train_epoch_end(trainer, model, ANY), - call.on_epoch_end(trainer, model), call.on_validation_start(trainer, model), call.on_epoch_start(trainer, model), call.on_validation_epoch_start(trainer, model), @@ -94,6 +92,8 @@ def test_trainer_callback_hook_system_fit(_, tmpdir): call.on_epoch_end(trainer, model), call.on_validation_end(trainer, model), call.on_save_checkpoint(trainer, model), # should take ANY but we are inspecting signature for BC + call.on_train_epoch_end(trainer, model, ANY), + call.on_epoch_end(trainer, model), call.on_train_end(trainer, model), call.on_fit_end(trainer, model), call.teardown(trainer, model, 'fit'), diff --git a/tests/callbacks/test_early_stopping.py b/tests/callbacks/test_early_stopping.py index b1242de725..7d303e6ed0 100644 --- a/tests/callbacks/test_early_stopping.py +++ b/tests/callbacks/test_early_stopping.py @@ -169,7 +169,7 @@ def test_early_stopping_patience_train( model.validation_step = None early_stop_callback = EarlyStopping( - monitor="train_loss", patience=patience, verbose=True, check_on_train_epoch_end=validation_step_none + monitor="train_loss", patience=patience, verbose=True, check_on_train_epoch_end=True ) trainer = Trainer( default_root_dir=tmpdir, diff --git a/tests/checkpointing/test_model_checkpoint.py b/tests/checkpointing/test_model_checkpoint.py index 3d7a35917e..2f867d4e99 100644 --- a/tests/checkpointing/test_model_checkpoint.py +++ b/tests/checkpointing/test_model_checkpoint.py @@ -60,9 +60,8 @@ class LogInTwoMethods(BoringModel): "validation_step_none,val_dataloaders_none,monitor", [ (False, False, 'val_log'), - (False, False, 'train_log_epoch'), (True, False, 'train_log_epoch'), - (False, True, 'train_log_epoch'), + (False, True, 'val_log'), ], ) @pytest.mark.parametrize('reduce_lr_on_plateau', [False, True]) @@ -76,7 +75,7 @@ def test_model_checkpoint_score_and_ckpt( max_epochs = 3 limit_train_batches = 5 limit_val_batches = 7 - lr = 1e-1 + lr, gamma = 1e-1, 2 class CustomBoringModel(BoringModel): @@ -106,7 +105,7 @@ def test_model_checkpoint_score_and_ckpt( 'strict': True, } else: - lr_scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=1) + lr_scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=1, gamma=gamma) return [optimizer], [lr_scheduler] @@ -153,9 +152,12 @@ def test_model_checkpoint_score_and_ckpt( assert mc_specific_data['current_score'] == score if not reduce_lr_on_plateau: - lr_scheduler_specific_data = chk['lr_schedulers'][0] - assert lr_scheduler_specific_data['_step_count'] == epoch + 2 - assert lr_scheduler_specific_data['_last_lr'][0] == lr * (lr**(epoch + 1)) + actual_step_count = chk['lr_schedulers'][0]['_step_count'] + actual_lr = chk['lr_schedulers'][0]['_last_lr'][0] + # if validation_step_none, the checkpoint gets saved after the learning rate update + # so we need to increase the count by one + assert actual_step_count == epoch + 1 + validation_step_none + assert actual_lr == lr * gamma**(epoch + validation_step_none) assert lr_scheduler_debug[epoch]['monitor_val'] == (score if reduce_lr_on_plateau else None) assert lr_scheduler_debug[epoch]['monitor_key'] == (monitor if reduce_lr_on_plateau else None) @@ -180,23 +182,21 @@ def test_model_checkpoint_score_and_ckpt_val_check_interval( max_epochs = 3 limit_train_batches = 12 limit_val_batches = 7 - lr = 1e-1 + lr, gamma = 1e-1, 2 monitor = 'val_log' - per_epoch_steps = int(limit_train_batches * val_check_interval) - per_epoch_call_count = limit_train_batches // per_epoch_steps - left_over_steps = limit_train_batches % per_epoch_steps + per_val_train_batches = int(limit_train_batches * val_check_interval) + per_epoch_val_checks, leftover_train_batches = divmod(limit_train_batches, per_val_train_batches) class CustomBoringModel(BoringModel): def __init__(self): super().__init__() - self.val_logs = torch.randn(per_epoch_call_count * max_epochs, limit_val_batches) + self.val_logs = torch.randn(per_epoch_val_checks * max_epochs, limit_val_batches) self.val_loop_count = 0 def validation_step(self, batch, batch_idx): log_value = self.val_logs[self.val_loop_count, batch_idx] self.log('val_log', log_value) - self.log('epoch', self.current_epoch, on_epoch=True) return super().validation_step(batch, batch_idx) def validation_epoch_end(self, outputs): @@ -213,7 +213,7 @@ def test_model_checkpoint_score_and_ckpt_val_check_interval( 'strict': True, } else: - lr_scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=1) + lr_scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=1, gamma=gamma) return [optimizer], [lr_scheduler] @@ -241,26 +241,27 @@ def test_model_checkpoint_score_and_ckpt_val_check_interval( # on_train_end ckpt callback is called which creates an additional ckpt in case no ckpt is created at the # end of epoch, thus if val_check_interval doesn't align with the training steps we create an additional ckpt - additional_ckpt, additional_ckpt_path = 0, None + additional_ckpt, additional_ckpt_path = False, None if not epoch_aligned: additional_ckpt_path = [f for f in ckpt_files if 'v1' in f.stem][0] - additional_ckpt = 1 + additional_ckpt = True - additional_ckpt = 1 if not epoch_aligned else 0 - assert len(ckpt_files) == len(scores) + additional_ckpt == per_epoch_call_count * max_epochs + additional_ckpt + assert len(ckpt_files) == len(scores) + additional_ckpt == per_epoch_val_checks * max_epochs + additional_ckpt assert len(lr_scheduler_debug) == max_epochs - def _make_assertions(epoch, ix, add=''): - global_ix = ix + per_epoch_call_count * epoch + def _make_assertions(epoch, ix, version=''): + global_ix = ix + per_epoch_val_checks * epoch + duplicated = bool(version) + score = scores[global_ix] expected_score = getattr(model, f'{monitor}s')[global_ix].mean().item() - expected_filename = f'{monitor}={score:.4f}-epoch={epoch}{add}.ckpt' + expected_filename = f'{monitor}={score:.4f}-epoch={epoch}{version}.ckpt' assert math.isclose(score, expected_score, rel_tol=1e-4) chk = pl_load(os.path.join(checkpoint.dirpath, expected_filename)) assert chk['epoch'] == epoch + 1 - epoch_num = epoch + (1 if add else 0) - expected_global_step = per_epoch_steps * (global_ix + 1) + (left_over_steps * epoch_num) + epoch_num = epoch + duplicated + expected_global_step = per_val_train_batches * (global_ix + 1) + (leftover_train_batches * epoch_num) assert chk['global_step'] == expected_global_step mc_specific_data = chk['callbacks'][type(checkpoint)] @@ -269,15 +270,15 @@ def test_model_checkpoint_score_and_ckpt_val_check_interval( assert mc_specific_data['current_score'] == score if not reduce_lr_on_plateau: - lr_scheduler_specific_data = chk['lr_schedulers'][0] - did_update = 1 if (ix + 1 == per_epoch_call_count) and (epoch_aligned or add) else 0 - assert lr_scheduler_specific_data['_step_count'] == epoch + 1 + did_update - assert lr_scheduler_specific_data['_last_lr'][0] == lr * (lr**(epoch + did_update)) + actual_step_count = chk['lr_schedulers'][0]['_step_count'] + actual_lr = chk['lr_schedulers'][0]['_last_lr'][0] + assert actual_step_count == epoch + 1 + duplicated + assert actual_lr == lr * gamma**(epoch + duplicated) return score for epoch in range(max_epochs): - for i in range(per_epoch_call_count): + for i in range(per_epoch_val_checks): score = _make_assertions(epoch, i) assert lr_scheduler_debug[epoch]['monitor_val'] == (score if reduce_lr_on_plateau else None) @@ -285,9 +286,7 @@ def test_model_checkpoint_score_and_ckpt_val_check_interval( # check the ckpt file saved on_train_end if additional_ckpt_path: - epoch = max_epochs - 1 - i = per_epoch_call_count - 1 - _make_assertions(epoch, i, add='-v1') + _make_assertions(max_epochs - 1, per_epoch_val_checks - 1, version='-v1') @pytest.mark.parametrize("save_top_k", [-1, 0, 1, 2]) diff --git a/tests/models/test_hooks.py b/tests/models/test_hooks.py index 678f34d298..913f403a14 100644 --- a/tests/models/test_hooks.py +++ b/tests/models/test_hooks.py @@ -469,9 +469,6 @@ def test_trainer_model_hook_system_fit(tmpdir): 'on_epoch_start', 'on_train_epoch_start', *(model.train_batch * train_batches), - 'training_epoch_end', - 'on_train_epoch_end', - 'on_epoch_end', 'on_validation_model_eval', 'on_validation_start', 'on_epoch_start', @@ -483,6 +480,9 @@ def test_trainer_model_hook_system_fit(tmpdir): 'on_save_checkpoint', 'on_validation_end', 'on_validation_model_train', + 'training_epoch_end', + 'on_train_epoch_end', + 'on_epoch_end', 'on_train_end', 'on_fit_end', 'teardown_fit', diff --git a/tests/trainer/loops/test_training_loop.py b/tests/trainer/loops/test_training_loop.py index 2e17f57ec9..da4ecbe5a9 100644 --- a/tests/trainer/loops/test_training_loop.py +++ b/tests/trainer/loops/test_training_loop.py @@ -141,4 +141,4 @@ def test_should_stop_mid_epoch(tmpdir): assert trainer.current_epoch == 0 assert trainer.global_step == 5 - assert model.validation_called_at == (0, 4) # TODO(@carmocca): should be 5 - will be fixed in next PR + assert model.validation_called_at == (0, 4)