From bfeffde8f403f086d7cc76fc5d1749782a3e385d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Carlos=20Mochol=C3=AD?= Date: Sat, 14 Aug 2021 08:50:39 +0200 Subject: [PATCH] Smart handling of `EarlyStopping.check_on_train_epoch_end` (#8888) * Smart handling of `EarlyStopping.check_on_train_epoch_end` * dummy value * Extra flag --- pytorch_lightning/callbacks/early_stopping.py | 12 ++++++-- tests/callbacks/test_early_stopping.py | 29 ++++++++++++++++++- 2 files changed, 37 insertions(+), 4 deletions(-) diff --git a/pytorch_lightning/callbacks/early_stopping.py b/pytorch_lightning/callbacks/early_stopping.py index 7def99d3cf..4df926b796 100644 --- a/pytorch_lightning/callbacks/early_stopping.py +++ b/pytorch_lightning/callbacks/early_stopping.py @@ -91,7 +91,7 @@ class EarlyStopping(Callback): check_finite: bool = True, stopping_threshold: Optional[float] = None, divergence_threshold: Optional[float] = None, - check_on_train_epoch_end: bool = True, + check_on_train_epoch_end: Optional[bool] = None, ): super().__init__() self.min_delta = min_delta @@ -120,6 +120,12 @@ class EarlyStopping(Callback): ) self.monitor = monitor or "early_stop_on" + def on_pretrain_routine_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: + if self._check_on_train_epoch_end is None: + # if the user runs validation multiple times per training epoch, we try to check after + # validation instead of on train epoch end + self._check_on_train_epoch_end = trainer.val_check_interval == 1.0 + def _validate_condition_metric(self, logs): monitor_val = logs.get(self.monitor) @@ -191,7 +197,7 @@ class EarlyStopping(Callback): # when in dev debugging trainer.dev_debugger.track_early_stopping_history(self, current) - should_stop, reason = self._evalute_stopping_criteria(current) + should_stop, reason = self._evaluate_stopping_criteria(current) # stop every ddp process if any world process decides to stop should_stop = trainer.training_type_plugin.reduce_boolean_decision(should_stop) @@ -201,7 +207,7 @@ class EarlyStopping(Callback): if reason and self.verbose: self._log_info(trainer, reason) - def _evalute_stopping_criteria(self, current: torch.Tensor) -> Tuple[bool, str]: + def _evaluate_stopping_criteria(self, current: torch.Tensor) -> Tuple[bool, str]: should_stop = False reason = None if self.check_finite and not torch.isfinite(current): diff --git a/tests/callbacks/test_early_stopping.py b/tests/callbacks/test_early_stopping.py index f7e4968e61..4c3b990dd1 100644 --- a/tests/callbacks/test_early_stopping.py +++ b/tests/callbacks/test_early_stopping.py @@ -132,7 +132,6 @@ def test_early_stopping_patience(tmpdir, loss_values: list, patience: int, expec trainer = Trainer( default_root_dir=tmpdir, callbacks=[early_stop_callback], - val_check_interval=1.0, num_sanity_val_steps=0, max_epochs=10, progress_bar_refresh_rate=0, @@ -417,3 +416,31 @@ def test_multiple_early_stopping_callbacks( num_processes=num_processes, ) trainer.fit(model) + + +def test_check_on_train_epoch_end_with_val_check_interval(tmpdir): + class TestModel(BoringModel): + def validation_step(self, batch, batch_idx): + self.log("foo", 1) + return super().validation_step(batch, batch_idx) + + model = TestModel() + val_check_interval, limit_train_batches = 0.3, 10 + trainer = Trainer( + default_root_dir=tmpdir, + val_check_interval=val_check_interval, + max_epochs=1, + limit_train_batches=limit_train_batches, + limit_val_batches=1, + callbacks=EarlyStopping(monitor="foo"), + progress_bar_refresh_rate=0, + ) + + side_effect = [(False, "A"), (True, "B")] + with mock.patch( + "pytorch_lightning.callbacks.EarlyStopping._evaluate_stopping_criteria", side_effect=side_effect + ) as es_mock: + trainer.fit(model) + + assert es_mock.call_count == len(side_effect) + assert trainer.global_step == len(side_effect) * int(limit_train_batches * val_check_interval)