Smart handling of `EarlyStopping.check_on_train_epoch_end` (#8888)
* Smart handling of `EarlyStopping.check_on_train_epoch_end` * dummy value * Extra flag
This commit is contained in:
parent
7d87879350
commit
bfeffde8f4
|
@ -91,7 +91,7 @@ class EarlyStopping(Callback):
|
|||
check_finite: bool = True,
|
||||
stopping_threshold: Optional[float] = None,
|
||||
divergence_threshold: Optional[float] = None,
|
||||
check_on_train_epoch_end: bool = True,
|
||||
check_on_train_epoch_end: Optional[bool] = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.min_delta = min_delta
|
||||
|
@ -120,6 +120,12 @@ class EarlyStopping(Callback):
|
|||
)
|
||||
self.monitor = monitor or "early_stop_on"
|
||||
|
||||
def on_pretrain_routine_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
|
||||
if self._check_on_train_epoch_end is None:
|
||||
# if the user runs validation multiple times per training epoch, we try to check after
|
||||
# validation instead of on train epoch end
|
||||
self._check_on_train_epoch_end = trainer.val_check_interval == 1.0
|
||||
|
||||
def _validate_condition_metric(self, logs):
|
||||
monitor_val = logs.get(self.monitor)
|
||||
|
||||
|
@ -191,7 +197,7 @@ class EarlyStopping(Callback):
|
|||
# when in dev debugging
|
||||
trainer.dev_debugger.track_early_stopping_history(self, current)
|
||||
|
||||
should_stop, reason = self._evalute_stopping_criteria(current)
|
||||
should_stop, reason = self._evaluate_stopping_criteria(current)
|
||||
|
||||
# stop every ddp process if any world process decides to stop
|
||||
should_stop = trainer.training_type_plugin.reduce_boolean_decision(should_stop)
|
||||
|
@ -201,7 +207,7 @@ class EarlyStopping(Callback):
|
|||
if reason and self.verbose:
|
||||
self._log_info(trainer, reason)
|
||||
|
||||
def _evalute_stopping_criteria(self, current: torch.Tensor) -> Tuple[bool, str]:
|
||||
def _evaluate_stopping_criteria(self, current: torch.Tensor) -> Tuple[bool, str]:
|
||||
should_stop = False
|
||||
reason = None
|
||||
if self.check_finite and not torch.isfinite(current):
|
||||
|
|
|
@ -132,7 +132,6 @@ def test_early_stopping_patience(tmpdir, loss_values: list, patience: int, expec
|
|||
trainer = Trainer(
|
||||
default_root_dir=tmpdir,
|
||||
callbacks=[early_stop_callback],
|
||||
val_check_interval=1.0,
|
||||
num_sanity_val_steps=0,
|
||||
max_epochs=10,
|
||||
progress_bar_refresh_rate=0,
|
||||
|
@ -417,3 +416,31 @@ def test_multiple_early_stopping_callbacks(
|
|||
num_processes=num_processes,
|
||||
)
|
||||
trainer.fit(model)
|
||||
|
||||
|
||||
def test_check_on_train_epoch_end_with_val_check_interval(tmpdir):
|
||||
class TestModel(BoringModel):
|
||||
def validation_step(self, batch, batch_idx):
|
||||
self.log("foo", 1)
|
||||
return super().validation_step(batch, batch_idx)
|
||||
|
||||
model = TestModel()
|
||||
val_check_interval, limit_train_batches = 0.3, 10
|
||||
trainer = Trainer(
|
||||
default_root_dir=tmpdir,
|
||||
val_check_interval=val_check_interval,
|
||||
max_epochs=1,
|
||||
limit_train_batches=limit_train_batches,
|
||||
limit_val_batches=1,
|
||||
callbacks=EarlyStopping(monitor="foo"),
|
||||
progress_bar_refresh_rate=0,
|
||||
)
|
||||
|
||||
side_effect = [(False, "A"), (True, "B")]
|
||||
with mock.patch(
|
||||
"pytorch_lightning.callbacks.EarlyStopping._evaluate_stopping_criteria", side_effect=side_effect
|
||||
) as es_mock:
|
||||
trainer.fit(model)
|
||||
|
||||
assert es_mock.call_count == len(side_effect)
|
||||
assert trainer.global_step == len(side_effect) * int(limit_train_batches * val_check_interval)
|
||||
|
|
Loading…
Reference in New Issue