more early stopping options (convergence and divergence threshold) (#6868)
Co-authored-by: Carlos Mocholí <carlossmocholi@gmail.com>
This commit is contained in:
parent
60c1c8fe83
commit
d12c6cf2b3
|
@ -102,6 +102,10 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
|
|||
- Added `max_time` Trainer argument to limit training time ([#6823](https://github.com/PyTorchLightning/pytorch-lightning/pull/6823))
|
||||
|
||||
|
||||
- Added new `EarlyStopping` parameters `stopping_threshold` and `divergence_threshold` ([#6868](https://github.com/PyTorchLightning/pytorch-lightning/pull/6868))
|
||||
|
||||
|
||||
|
||||
### Changed
|
||||
|
||||
- Renamed `pytorch_lightning.callbacks.swa` to `pytorch_lightning.callbacks.stochastic_weight_avg` ([#6259](https://github.com/PyTorchLightning/pytorch-lightning/pull/6259))
|
||||
|
|
|
@ -18,7 +18,8 @@ Early Stopping
|
|||
Monitor a metric and stop training when it stops improving.
|
||||
|
||||
"""
|
||||
from typing import Any, Dict
|
||||
import logging
|
||||
from typing import Any, Dict, Optional, Tuple
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
@ -27,6 +28,8 @@ from pytorch_lightning.callbacks.base import Callback
|
|||
from pytorch_lightning.utilities import rank_zero_warn
|
||||
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class EarlyStopping(Callback):
|
||||
r"""
|
||||
|
@ -53,6 +56,9 @@ class EarlyStopping(Callback):
|
|||
monitored has stopped decreasing and in ``'max'`` mode it will stop when the quantity
|
||||
monitored has stopped increasing.
|
||||
strict: whether to crash the training if `monitor` is not found in the validation metrics.
|
||||
check_finite: When set ``True``, stops training when the monitor becomes NaN or infinite.
|
||||
stopping_threshold: Stop training immediately once the monitored quantity reaches this threshold.
|
||||
divergence_threshold: Stop training as soon as the monitored quantity becomes worse than this threshold.
|
||||
|
||||
Raises:
|
||||
MisconfigurationException:
|
||||
|
@ -72,6 +78,11 @@ class EarlyStopping(Callback):
|
|||
'max': torch.gt,
|
||||
}
|
||||
|
||||
order_dict = {
|
||||
'min': "<",
|
||||
'max': ">",
|
||||
}
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
monitor: str = 'early_stop_on',
|
||||
|
@ -80,16 +91,22 @@ class EarlyStopping(Callback):
|
|||
verbose: bool = False,
|
||||
mode: str = 'min',
|
||||
strict: bool = True,
|
||||
check_finite: bool = True,
|
||||
stopping_threshold: Optional[float] = None,
|
||||
divergence_threshold: Optional[float] = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.monitor = monitor
|
||||
self.min_delta = min_delta
|
||||
self.patience = patience
|
||||
self.verbose = verbose
|
||||
self.mode = mode
|
||||
self.strict = strict
|
||||
self.min_delta = min_delta
|
||||
self.check_finite = check_finite
|
||||
self.stopping_threshold = stopping_threshold
|
||||
self.divergence_threshold = divergence_threshold
|
||||
self.wait_count = 0
|
||||
self.stopped_epoch = 0
|
||||
self.mode = mode
|
||||
|
||||
if self.mode not in self.mode_dict:
|
||||
raise MisconfigurationException(f"`mode` can be {', '.join(self.mode_dict.keys())}, got {self.mode}")
|
||||
|
@ -160,15 +177,50 @@ class EarlyStopping(Callback):
|
|||
# when in dev debugging
|
||||
trainer.dev_debugger.track_early_stopping_history(self, current)
|
||||
|
||||
if self.monitor_op(current - self.min_delta, self.best_score):
|
||||
should_stop, reason = self._evalute_stopping_criteria(current)
|
||||
|
||||
# stop every ddp process if any world process decides to stop
|
||||
should_stop = trainer.training_type_plugin.reduce_boolean_decision(should_stop)
|
||||
trainer.should_stop = trainer.should_stop or should_stop
|
||||
if should_stop:
|
||||
self.stopped_epoch = trainer.current_epoch
|
||||
if reason:
|
||||
log.info(f"[{trainer.global_rank}] {reason}")
|
||||
|
||||
def _evalute_stopping_criteria(self, current: torch.Tensor) -> Tuple[bool, str]:
|
||||
should_stop = False
|
||||
reason = None
|
||||
if self.check_finite and not torch.isfinite(current):
|
||||
should_stop = True
|
||||
reason = (
|
||||
f"Monitored metric {self.monitor} = {current} is not finite."
|
||||
f" Previous best value was {self.best_score:.3f}. Signaling Trainer to stop."
|
||||
)
|
||||
elif self.stopping_threshold is not None and self.monitor_op(current, self.stopping_threshold):
|
||||
should_stop = True
|
||||
reason = (
|
||||
"Stopping threshold reached:"
|
||||
f" {self.monitor} = {current} {self.order_dict[self.mode]} {self.stopping_threshold}."
|
||||
" Signaling Trainer to stop."
|
||||
)
|
||||
elif self.divergence_threshold is not None and self.monitor_op(-current, -self.divergence_threshold):
|
||||
should_stop = True
|
||||
reason = (
|
||||
"Divergence threshold reached:"
|
||||
f" {self.monitor} = {current} {self.order_dict[self.mode]} {self.divergence_threshold}."
|
||||
" Signaling Trainer to stop."
|
||||
)
|
||||
elif self.monitor_op(current - self.min_delta, self.best_score):
|
||||
should_stop = False
|
||||
self.best_score = current
|
||||
self.wait_count = 0
|
||||
else:
|
||||
self.wait_count += 1
|
||||
|
||||
if self.wait_count >= self.patience:
|
||||
self.stopped_epoch = trainer.current_epoch
|
||||
trainer.should_stop = True
|
||||
should_stop = True
|
||||
reason = (
|
||||
f"Monitored metric {self.monitor} did not improve in the last {self.wait_count} epochs."
|
||||
f" Best score: {self.best_score:.3f}. Signaling Trainer to stop."
|
||||
)
|
||||
|
||||
# stop every ddp process if any world process decides to stop
|
||||
trainer.should_stop = trainer.training_type_plugin.reduce_boolean_decision(trainer.should_stop)
|
||||
return should_stop, reason
|
||||
|
|
|
@ -213,25 +213,64 @@ def test_early_stopping_no_val_step(tmpdir):
|
|||
assert trainer.current_epoch < trainer.max_epochs - 1
|
||||
|
||||
|
||||
def test_early_stopping_functionality(tmpdir):
|
||||
@pytest.mark.parametrize("stopping_threshold,divergence_theshold,losses,expected_epoch", [
|
||||
(None, None, [8, 4, 2, 3, 4, 5, 8, 10], 5),
|
||||
(2.9, None, [9, 8, 7, 6, 5, 6, 4, 3, 2, 1], 8),
|
||||
(None, 15.9, [9, 4, 2, 16, 32, 64], 3),
|
||||
])
|
||||
def test_early_stopping_thresholds(tmpdir, stopping_threshold, divergence_theshold, losses, expected_epoch):
|
||||
|
||||
class CurrentModel(BoringModel):
|
||||
|
||||
def validation_epoch_end(self, outputs):
|
||||
losses = [8, 4, 2, 3, 4, 5, 8, 10]
|
||||
val_loss = losses[self.current_epoch]
|
||||
self.log('abc', val_loss)
|
||||
|
||||
model = CurrentModel()
|
||||
|
||||
early_stopping = EarlyStopping(
|
||||
monitor='abc',
|
||||
stopping_threshold=stopping_threshold,
|
||||
divergence_threshold=divergence_theshold,
|
||||
)
|
||||
trainer = Trainer(
|
||||
default_root_dir=tmpdir,
|
||||
callbacks=[EarlyStopping(monitor='abc')],
|
||||
callbacks=[early_stopping],
|
||||
overfit_batches=0.20,
|
||||
max_epochs=20,
|
||||
)
|
||||
trainer.fit(model)
|
||||
assert trainer.current_epoch == 5, 'early_stopping failed'
|
||||
assert trainer.current_epoch == expected_epoch, 'early_stopping failed'
|
||||
|
||||
|
||||
@pytest.mark.parametrize("stop_value", [
|
||||
torch.tensor(np.inf),
|
||||
torch.tensor(np.nan),
|
||||
])
|
||||
def test_early_stopping_on_non_finite_monitor(tmpdir, stop_value):
|
||||
|
||||
losses = [4, 3, stop_value, 2, 1]
|
||||
expected_stop_epoch = 2
|
||||
|
||||
class CurrentModel(BoringModel):
|
||||
|
||||
def validation_epoch_end(self, outputs):
|
||||
val_loss = losses[self.current_epoch]
|
||||
self.log('val_loss', val_loss)
|
||||
|
||||
model = CurrentModel()
|
||||
early_stopping = EarlyStopping(
|
||||
monitor='val_loss',
|
||||
check_finite=True,
|
||||
)
|
||||
trainer = Trainer(
|
||||
default_root_dir=tmpdir,
|
||||
callbacks=[early_stopping],
|
||||
overfit_batches=0.20,
|
||||
max_epochs=10,
|
||||
)
|
||||
trainer.fit(model)
|
||||
assert trainer.current_epoch == expected_stop_epoch
|
||||
assert early_stopping.stopped_epoch == expected_stop_epoch
|
||||
|
||||
|
||||
@pytest.mark.parametrize('step_freeze, min_steps, min_epochs', [(5, 1, 1), (5, 1, 3), (3, 15, 1)])
|
||||
|
|
Loading…
Reference in New Issue