Check early stopping metric in the beginning of the training (#542)

* Early stopping fix

* Update trainer.py

* Don't force validation sanity check

* fix tests

* update

* Added early_stopping check_metrics

* Updated docs

* Update docs

* Do not call early stopping when validation is disabled

Co-authored-by: William Falcon <waf2107@columbia.edu>
This commit is contained in:
Vadim Bereznyuk 2020-01-23 19:12:51 +03:00 committed by William Falcon
parent 588ad83771
commit 50881c0b31
6 changed files with 65 additions and 22 deletions

View File

@ -71,21 +71,23 @@ class EarlyStopping(Callback):
Stop training when a monitored quantity has stopped improving.
Args:
monitor (str): quantity to be monitored.
monitor (str): quantity to be monitored. Default: ``'val_loss'``.
min_delta (float): minimum change in the monitored quantity
to qualify as an improvement, i.e. an absolute
change of less than min_delta, will count as no
improvement.
change of less than `min_delta`, will count as no
improvement. Default: ``0``.
patience (int): number of epochs with no improvement
after which training will be stopped.
verbose (bool): verbosity mode.
after which training will be stopped. Default: ``0``.
verbose (bool): verbosity mode. Default: ``0``.
mode (str): one of {auto, min, max}. In `min` mode,
training will stop when the quantity
monitored has stopped decreasing; in `max`
mode it will stop when the quantity
monitored has stopped increasing; in `auto`
mode, the direction is automatically inferred
from the name of the monitored quantity.
from the name of the monitored quantity. Default: ``'auto'``.
strict (bool): whether to crash the training if `monitor` is
not found in the metrics. Default: ``True``.
Example::
@ -97,18 +99,20 @@ class EarlyStopping(Callback):
"""
def __init__(self, monitor='val_loss',
min_delta=0.0, patience=0, verbose=0, mode='auto'):
min_delta=0.0, patience=0, verbose=0, mode='auto', strict=True):
super(EarlyStopping, self).__init__()
self.monitor = monitor
self.patience = patience
self.verbose = verbose
self.strict = strict
self.min_delta = min_delta
self.wait = 0
self.stopped_epoch = 0
if mode not in ['auto', 'min', 'max']:
logging.info(f'EarlyStopping mode {mode} is unknown, fallback to auto mode.')
if self.verbose > 0:
logging.info(f'EarlyStopping mode {mode} is unknown, fallback to auto mode.')
mode = 'auto'
if mode == 'min':
@ -128,6 +132,22 @@ class EarlyStopping(Callback):
self.on_train_begin()
def check_metrics(self, logs):
monitor_val = logs.get(self.monitor)
error_msg = (f'Early stopping conditioned on metric `{self.monitor}`'
f' which is not available. Available metrics are:'
f' `{"`, `".join(list(logs.keys()))}`')
if monitor_val is None:
if self.strict:
raise RuntimeError(error_msg)
elif self.verbose > 0:
warnings.warn(error_msg, RuntimeWarning)
return False
return True
def on_train_begin(self, logs=None):
# Allow instances to be re-used
self.wait = 0
@ -135,16 +155,11 @@ class EarlyStopping(Callback):
self.best = np.Inf if self.monitor_op == np.less else -np.Inf
def on_epoch_end(self, epoch, logs=None):
current = logs.get(self.monitor)
stop_training = False
if current is None:
warnings.warn(
f'Early stopping conditioned on metric `{self.monitor}`'
f' which is not available. Available metrics are: {",".join(list(logs.keys()))}',
RuntimeWarning)
stop_training = True
if not self.check_metrics(logs):
return stop_training
current = logs.get(self.monitor)
if self.monitor_op(current - self.min_delta, self.best):
self.best = current
self.wait = 0

View File

@ -55,10 +55,20 @@ class TrainerCallbackConfigMixin(ABC):
self.early_stop_callback = EarlyStopping(
monitor='val_loss',
patience=3,
strict=True,
verbose=True,
mode='min'
)
self.enable_early_stop = True
elif early_stop_callback is None:
self.early_stop_callback = EarlyStopping(
monitor='val_loss',
patience=3,
strict=False,
verbose=False,
mode='min'
)
self.enable_early_stop = True
elif not early_stop_callback:
self.early_stop_callback = None
self.enable_early_stop = False

View File

@ -52,7 +52,7 @@ class Trainer(TrainerIOMixin,
self,
logger=True,
checkpoint_callback=True,
early_stop_callback=True,
early_stop_callback=None,
default_save_path=None,
gradient_clip_val=0,
gradient_clip=None, # backward compatible, todo: remove in v0.8.0
@ -121,7 +121,13 @@ class Trainer(TrainerIOMixin,
)
trainer = Trainer(checkpoint_callback=checkpoint_callback)
early_stop_callback (:class:`.EarlyStopping`): Callback for early stopping
early_stop_callback (:class:`.EarlyStopping`): Callback for early stopping. If
set to ``True``, then the default callback monitoring ``'val_loss'`` is created.
Will raise an error if ``'val_loss'`` is not found.
If set to ``False``, then early stopping will be disabled.
If set to ``None``, then the default callback monitoring ``'val_loss'`` is created.
If ``'val_loss'`` is not found will work as if early stopping is disabled.
Default: ``None``.
Example::
from pytorch_lightning.callbacks import EarlyStopping
@ -129,7 +135,8 @@ class Trainer(TrainerIOMixin,
early_stop_callback = EarlyStopping(
monitor='val_loss',
patience=3,
verbose=True,
strict=False,
verbose=False,
mode='min'
)
@ -809,12 +816,17 @@ class Trainer(TrainerIOMixin,
# dummy validation progress bar
self.val_progress_bar = tqdm.tqdm(disable=True)
self.evaluate(model, self.get_val_dataloaders(), self.num_sanity_val_steps, self.testing)
eval_results = self.evaluate(model, self.get_val_dataloaders(),
self.num_sanity_val_steps, False)
_, _, _, callback_metrics, _ = self.process_output(eval_results)
# close progress bars
self.main_progress_bar.close()
self.val_progress_bar.close()
if self.enable_early_stop:
self.early_stop_callback.check_metrics(callback_metrics)
# init progress bar
pbar = tqdm.tqdm(leave=True, position=2 * self.process_position,
disable=not self.show_progress_bar, dynamic_ncols=True, unit='batch',

View File

@ -346,7 +346,8 @@ class TrainerTrainLoopMixin(ABC):
# early stopping
met_min_epochs = epoch >= self.min_epochs - 1
if self.enable_early_stop and (met_min_epochs or self.fast_dev_run):
if (self.enable_early_stop and not self.disable_validation and
(met_min_epochs or self.fast_dev_run)):
should_stop = self.early_stop_callback.on_epoch_end(epoch=epoch,
logs=self.callback_metrics)
# stop training
@ -401,6 +402,9 @@ class TrainerTrainLoopMixin(ABC):
if self.fast_dev_run or should_check_val:
self.run_evaluation(test=self.testing)
if self.enable_early_stop:
self.early_stop_callback.check_metrics(self.callback_metrics)
# when logs should be saved
should_save_log = (batch_idx + 1) % self.log_save_interval == 0 or early_stop_epoch
if should_save_log or self.fast_dev_run:

View File

@ -140,7 +140,8 @@ def test_running_test_without_val(tmpdir):
val_percent_check=0.2,
test_percent_check=0.2,
checkpoint_callback=checkpoint,
logger=logger
logger=logger,
early_stop_callback=False
)
# fit model
@ -318,6 +319,7 @@ def test_tbptt_cpu_model(tmpdir):
truncated_bptt_steps=truncated_bptt_steps,
val_percent_check=0,
weights_summary=None,
early_stop_callback=False
)
hparams = tutils.get_hparams()

View File

@ -392,7 +392,7 @@ def test_multiple_test_dataloader(tmpdir):
default_save_path=tmpdir,
max_epochs=1,
val_percent_check=0.1,
train_percent_check=0.2,
train_percent_check=0.2
)
# fit model