lightning/pytorch_lightning/callbacks/early_stopping.py

129 lines
4.4 KiB
Python
Raw Normal View History

r"""
Early Stopping
==============
Stop training when a monitored quantity has stopped improving.
"""
import numpy as np
import torch
from pytorch_lightning import _logger as log
from pytorch_lightning.callbacks.base import Callback
from pytorch_lightning.utilities import rank_zero_warn
torch_inf = torch.tensor(np.Inf)
class EarlyStopping(Callback):
r"""
Args:
monitor: quantity to be monitored. Default: ``'val_loss'``.
min_delta: minimum change in the monitored quantity
to qualify as an improvement, i.e. an absolute
change of less than `min_delta`, will count as no
improvement. Default: ``0``.
patience: number of epochs with no improvement
after which training will be stopped. Default: ``0``.
verbose: verbosity mode. Default: ``False``.
mode: one of {auto, min, max}. In `min` mode,
training will stop when the quantity
monitored has stopped decreasing; in `max`
mode it will stop when the quantity
monitored has stopped increasing; in `auto`
mode, the direction is automatically inferred
from the name of the monitored quantity. Default: ``'auto'``.
strict: whether to crash the training if `monitor` is
not found in the metrics. Default: ``True``.
Example::
>>> from pytorch_lightning import Trainer
>>> from pytorch_lightning.callbacks import EarlyStopping
>>> early_stopping = EarlyStopping('val_loss')
>>> trainer = Trainer(early_stop_callback=early_stopping)
"""
def __init__(self, monitor: str = 'val_loss', min_delta: float = 0.0, patience: int = 3,
verbose: bool = False, mode: str = 'auto', strict: bool = True):
super().__init__()
self.monitor = monitor
self.patience = patience
self.verbose = verbose
self.strict = strict
self.min_delta = min_delta
self.wait = 0
self.stopped_epoch = 0
mode_dict = {
'min': torch.lt,
'max': torch.gt,
'auto': torch.gt if 'acc' in self.monitor else torch.lt
}
if mode not in mode_dict:
if self.verbose > 0:
log.info(f'EarlyStopping mode {mode} is unknown, fallback to auto mode.')
mode = 'auto'
self.monitor_op = mode_dict[mode]
self.min_delta *= 1 if self.monitor_op == torch.gt else -1
def _validate_condition_metric(self, logs):
"""
Checks that the condition metric for early stopping is good
:param logs:
:return:
"""
monitor_val = logs.get(self.monitor)
error_msg = (f'Early stopping conditioned on metric `{self.monitor}`'
f' which is not available. Available metrics are:'
f' `{"`, `".join(list(logs.keys()))}`')
if monitor_val is None:
if self.strict:
raise RuntimeError(error_msg)
if self.verbose > 0:
rank_zero_warn(error_msg, RuntimeWarning)
return False
return True
def on_train_start(self, trainer, pl_module):
# Allow instances to be re-used
self.wait = 0
self.stopped_epoch = 0
self.best = torch_inf if self.monitor_op == torch.lt else -torch_inf
def on_epoch_end(self, trainer, pl_module):
logs = trainer.callback_metrics
stop_training = False
if not self._validate_condition_metric(logs):
return stop_training
current = logs.get(self.monitor)
if not isinstance(current, torch.Tensor):
current = torch.tensor(current)
if self.monitor_op(current - self.min_delta, self.best):
self.best = current
self.wait = 0
else:
self.wait += 1
if self.wait >= self.patience:
self.stopped_epoch = trainer.current_epoch
stop_training = True
self.on_train_end(trainer, pl_module)
return stop_training
def on_train_end(self, trainer, pl_module):
if self.stopped_epoch > 0 and self.verbose > 0:
rank_zero_warn('Displayed epoch numbers by `EarlyStopping` start from "1" until v0.6.x,'
' but will start from "0" in v0.8.0.', DeprecationWarning)
log.info(f'Epoch {self.stopped_epoch + 1:05d}: early stopping')