From d12c6cf2b358c989d0d8bc17018049def99d6129 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Mon, 19 Apr 2021 16:49:52 +0200 Subject: [PATCH] more early stopping options (convergence and divergence threshold) (#6868) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Carlos MocholĂ­ --- CHANGELOG.md | 4 ++ pytorch_lightning/callbacks/early_stopping.py | 70 ++++++++++++++++--- tests/callbacks/test_early_stopping.py | 49 +++++++++++-- 3 files changed, 109 insertions(+), 14 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 59575e2d02..bab60910c7 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -102,6 +102,10 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Added `max_time` Trainer argument to limit training time ([#6823](https://github.com/PyTorchLightning/pytorch-lightning/pull/6823)) +- Added new `EarlyStopping` parameters `stopping_threshold` and `divergence_threshold` ([#6868](https://github.com/PyTorchLightning/pytorch-lightning/pull/6868)) + + + ### Changed - Renamed `pytorch_lightning.callbacks.swa` to `pytorch_lightning.callbacks.stochastic_weight_avg` ([#6259](https://github.com/PyTorchLightning/pytorch-lightning/pull/6259)) diff --git a/pytorch_lightning/callbacks/early_stopping.py b/pytorch_lightning/callbacks/early_stopping.py index 24ebcdf807..9af576aafd 100644 --- a/pytorch_lightning/callbacks/early_stopping.py +++ b/pytorch_lightning/callbacks/early_stopping.py @@ -18,7 +18,8 @@ Early Stopping Monitor a metric and stop training when it stops improving. """ -from typing import Any, Dict +import logging +from typing import Any, Dict, Optional, Tuple import numpy as np import torch @@ -27,6 +28,8 @@ from pytorch_lightning.callbacks.base import Callback from pytorch_lightning.utilities import rank_zero_warn from pytorch_lightning.utilities.exceptions import MisconfigurationException +log = logging.getLogger(__name__) + class EarlyStopping(Callback): r""" @@ -53,6 +56,9 @@ class EarlyStopping(Callback): monitored has stopped decreasing and in ``'max'`` mode it will stop when the quantity monitored has stopped increasing. strict: whether to crash the training if `monitor` is not found in the validation metrics. + check_finite: When set ``True``, stops training when the monitor becomes NaN or infinite. + stopping_threshold: Stop training immediately once the monitored quantity reaches this threshold. + divergence_threshold: Stop training as soon as the monitored quantity becomes worse than this threshold. Raises: MisconfigurationException: @@ -72,6 +78,11 @@ class EarlyStopping(Callback): 'max': torch.gt, } + order_dict = { + 'min': "<", + 'max': ">", + } + def __init__( self, monitor: str = 'early_stop_on', @@ -80,16 +91,22 @@ class EarlyStopping(Callback): verbose: bool = False, mode: str = 'min', strict: bool = True, + check_finite: bool = True, + stopping_threshold: Optional[float] = None, + divergence_threshold: Optional[float] = None, ): super().__init__() self.monitor = monitor + self.min_delta = min_delta self.patience = patience self.verbose = verbose + self.mode = mode self.strict = strict - self.min_delta = min_delta + self.check_finite = check_finite + self.stopping_threshold = stopping_threshold + self.divergence_threshold = divergence_threshold self.wait_count = 0 self.stopped_epoch = 0 - self.mode = mode if self.mode not in self.mode_dict: raise MisconfigurationException(f"`mode` can be {', '.join(self.mode_dict.keys())}, got {self.mode}") @@ -160,15 +177,50 @@ class EarlyStopping(Callback): # when in dev debugging trainer.dev_debugger.track_early_stopping_history(self, current) - if self.monitor_op(current - self.min_delta, self.best_score): + should_stop, reason = self._evalute_stopping_criteria(current) + + # stop every ddp process if any world process decides to stop + should_stop = trainer.training_type_plugin.reduce_boolean_decision(should_stop) + trainer.should_stop = trainer.should_stop or should_stop + if should_stop: + self.stopped_epoch = trainer.current_epoch + if reason: + log.info(f"[{trainer.global_rank}] {reason}") + + def _evalute_stopping_criteria(self, current: torch.Tensor) -> Tuple[bool, str]: + should_stop = False + reason = None + if self.check_finite and not torch.isfinite(current): + should_stop = True + reason = ( + f"Monitored metric {self.monitor} = {current} is not finite." + f" Previous best value was {self.best_score:.3f}. Signaling Trainer to stop." + ) + elif self.stopping_threshold is not None and self.monitor_op(current, self.stopping_threshold): + should_stop = True + reason = ( + "Stopping threshold reached:" + f" {self.monitor} = {current} {self.order_dict[self.mode]} {self.stopping_threshold}." + " Signaling Trainer to stop." + ) + elif self.divergence_threshold is not None and self.monitor_op(-current, -self.divergence_threshold): + should_stop = True + reason = ( + "Divergence threshold reached:" + f" {self.monitor} = {current} {self.order_dict[self.mode]} {self.divergence_threshold}." + " Signaling Trainer to stop." + ) + elif self.monitor_op(current - self.min_delta, self.best_score): + should_stop = False self.best_score = current self.wait_count = 0 else: self.wait_count += 1 - if self.wait_count >= self.patience: - self.stopped_epoch = trainer.current_epoch - trainer.should_stop = True + should_stop = True + reason = ( + f"Monitored metric {self.monitor} did not improve in the last {self.wait_count} epochs." + f" Best score: {self.best_score:.3f}. Signaling Trainer to stop." + ) - # stop every ddp process if any world process decides to stop - trainer.should_stop = trainer.training_type_plugin.reduce_boolean_decision(trainer.should_stop) + return should_stop, reason diff --git a/tests/callbacks/test_early_stopping.py b/tests/callbacks/test_early_stopping.py index cc619077ee..3844d16edb 100644 --- a/tests/callbacks/test_early_stopping.py +++ b/tests/callbacks/test_early_stopping.py @@ -213,25 +213,64 @@ def test_early_stopping_no_val_step(tmpdir): assert trainer.current_epoch < trainer.max_epochs - 1 -def test_early_stopping_functionality(tmpdir): +@pytest.mark.parametrize("stopping_threshold,divergence_theshold,losses,expected_epoch", [ + (None, None, [8, 4, 2, 3, 4, 5, 8, 10], 5), + (2.9, None, [9, 8, 7, 6, 5, 6, 4, 3, 2, 1], 8), + (None, 15.9, [9, 4, 2, 16, 32, 64], 3), +]) +def test_early_stopping_thresholds(tmpdir, stopping_threshold, divergence_theshold, losses, expected_epoch): class CurrentModel(BoringModel): def validation_epoch_end(self, outputs): - losses = [8, 4, 2, 3, 4, 5, 8, 10] val_loss = losses[self.current_epoch] self.log('abc', val_loss) model = CurrentModel() - + early_stopping = EarlyStopping( + monitor='abc', + stopping_threshold=stopping_threshold, + divergence_threshold=divergence_theshold, + ) trainer = Trainer( default_root_dir=tmpdir, - callbacks=[EarlyStopping(monitor='abc')], + callbacks=[early_stopping], overfit_batches=0.20, max_epochs=20, ) trainer.fit(model) - assert trainer.current_epoch == 5, 'early_stopping failed' + assert trainer.current_epoch == expected_epoch, 'early_stopping failed' + + +@pytest.mark.parametrize("stop_value", [ + torch.tensor(np.inf), + torch.tensor(np.nan), +]) +def test_early_stopping_on_non_finite_monitor(tmpdir, stop_value): + + losses = [4, 3, stop_value, 2, 1] + expected_stop_epoch = 2 + + class CurrentModel(BoringModel): + + def validation_epoch_end(self, outputs): + val_loss = losses[self.current_epoch] + self.log('val_loss', val_loss) + + model = CurrentModel() + early_stopping = EarlyStopping( + monitor='val_loss', + check_finite=True, + ) + trainer = Trainer( + default_root_dir=tmpdir, + callbacks=[early_stopping], + overfit_batches=0.20, + max_epochs=10, + ) + trainer.fit(model) + assert trainer.current_epoch == expected_stop_epoch + assert early_stopping.stopped_epoch == expected_stop_epoch @pytest.mark.parametrize('step_freeze, min_steps, min_epochs', [(5, 1, 1), (5, 1, 3), (3, 15, 1)])