From a1c40f3207ae947bf9146d32c06138bc78b2c535 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Carlos=20Mochol=C3=AD?= Date: Tue, 25 May 2021 16:59:42 +0200 Subject: [PATCH] Remove on epoch guard from the should stop validation check (#7701) * Remove on epoch guard from the should stop validation check * Formatting --- .../callbacks/gpu_stats_monitor.py | 4 +-- pytorch_lightning/callbacks/lr_monitor.py | 4 +-- pytorch_lightning/trainer/training_loop.py | 18 +++-------- tests/trainer/loops/test_training_loop.py | 32 +++++++++++++++++++ 4 files changed, 38 insertions(+), 20 deletions(-) diff --git a/pytorch_lightning/callbacks/gpu_stats_monitor.py b/pytorch_lightning/callbacks/gpu_stats_monitor.py index ffd39e9af4..794165fe60 100644 --- a/pytorch_lightning/callbacks/gpu_stats_monitor.py +++ b/pytorch_lightning/callbacks/gpu_stats_monitor.py @@ -211,6 +211,4 @@ class GPUStatsMonitor(Callback): @staticmethod def _should_log(trainer) -> bool: - should_log = ((trainer.global_step + 1) % trainer.log_every_n_steps == 0 or trainer.should_stop) - - return should_log + return (trainer.global_step + 1) % trainer.log_every_n_steps == 0 or trainer.should_stop diff --git a/pytorch_lightning/callbacks/lr_monitor.py b/pytorch_lightning/callbacks/lr_monitor.py index 7530bfaa9d..410f8b319c 100644 --- a/pytorch_lightning/callbacks/lr_monitor.py +++ b/pytorch_lightning/callbacks/lr_monitor.py @@ -202,6 +202,4 @@ class LearningRateMonitor(Callback): @staticmethod def _should_log(trainer) -> bool: - should_log = ((trainer.global_step + 1) % trainer.log_every_n_steps == 0 or trainer.should_stop) - - return should_log + return (trainer.global_step + 1) % trainer.log_every_n_steps == 0 or trainer.should_stop diff --git a/pytorch_lightning/trainer/training_loop.py b/pytorch_lightning/trainer/training_loop.py index 1906679a2b..6213879013 100644 --- a/pytorch_lightning/trainer/training_loop.py +++ b/pytorch_lightning/trainer/training_loop.py @@ -529,21 +529,11 @@ class TrainLoop: self.total_batch_idx += 1 - # max steps reached, end training - if ( + max_steps_reached = ( self.max_steps is not None and self.max_steps <= self.global_step + 1 and self._accumulated_batches_reached() - ): - break - - # end epoch early - # stop when the flag is changed or we've gone past the amount - # requested in the batches - if self.trainer.should_stop: - break - - # stop epoch if we limited the number of training batches - if self._num_training_batches_reached(is_last_batch): + ) + if max_steps_reached or self.trainer.should_stop or self._num_training_batches_reached(is_last_batch): break # progress global step according to grads progress @@ -906,7 +896,7 @@ class TrainLoop: if on_epoch and is_last_batch and is_infinite_dataset: return True - if on_epoch and self.trainer.should_stop: + if self.trainer.should_stop: return True # TODO: let training/eval loop handle logic around limit_*_batches and val_check_batch diff --git a/tests/trainer/loops/test_training_loop.py b/tests/trainer/loops/test_training_loop.py index b89909a40f..2e17f57ec9 100644 --- a/tests/trainer/loops/test_training_loop.py +++ b/tests/trainer/loops/test_training_loop.py @@ -110,3 +110,35 @@ def test_on_train_batch_start_return_minus_one(max_epochs, batch_idx_): else: assert trainer.train_loop.batch_idx == batch_idx_ assert trainer.global_step == batch_idx_ * max_epochs + + +def test_should_stop_mid_epoch(tmpdir): + """Test that training correctly stops mid epoch and that validation is still called at the right time""" + + class TestModel(BoringModel): + + def __init__(self): + super().__init__() + self.validation_called_at = None + + def training_step(self, batch, batch_idx): + if batch_idx == 4: + self.trainer.should_stop = True + return super().training_step(batch, batch_idx) + + def validation_step(self, *args): + self.validation_called_at = (self.trainer.current_epoch, self.trainer.global_step) + return super().validation_step(*args) + + model = TestModel() + trainer = Trainer( + default_root_dir=tmpdir, + max_epochs=1, + limit_train_batches=10, + limit_val_batches=1, + ) + trainer.fit(model) + + assert trainer.current_epoch == 0 + assert trainer.global_step == 5 + assert model.validation_called_at == (0, 4) # TODO(@carmocca): should be 5 - will be fixed in next PR