more early stopping options (convergence and divergence threshold) (#6868)

Co-authored-by: Carlos Mocholí <carlossmocholi@gmail.com>
This commit is contained in:
Adrian Wälchli 2021-04-19 16:49:52 +02:00 committed by GitHub
parent 60c1c8fe83
commit d12c6cf2b3
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 109 additions and 14 deletions

View File

@ -102,6 +102,10 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Added `max_time` Trainer argument to limit training time ([#6823](https://github.com/PyTorchLightning/pytorch-lightning/pull/6823))
- Added new `EarlyStopping` parameters `stopping_threshold` and `divergence_threshold` ([#6868](https://github.com/PyTorchLightning/pytorch-lightning/pull/6868))
### Changed
- Renamed `pytorch_lightning.callbacks.swa` to `pytorch_lightning.callbacks.stochastic_weight_avg` ([#6259](https://github.com/PyTorchLightning/pytorch-lightning/pull/6259))

View File

@ -18,7 +18,8 @@ Early Stopping
Monitor a metric and stop training when it stops improving.
"""
from typing import Any, Dict
import logging
from typing import Any, Dict, Optional, Tuple
import numpy as np
import torch
@ -27,6 +28,8 @@ from pytorch_lightning.callbacks.base import Callback
from pytorch_lightning.utilities import rank_zero_warn
from pytorch_lightning.utilities.exceptions import MisconfigurationException
log = logging.getLogger(__name__)
class EarlyStopping(Callback):
r"""
@ -53,6 +56,9 @@ class EarlyStopping(Callback):
monitored has stopped decreasing and in ``'max'`` mode it will stop when the quantity
monitored has stopped increasing.
strict: whether to crash the training if `monitor` is not found in the validation metrics.
check_finite: When set ``True``, stops training when the monitor becomes NaN or infinite.
stopping_threshold: Stop training immediately once the monitored quantity reaches this threshold.
divergence_threshold: Stop training as soon as the monitored quantity becomes worse than this threshold.
Raises:
MisconfigurationException:
@ -72,6 +78,11 @@ class EarlyStopping(Callback):
'max': torch.gt,
}
order_dict = {
'min': "<",
'max': ">",
}
def __init__(
self,
monitor: str = 'early_stop_on',
@ -80,16 +91,22 @@ class EarlyStopping(Callback):
verbose: bool = False,
mode: str = 'min',
strict: bool = True,
check_finite: bool = True,
stopping_threshold: Optional[float] = None,
divergence_threshold: Optional[float] = None,
):
super().__init__()
self.monitor = monitor
self.min_delta = min_delta
self.patience = patience
self.verbose = verbose
self.mode = mode
self.strict = strict
self.min_delta = min_delta
self.check_finite = check_finite
self.stopping_threshold = stopping_threshold
self.divergence_threshold = divergence_threshold
self.wait_count = 0
self.stopped_epoch = 0
self.mode = mode
if self.mode not in self.mode_dict:
raise MisconfigurationException(f"`mode` can be {', '.join(self.mode_dict.keys())}, got {self.mode}")
@ -160,15 +177,50 @@ class EarlyStopping(Callback):
# when in dev debugging
trainer.dev_debugger.track_early_stopping_history(self, current)
if self.monitor_op(current - self.min_delta, self.best_score):
should_stop, reason = self._evalute_stopping_criteria(current)
# stop every ddp process if any world process decides to stop
should_stop = trainer.training_type_plugin.reduce_boolean_decision(should_stop)
trainer.should_stop = trainer.should_stop or should_stop
if should_stop:
self.stopped_epoch = trainer.current_epoch
if reason:
log.info(f"[{trainer.global_rank}] {reason}")
def _evalute_stopping_criteria(self, current: torch.Tensor) -> Tuple[bool, str]:
should_stop = False
reason = None
if self.check_finite and not torch.isfinite(current):
should_stop = True
reason = (
f"Monitored metric {self.monitor} = {current} is not finite."
f" Previous best value was {self.best_score:.3f}. Signaling Trainer to stop."
)
elif self.stopping_threshold is not None and self.monitor_op(current, self.stopping_threshold):
should_stop = True
reason = (
"Stopping threshold reached:"
f" {self.monitor} = {current} {self.order_dict[self.mode]} {self.stopping_threshold}."
" Signaling Trainer to stop."
)
elif self.divergence_threshold is not None and self.monitor_op(-current, -self.divergence_threshold):
should_stop = True
reason = (
"Divergence threshold reached:"
f" {self.monitor} = {current} {self.order_dict[self.mode]} {self.divergence_threshold}."
" Signaling Trainer to stop."
)
elif self.monitor_op(current - self.min_delta, self.best_score):
should_stop = False
self.best_score = current
self.wait_count = 0
else:
self.wait_count += 1
if self.wait_count >= self.patience:
self.stopped_epoch = trainer.current_epoch
trainer.should_stop = True
should_stop = True
reason = (
f"Monitored metric {self.monitor} did not improve in the last {self.wait_count} epochs."
f" Best score: {self.best_score:.3f}. Signaling Trainer to stop."
)
# stop every ddp process if any world process decides to stop
trainer.should_stop = trainer.training_type_plugin.reduce_boolean_decision(trainer.should_stop)
return should_stop, reason

View File

@ -213,25 +213,64 @@ def test_early_stopping_no_val_step(tmpdir):
assert trainer.current_epoch < trainer.max_epochs - 1
def test_early_stopping_functionality(tmpdir):
@pytest.mark.parametrize("stopping_threshold,divergence_theshold,losses,expected_epoch", [
(None, None, [8, 4, 2, 3, 4, 5, 8, 10], 5),
(2.9, None, [9, 8, 7, 6, 5, 6, 4, 3, 2, 1], 8),
(None, 15.9, [9, 4, 2, 16, 32, 64], 3),
])
def test_early_stopping_thresholds(tmpdir, stopping_threshold, divergence_theshold, losses, expected_epoch):
class CurrentModel(BoringModel):
def validation_epoch_end(self, outputs):
losses = [8, 4, 2, 3, 4, 5, 8, 10]
val_loss = losses[self.current_epoch]
self.log('abc', val_loss)
model = CurrentModel()
early_stopping = EarlyStopping(
monitor='abc',
stopping_threshold=stopping_threshold,
divergence_threshold=divergence_theshold,
)
trainer = Trainer(
default_root_dir=tmpdir,
callbacks=[EarlyStopping(monitor='abc')],
callbacks=[early_stopping],
overfit_batches=0.20,
max_epochs=20,
)
trainer.fit(model)
assert trainer.current_epoch == 5, 'early_stopping failed'
assert trainer.current_epoch == expected_epoch, 'early_stopping failed'
@pytest.mark.parametrize("stop_value", [
torch.tensor(np.inf),
torch.tensor(np.nan),
])
def test_early_stopping_on_non_finite_monitor(tmpdir, stop_value):
losses = [4, 3, stop_value, 2, 1]
expected_stop_epoch = 2
class CurrentModel(BoringModel):
def validation_epoch_end(self, outputs):
val_loss = losses[self.current_epoch]
self.log('val_loss', val_loss)
model = CurrentModel()
early_stopping = EarlyStopping(
monitor='val_loss',
check_finite=True,
)
trainer = Trainer(
default_root_dir=tmpdir,
callbacks=[early_stopping],
overfit_batches=0.20,
max_epochs=10,
)
trainer.fit(model)
assert trainer.current_epoch == expected_stop_epoch
assert early_stopping.stopped_epoch == expected_stop_epoch
@pytest.mark.parametrize('step_freeze, min_steps, min_epochs', [(5, 1, 1), (5, 1, 3), (3, 15, 1)])