diff --git a/src/lightning/pytorch/CHANGELOG.md b/src/lightning/pytorch/CHANGELOG.md index c5b7fe92f7..2bf569b576 100644 --- a/src/lightning/pytorch/CHANGELOG.md +++ b/src/lightning/pytorch/CHANGELOG.md @@ -271,6 +271,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - +- Fixed early stopping triggering extra validation runs after reaching `min_epochs` or `min_steps` ([#16719](https://github.com/Lightning-AI/lightning/pull/16719)) + + ## [1.9.1] - 2023-02-10 ### Fixed diff --git a/src/lightning/pytorch/loops/epoch/training_epoch_loop.py b/src/lightning/pytorch/loops/epoch/training_epoch_loop.py index 3ac41edccb..88e3bbc22e 100644 --- a/src/lightning/pytorch/loops/epoch/training_epoch_loop.py +++ b/src/lightning/pytorch/loops/epoch/training_epoch_loop.py @@ -115,13 +115,13 @@ class _TrainingEpochLoop(loops._Loop): if self.trainer.should_stop: # early stopping min_epochs = self.trainer.fit_loop.min_epochs - should_stop_early = self.trainer.fit_loop._should_stop_early - if not should_stop_early: + can_stop_early = self.trainer.fit_loop._can_stop_early + if not can_stop_early: self._warning_cache.info( f"Trainer was signaled to stop but the required `min_epochs={min_epochs!r}` or" f" `min_steps={self.min_steps!r}` has not been met. Training will continue..." ) - return should_stop_early + return can_stop_early return False @@ -389,7 +389,9 @@ class _TrainingEpochLoop(loops._Loop): if is_last_batch and is_infinite_dataset: return True - if self.trainer.should_stop: + if self.trainer.should_stop and self.trainer.fit_loop._can_stop_early: + # allow validation if requesting to stop early through `Trainer.should_stop` (e.g. by early stopping) + # and when the loop allows to stop (min_epochs/steps met) return True # TODO: let training/eval loop handle logic around limit_*_batches and val_check_batch diff --git a/src/lightning/pytorch/loops/fit_loop.py b/src/lightning/pytorch/loops/fit_loop.py index fd95e51406..c65cfedf2f 100644 --- a/src/lightning/pytorch/loops/fit_loop.py +++ b/src/lightning/pytorch/loops/fit_loop.py @@ -142,7 +142,7 @@ class _FitLoop(_Loop): raise RuntimeError("`FitLoop._results` property isn't defined. Accessed outside of scope") @property - def _should_stop_early(self) -> bool: + def _can_stop_early(self) -> bool: met_min_epochs = self.epoch_progress.current.processed >= self.min_epochs if self.min_epochs else True met_min_steps = self.epoch_loop.global_step >= self.min_steps if self.min_steps else True return met_min_epochs and met_min_steps @@ -170,7 +170,7 @@ class _FitLoop(_Loop): rank_zero_info(f"`Trainer.fit` stopped: `max_epochs={self.max_epochs!r}` reached.") return True - if self.trainer.should_stop and self._should_stop_early: + if self.trainer.should_stop and self._can_stop_early: rank_zero_debug("`Trainer.fit` stopped: `trainer.should_stop` was set.") return True diff --git a/tests/tests_pytorch/loops/epoch/test_training_epoch_loop.py b/tests/tests_pytorch/loops/epoch/test_training_epoch_loop.py index c639dec84d..ba305a461f 100644 --- a/tests/tests_pytorch/loops/epoch/test_training_epoch_loop.py +++ b/tests/tests_pytorch/loops/epoch/test_training_epoch_loop.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import logging -from unittest.mock import patch +from unittest.mock import Mock, patch import pytest @@ -75,4 +75,29 @@ def test_should_stop_early_stopping_conditions_not_met( assert trainer.fit_loop.epoch_loop.done is epoch_loop_done assert (message in caplog.text) is raise_info_msg - assert trainer.fit_loop._should_stop_early is early_stop + assert trainer.fit_loop._can_stop_early is early_stop + + +@pytest.mark.parametrize("min_epochs,min_steps,val_count", [(3, None, 3), (None, 3, 2)]) +def test_should_stop_triggers_validation_once(min_epochs, min_steps, val_count, tmp_path): + """Regression test for issue #15708. + + Test that the request for `should_stop=True` only triggers validation when Trainer is allowed to stop + (min_epochs/steps is satisfied). + """ + model = BoringModel() + trainer = Trainer( + default_root_dir=tmp_path, + num_sanity_val_steps=0, + limit_val_batches=2, + limit_train_batches=2, + max_epochs=3, + min_epochs=min_epochs, + min_steps=min_steps, + enable_model_summary=False, + enable_checkpointing=False, + ) + trainer.should_stop = True # Request to stop before min_epochs/min_steps are reached + trainer.fit_loop.epoch_loop.val_loop.run = Mock() + trainer.fit(model) + assert trainer.fit_loop.epoch_loop.val_loop.run.call_count == val_count diff --git a/tests/tests_pytorch/loops/test_training_loop.py b/tests/tests_pytorch/loops/test_training_loop.py index 2e029ed758..78acbc0ecf 100644 --- a/tests/tests_pytorch/loops/test_training_loop.py +++ b/tests/tests_pytorch/loops/test_training_loop.py @@ -223,4 +223,4 @@ def test_should_stop_early_stopping_conditions_met( assert trainer.fit_loop.done is fit_loop_done assert (message in caplog.text) is raise_debug_msg - assert trainer.fit_loop._should_stop_early is early_stop + assert trainer.fit_loop._can_stop_early is early_stop