Fix min-epochs and early-stopping triggering too many validation runs (#16719)

Co-authored-by: Jirka Borovec <6035284+Borda@users.noreply.github.com>
This commit is contained in:
Adrian Wälchli 2023-02-11 04:02:39 +01:00 committed by GitHub
parent 78efde3d36
commit 5340d960b9
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 39 additions and 9 deletions

View File

@ -271,6 +271,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
-
- Fixed early stopping triggering extra validation runs after reaching `min_epochs` or `min_steps` ([#16719](https://github.com/Lightning-AI/lightning/pull/16719))
## [1.9.1] - 2023-02-10
### Fixed

View File

@ -115,13 +115,13 @@ class _TrainingEpochLoop(loops._Loop):
if self.trainer.should_stop:
# early stopping
min_epochs = self.trainer.fit_loop.min_epochs
should_stop_early = self.trainer.fit_loop._should_stop_early
if not should_stop_early:
can_stop_early = self.trainer.fit_loop._can_stop_early
if not can_stop_early:
self._warning_cache.info(
f"Trainer was signaled to stop but the required `min_epochs={min_epochs!r}` or"
f" `min_steps={self.min_steps!r}` has not been met. Training will continue..."
)
return should_stop_early
return can_stop_early
return False
@ -389,7 +389,9 @@ class _TrainingEpochLoop(loops._Loop):
if is_last_batch and is_infinite_dataset:
return True
if self.trainer.should_stop:
if self.trainer.should_stop and self.trainer.fit_loop._can_stop_early:
# allow validation if requesting to stop early through `Trainer.should_stop` (e.g. by early stopping)
# and when the loop allows to stop (min_epochs/steps met)
return True
# TODO: let training/eval loop handle logic around limit_*_batches and val_check_batch

View File

@ -142,7 +142,7 @@ class _FitLoop(_Loop):
raise RuntimeError("`FitLoop._results` property isn't defined. Accessed outside of scope")
@property
def _should_stop_early(self) -> bool:
def _can_stop_early(self) -> bool:
met_min_epochs = self.epoch_progress.current.processed >= self.min_epochs if self.min_epochs else True
met_min_steps = self.epoch_loop.global_step >= self.min_steps if self.min_steps else True
return met_min_epochs and met_min_steps
@ -170,7 +170,7 @@ class _FitLoop(_Loop):
rank_zero_info(f"`Trainer.fit` stopped: `max_epochs={self.max_epochs!r}` reached.")
return True
if self.trainer.should_stop and self._should_stop_early:
if self.trainer.should_stop and self._can_stop_early:
rank_zero_debug("`Trainer.fit` stopped: `trainer.should_stop` was set.")
return True

View File

@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from unittest.mock import patch
from unittest.mock import Mock, patch
import pytest
@ -75,4 +75,29 @@ def test_should_stop_early_stopping_conditions_not_met(
assert trainer.fit_loop.epoch_loop.done is epoch_loop_done
assert (message in caplog.text) is raise_info_msg
assert trainer.fit_loop._should_stop_early is early_stop
assert trainer.fit_loop._can_stop_early is early_stop
@pytest.mark.parametrize("min_epochs,min_steps,val_count", [(3, None, 3), (None, 3, 2)])
def test_should_stop_triggers_validation_once(min_epochs, min_steps, val_count, tmp_path):
"""Regression test for issue #15708.
Test that the request for `should_stop=True` only triggers validation when Trainer is allowed to stop
(min_epochs/steps is satisfied).
"""
model = BoringModel()
trainer = Trainer(
default_root_dir=tmp_path,
num_sanity_val_steps=0,
limit_val_batches=2,
limit_train_batches=2,
max_epochs=3,
min_epochs=min_epochs,
min_steps=min_steps,
enable_model_summary=False,
enable_checkpointing=False,
)
trainer.should_stop = True # Request to stop before min_epochs/min_steps are reached
trainer.fit_loop.epoch_loop.val_loop.run = Mock()
trainer.fit(model)
assert trainer.fit_loop.epoch_loop.val_loop.run.call_count == val_count

View File

@ -223,4 +223,4 @@ def test_should_stop_early_stopping_conditions_met(
assert trainer.fit_loop.done is fit_loop_done
assert (message in caplog.text) is raise_debug_msg
assert trainer.fit_loop._should_stop_early is early_stop
assert trainer.fit_loop._can_stop_early is early_stop