Check early stopping metric in the beginning of the training (#542)
* Early stopping fix * Update trainer.py * Don't force validation sanity check * fix tests * update * Added early_stopping check_metrics * Updated docs * Update docs * Do not call early stopping when validation is disabled Co-authored-by: William Falcon <waf2107@columbia.edu>
This commit is contained in:
parent
588ad83771
commit
50881c0b31
|
@ -71,21 +71,23 @@ class EarlyStopping(Callback):
|
|||
Stop training when a monitored quantity has stopped improving.
|
||||
|
||||
Args:
|
||||
monitor (str): quantity to be monitored.
|
||||
monitor (str): quantity to be monitored. Default: ``'val_loss'``.
|
||||
min_delta (float): minimum change in the monitored quantity
|
||||
to qualify as an improvement, i.e. an absolute
|
||||
change of less than min_delta, will count as no
|
||||
improvement.
|
||||
change of less than `min_delta`, will count as no
|
||||
improvement. Default: ``0``.
|
||||
patience (int): number of epochs with no improvement
|
||||
after which training will be stopped.
|
||||
verbose (bool): verbosity mode.
|
||||
after which training will be stopped. Default: ``0``.
|
||||
verbose (bool): verbosity mode. Default: ``0``.
|
||||
mode (str): one of {auto, min, max}. In `min` mode,
|
||||
training will stop when the quantity
|
||||
monitored has stopped decreasing; in `max`
|
||||
mode it will stop when the quantity
|
||||
monitored has stopped increasing; in `auto`
|
||||
mode, the direction is automatically inferred
|
||||
from the name of the monitored quantity.
|
||||
from the name of the monitored quantity. Default: ``'auto'``.
|
||||
strict (bool): whether to crash the training if `monitor` is
|
||||
not found in the metrics. Default: ``True``.
|
||||
|
||||
Example::
|
||||
|
||||
|
@ -97,18 +99,20 @@ class EarlyStopping(Callback):
|
|||
"""
|
||||
|
||||
def __init__(self, monitor='val_loss',
|
||||
min_delta=0.0, patience=0, verbose=0, mode='auto'):
|
||||
min_delta=0.0, patience=0, verbose=0, mode='auto', strict=True):
|
||||
super(EarlyStopping, self).__init__()
|
||||
|
||||
self.monitor = monitor
|
||||
self.patience = patience
|
||||
self.verbose = verbose
|
||||
self.strict = strict
|
||||
self.min_delta = min_delta
|
||||
self.wait = 0
|
||||
self.stopped_epoch = 0
|
||||
|
||||
if mode not in ['auto', 'min', 'max']:
|
||||
logging.info(f'EarlyStopping mode {mode} is unknown, fallback to auto mode.')
|
||||
if self.verbose > 0:
|
||||
logging.info(f'EarlyStopping mode {mode} is unknown, fallback to auto mode.')
|
||||
mode = 'auto'
|
||||
|
||||
if mode == 'min':
|
||||
|
@ -128,6 +132,22 @@ class EarlyStopping(Callback):
|
|||
|
||||
self.on_train_begin()
|
||||
|
||||
def check_metrics(self, logs):
|
||||
monitor_val = logs.get(self.monitor)
|
||||
error_msg = (f'Early stopping conditioned on metric `{self.monitor}`'
|
||||
f' which is not available. Available metrics are:'
|
||||
f' `{"`, `".join(list(logs.keys()))}`')
|
||||
|
||||
if monitor_val is None:
|
||||
if self.strict:
|
||||
raise RuntimeError(error_msg)
|
||||
elif self.verbose > 0:
|
||||
warnings.warn(error_msg, RuntimeWarning)
|
||||
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
def on_train_begin(self, logs=None):
|
||||
# Allow instances to be re-used
|
||||
self.wait = 0
|
||||
|
@ -135,16 +155,11 @@ class EarlyStopping(Callback):
|
|||
self.best = np.Inf if self.monitor_op == np.less else -np.Inf
|
||||
|
||||
def on_epoch_end(self, epoch, logs=None):
|
||||
current = logs.get(self.monitor)
|
||||
stop_training = False
|
||||
if current is None:
|
||||
warnings.warn(
|
||||
f'Early stopping conditioned on metric `{self.monitor}`'
|
||||
f' which is not available. Available metrics are: {",".join(list(logs.keys()))}',
|
||||
RuntimeWarning)
|
||||
stop_training = True
|
||||
if not self.check_metrics(logs):
|
||||
return stop_training
|
||||
|
||||
current = logs.get(self.monitor)
|
||||
if self.monitor_op(current - self.min_delta, self.best):
|
||||
self.best = current
|
||||
self.wait = 0
|
||||
|
|
|
@ -55,10 +55,20 @@ class TrainerCallbackConfigMixin(ABC):
|
|||
self.early_stop_callback = EarlyStopping(
|
||||
monitor='val_loss',
|
||||
patience=3,
|
||||
strict=True,
|
||||
verbose=True,
|
||||
mode='min'
|
||||
)
|
||||
self.enable_early_stop = True
|
||||
elif early_stop_callback is None:
|
||||
self.early_stop_callback = EarlyStopping(
|
||||
monitor='val_loss',
|
||||
patience=3,
|
||||
strict=False,
|
||||
verbose=False,
|
||||
mode='min'
|
||||
)
|
||||
self.enable_early_stop = True
|
||||
elif not early_stop_callback:
|
||||
self.early_stop_callback = None
|
||||
self.enable_early_stop = False
|
||||
|
|
|
@ -52,7 +52,7 @@ class Trainer(TrainerIOMixin,
|
|||
self,
|
||||
logger=True,
|
||||
checkpoint_callback=True,
|
||||
early_stop_callback=True,
|
||||
early_stop_callback=None,
|
||||
default_save_path=None,
|
||||
gradient_clip_val=0,
|
||||
gradient_clip=None, # backward compatible, todo: remove in v0.8.0
|
||||
|
@ -121,7 +121,13 @@ class Trainer(TrainerIOMixin,
|
|||
)
|
||||
|
||||
trainer = Trainer(checkpoint_callback=checkpoint_callback)
|
||||
early_stop_callback (:class:`.EarlyStopping`): Callback for early stopping
|
||||
early_stop_callback (:class:`.EarlyStopping`): Callback for early stopping. If
|
||||
set to ``True``, then the default callback monitoring ``'val_loss'`` is created.
|
||||
Will raise an error if ``'val_loss'`` is not found.
|
||||
If set to ``False``, then early stopping will be disabled.
|
||||
If set to ``None``, then the default callback monitoring ``'val_loss'`` is created.
|
||||
If ``'val_loss'`` is not found will work as if early stopping is disabled.
|
||||
Default: ``None``.
|
||||
Example::
|
||||
from pytorch_lightning.callbacks import EarlyStopping
|
||||
|
||||
|
@ -129,7 +135,8 @@ class Trainer(TrainerIOMixin,
|
|||
early_stop_callback = EarlyStopping(
|
||||
monitor='val_loss',
|
||||
patience=3,
|
||||
verbose=True,
|
||||
strict=False,
|
||||
verbose=False,
|
||||
mode='min'
|
||||
)
|
||||
|
||||
|
@ -809,12 +816,17 @@ class Trainer(TrainerIOMixin,
|
|||
# dummy validation progress bar
|
||||
self.val_progress_bar = tqdm.tqdm(disable=True)
|
||||
|
||||
self.evaluate(model, self.get_val_dataloaders(), self.num_sanity_val_steps, self.testing)
|
||||
eval_results = self.evaluate(model, self.get_val_dataloaders(),
|
||||
self.num_sanity_val_steps, False)
|
||||
_, _, _, callback_metrics, _ = self.process_output(eval_results)
|
||||
|
||||
# close progress bars
|
||||
self.main_progress_bar.close()
|
||||
self.val_progress_bar.close()
|
||||
|
||||
if self.enable_early_stop:
|
||||
self.early_stop_callback.check_metrics(callback_metrics)
|
||||
|
||||
# init progress bar
|
||||
pbar = tqdm.tqdm(leave=True, position=2 * self.process_position,
|
||||
disable=not self.show_progress_bar, dynamic_ncols=True, unit='batch',
|
||||
|
|
|
@ -346,7 +346,8 @@ class TrainerTrainLoopMixin(ABC):
|
|||
|
||||
# early stopping
|
||||
met_min_epochs = epoch >= self.min_epochs - 1
|
||||
if self.enable_early_stop and (met_min_epochs or self.fast_dev_run):
|
||||
if (self.enable_early_stop and not self.disable_validation and
|
||||
(met_min_epochs or self.fast_dev_run)):
|
||||
should_stop = self.early_stop_callback.on_epoch_end(epoch=epoch,
|
||||
logs=self.callback_metrics)
|
||||
# stop training
|
||||
|
@ -401,6 +402,9 @@ class TrainerTrainLoopMixin(ABC):
|
|||
if self.fast_dev_run or should_check_val:
|
||||
self.run_evaluation(test=self.testing)
|
||||
|
||||
if self.enable_early_stop:
|
||||
self.early_stop_callback.check_metrics(self.callback_metrics)
|
||||
|
||||
# when logs should be saved
|
||||
should_save_log = (batch_idx + 1) % self.log_save_interval == 0 or early_stop_epoch
|
||||
if should_save_log or self.fast_dev_run:
|
||||
|
|
|
@ -140,7 +140,8 @@ def test_running_test_without_val(tmpdir):
|
|||
val_percent_check=0.2,
|
||||
test_percent_check=0.2,
|
||||
checkpoint_callback=checkpoint,
|
||||
logger=logger
|
||||
logger=logger,
|
||||
early_stop_callback=False
|
||||
)
|
||||
|
||||
# fit model
|
||||
|
@ -318,6 +319,7 @@ def test_tbptt_cpu_model(tmpdir):
|
|||
truncated_bptt_steps=truncated_bptt_steps,
|
||||
val_percent_check=0,
|
||||
weights_summary=None,
|
||||
early_stop_callback=False
|
||||
)
|
||||
|
||||
hparams = tutils.get_hparams()
|
||||
|
|
|
@ -392,7 +392,7 @@ def test_multiple_test_dataloader(tmpdir):
|
|||
default_save_path=tmpdir,
|
||||
max_epochs=1,
|
||||
val_percent_check=0.1,
|
||||
train_percent_check=0.2,
|
||||
train_percent_check=0.2
|
||||
)
|
||||
|
||||
# fit model
|
||||
|
|
Loading…
Reference in New Issue