Fix min-epochs and early-stopping triggering too many validation runs (#16719)
Co-authored-by: Jirka Borovec <6035284+Borda@users.noreply.github.com>
This commit is contained in:
parent
78efde3d36
commit
5340d960b9
|
@ -271,6 +271,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
|
|||
-
|
||||
|
||||
|
||||
- Fixed early stopping triggering extra validation runs after reaching `min_epochs` or `min_steps` ([#16719](https://github.com/Lightning-AI/lightning/pull/16719))
|
||||
|
||||
|
||||
## [1.9.1] - 2023-02-10
|
||||
|
||||
### Fixed
|
||||
|
|
|
@ -115,13 +115,13 @@ class _TrainingEpochLoop(loops._Loop):
|
|||
if self.trainer.should_stop:
|
||||
# early stopping
|
||||
min_epochs = self.trainer.fit_loop.min_epochs
|
||||
should_stop_early = self.trainer.fit_loop._should_stop_early
|
||||
if not should_stop_early:
|
||||
can_stop_early = self.trainer.fit_loop._can_stop_early
|
||||
if not can_stop_early:
|
||||
self._warning_cache.info(
|
||||
f"Trainer was signaled to stop but the required `min_epochs={min_epochs!r}` or"
|
||||
f" `min_steps={self.min_steps!r}` has not been met. Training will continue..."
|
||||
)
|
||||
return should_stop_early
|
||||
return can_stop_early
|
||||
|
||||
return False
|
||||
|
||||
|
@ -389,7 +389,9 @@ class _TrainingEpochLoop(loops._Loop):
|
|||
if is_last_batch and is_infinite_dataset:
|
||||
return True
|
||||
|
||||
if self.trainer.should_stop:
|
||||
if self.trainer.should_stop and self.trainer.fit_loop._can_stop_early:
|
||||
# allow validation if requesting to stop early through `Trainer.should_stop` (e.g. by early stopping)
|
||||
# and when the loop allows to stop (min_epochs/steps met)
|
||||
return True
|
||||
|
||||
# TODO: let training/eval loop handle logic around limit_*_batches and val_check_batch
|
||||
|
|
|
@ -142,7 +142,7 @@ class _FitLoop(_Loop):
|
|||
raise RuntimeError("`FitLoop._results` property isn't defined. Accessed outside of scope")
|
||||
|
||||
@property
|
||||
def _should_stop_early(self) -> bool:
|
||||
def _can_stop_early(self) -> bool:
|
||||
met_min_epochs = self.epoch_progress.current.processed >= self.min_epochs if self.min_epochs else True
|
||||
met_min_steps = self.epoch_loop.global_step >= self.min_steps if self.min_steps else True
|
||||
return met_min_epochs and met_min_steps
|
||||
|
@ -170,7 +170,7 @@ class _FitLoop(_Loop):
|
|||
rank_zero_info(f"`Trainer.fit` stopped: `max_epochs={self.max_epochs!r}` reached.")
|
||||
return True
|
||||
|
||||
if self.trainer.should_stop and self._should_stop_early:
|
||||
if self.trainer.should_stop and self._can_stop_early:
|
||||
rank_zero_debug("`Trainer.fit` stopped: `trainer.should_stop` was set.")
|
||||
return True
|
||||
|
||||
|
|
|
@ -12,7 +12,7 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import logging
|
||||
from unittest.mock import patch
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
|
@ -75,4 +75,29 @@ def test_should_stop_early_stopping_conditions_not_met(
|
|||
assert trainer.fit_loop.epoch_loop.done is epoch_loop_done
|
||||
|
||||
assert (message in caplog.text) is raise_info_msg
|
||||
assert trainer.fit_loop._should_stop_early is early_stop
|
||||
assert trainer.fit_loop._can_stop_early is early_stop
|
||||
|
||||
|
||||
@pytest.mark.parametrize("min_epochs,min_steps,val_count", [(3, None, 3), (None, 3, 2)])
|
||||
def test_should_stop_triggers_validation_once(min_epochs, min_steps, val_count, tmp_path):
|
||||
"""Regression test for issue #15708.
|
||||
|
||||
Test that the request for `should_stop=True` only triggers validation when Trainer is allowed to stop
|
||||
(min_epochs/steps is satisfied).
|
||||
"""
|
||||
model = BoringModel()
|
||||
trainer = Trainer(
|
||||
default_root_dir=tmp_path,
|
||||
num_sanity_val_steps=0,
|
||||
limit_val_batches=2,
|
||||
limit_train_batches=2,
|
||||
max_epochs=3,
|
||||
min_epochs=min_epochs,
|
||||
min_steps=min_steps,
|
||||
enable_model_summary=False,
|
||||
enable_checkpointing=False,
|
||||
)
|
||||
trainer.should_stop = True # Request to stop before min_epochs/min_steps are reached
|
||||
trainer.fit_loop.epoch_loop.val_loop.run = Mock()
|
||||
trainer.fit(model)
|
||||
assert trainer.fit_loop.epoch_loop.val_loop.run.call_count == val_count
|
||||
|
|
|
@ -223,4 +223,4 @@ def test_should_stop_early_stopping_conditions_met(
|
|||
assert trainer.fit_loop.done is fit_loop_done
|
||||
|
||||
assert (message in caplog.text) is raise_debug_msg
|
||||
assert trainer.fit_loop._should_stop_early is early_stop
|
||||
assert trainer.fit_loop._can_stop_early is early_stop
|
||||
|
|
Loading…
Reference in New Issue