Smart handling of `EarlyStopping.check_on_train_epoch_end` (#8888)

* Smart handling of `EarlyStopping.check_on_train_epoch_end`

* dummy value

* Extra flag
This commit is contained in:
Carlos Mocholí 2021-08-14 08:50:39 +02:00 committed by GitHub
parent 7d87879350
commit bfeffde8f4
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 37 additions and 4 deletions

View File

@ -91,7 +91,7 @@ class EarlyStopping(Callback):
check_finite: bool = True,
stopping_threshold: Optional[float] = None,
divergence_threshold: Optional[float] = None,
check_on_train_epoch_end: bool = True,
check_on_train_epoch_end: Optional[bool] = None,
):
super().__init__()
self.min_delta = min_delta
@ -120,6 +120,12 @@ class EarlyStopping(Callback):
)
self.monitor = monitor or "early_stop_on"
def on_pretrain_routine_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
if self._check_on_train_epoch_end is None:
# if the user runs validation multiple times per training epoch, we try to check after
# validation instead of on train epoch end
self._check_on_train_epoch_end = trainer.val_check_interval == 1.0
def _validate_condition_metric(self, logs):
monitor_val = logs.get(self.monitor)
@ -191,7 +197,7 @@ class EarlyStopping(Callback):
# when in dev debugging
trainer.dev_debugger.track_early_stopping_history(self, current)
should_stop, reason = self._evalute_stopping_criteria(current)
should_stop, reason = self._evaluate_stopping_criteria(current)
# stop every ddp process if any world process decides to stop
should_stop = trainer.training_type_plugin.reduce_boolean_decision(should_stop)
@ -201,7 +207,7 @@ class EarlyStopping(Callback):
if reason and self.verbose:
self._log_info(trainer, reason)
def _evalute_stopping_criteria(self, current: torch.Tensor) -> Tuple[bool, str]:
def _evaluate_stopping_criteria(self, current: torch.Tensor) -> Tuple[bool, str]:
should_stop = False
reason = None
if self.check_finite and not torch.isfinite(current):

View File

@ -132,7 +132,6 @@ def test_early_stopping_patience(tmpdir, loss_values: list, patience: int, expec
trainer = Trainer(
default_root_dir=tmpdir,
callbacks=[early_stop_callback],
val_check_interval=1.0,
num_sanity_val_steps=0,
max_epochs=10,
progress_bar_refresh_rate=0,
@ -417,3 +416,31 @@ def test_multiple_early_stopping_callbacks(
num_processes=num_processes,
)
trainer.fit(model)
def test_check_on_train_epoch_end_with_val_check_interval(tmpdir):
class TestModel(BoringModel):
def validation_step(self, batch, batch_idx):
self.log("foo", 1)
return super().validation_step(batch, batch_idx)
model = TestModel()
val_check_interval, limit_train_batches = 0.3, 10
trainer = Trainer(
default_root_dir=tmpdir,
val_check_interval=val_check_interval,
max_epochs=1,
limit_train_batches=limit_train_batches,
limit_val_batches=1,
callbacks=EarlyStopping(monitor="foo"),
progress_bar_refresh_rate=0,
)
side_effect = [(False, "A"), (True, "B")]
with mock.patch(
"pytorch_lightning.callbacks.EarlyStopping._evaluate_stopping_criteria", side_effect=side_effect
) as es_mock:
trainer.fit(model)
assert es_mock.call_count == len(side_effect)
assert trainer.global_step == len(side_effect) * int(limit_train_batches * val_check_interval)