early stopping checks on_validation_end (#1458)
* Fixes PyTorchLightning/pytorch-lightning#490 `EarlyStopping` should check the metric of interest `on_validation_end` rather than `on_epoch_end`. In a normal scenario, this does not cause a problem, but in combination with `check_val_every_n_epoch>1` in the `Trainer` it results in a warning or in a `RuntimeError` depending on `strict`. * Highlighted that ES callback runs on val epochs in docstring * Updated EarlyStopping in rst doc * Update early_stopping.py * Update early_stopping.rst * Update early_stopping.rst * Update early_stopping.rst * Update early_stopping.rst * Apply suggestions from code review Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com> * Update docs/source/early_stopping.rst * fix doctest indentation warning * Train loop calls early_stop.on_validation_end * chlog Co-authored-by: William Falcon <waf2107@columbia.edu> Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com> Co-authored-by: Jirka <jirka@pytorchlightning.ai>
This commit is contained in:
parent
8ca8336ce5
commit
65b4352930
|
@ -16,6 +16,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
|
|||
|
||||
- Added option `save_last` to save the model at the end of every epoch in `ModelCheckpoint` [(#1908)](https://github.com/PyTorchLightning/pytorch-lightning/pull/1908)
|
||||
|
||||
- Early stopping checks `on_validation_end` ([#1458](https://github.com/PyTorchLightning/pytorch-lightning/pull/1458))
|
||||
|
||||
### Changed
|
||||
|
||||
- Allow user to select individual TPU core to train on ([#1729](https://github.com/PyTorchLightning/pytorch-lightning/pull/1729))
|
||||
|
|
|
@ -19,36 +19,63 @@ By default early stopping will be enabled if `'val_loss'`
|
|||
is found in :meth:`~pytorch_lightning.core.lightning.LightningModule.validation_epoch_end`'s
|
||||
return dict. Otherwise training will proceed with early stopping disabled.
|
||||
|
||||
Enable Early Stopping using Callbacks on epoch end
|
||||
--------------------------------------------------
|
||||
There are two ways to enable early stopping using callbacks on epoch end.
|
||||
Enable Early Stopping using the EarlyStopping Callback
|
||||
------------------------------------------------------
|
||||
The
|
||||
:class:`~pytorch_lightning.callbacks.early_stopping.EarlyStopping`
|
||||
callback can be used to monitor a validation metric and stop the training when no improvement is observed.
|
||||
|
||||
- Set early_stop_callback to True. Will look for 'val_loss' in validation_epoch_end() return dict.
|
||||
If it is not found an error is raised.
|
||||
There are two ways to enable the EarlyStopping callback:
|
||||
|
||||
- Set `early_stop_callback=True`.
|
||||
The callback will look for 'val_loss' in the dict returned by
|
||||
:meth:`~pytorch_lightning.core.lightning.LightningModule.validation_epoch_end`
|
||||
and raise an error if `val_loss` is not present.
|
||||
|
||||
.. testcode::
|
||||
|
||||
trainer = Trainer(early_stop_callback=True)
|
||||
|
||||
- Or configure your own callback
|
||||
- Create the callback object and pass it to the trainer.
|
||||
This allows for further customization.
|
||||
|
||||
.. testcode::
|
||||
|
||||
early_stop_callback = EarlyStopping(
|
||||
monitor='val_loss',
|
||||
min_delta=0.00,
|
||||
patience=3,
|
||||
verbose=False,
|
||||
mode='min'
|
||||
monitor='val_accuracy',
|
||||
min_delta=0.00,
|
||||
patience=3,
|
||||
verbose=False,
|
||||
mode='max'
|
||||
)
|
||||
trainer = Trainer(early_stop_callback=early_stop_callback)
|
||||
|
||||
In any case, the callback will fall back to the training metrics (returned in
|
||||
:meth:`~pytorch_lightning.core.lightning.LightningModule.training_step`,
|
||||
:meth:`~pytorch_lightning.core.lightning.LightningModule.training_step_end`)
|
||||
looking for a key to monitor if validation is disabled or
|
||||
:meth:`~pytorch_lightning.core.lightning.LightningModule.validation_epoch_end`
|
||||
is not defined.
|
||||
In case you need early stopping in a different part of training, subclass EarlyStopping
|
||||
and change where it is called:
|
||||
|
||||
.. testcode::
|
||||
|
||||
class MyEarlyStopping(EarlyStopping):
|
||||
|
||||
def on_validation_end(self, trainer, pl_module):
|
||||
# override this to disable early stopping at the end of val loop
|
||||
pass
|
||||
|
||||
def on_train_end(self, trainer, pl_module):
|
||||
# instead, do it at the end of training loop
|
||||
self._run_early_stopping_check(trainer, pl_module)
|
||||
|
||||
.. note::
|
||||
The EarlyStopping callback runs at the end of every validation epoch,
|
||||
which, under the default configuration, happen after every training epoch.
|
||||
However, the frequency of validation can be modified by setting various parameters
|
||||
on the :class:`~pytorch_lightning.trainer.trainer.Trainer`,
|
||||
for example :paramref:`~pytorch_lightning.trainer.trainer.Trainer.check_val_every_n_epoch`
|
||||
and :paramref:`~pytorch_lightning.trainer.trainer.Trainer.val_check_interval`.
|
||||
It must be noted that the `patience` parameter counts the number of
|
||||
validation epochs with no improvement, and not the number of training epochs.
|
||||
Therefore, with parameters `check_val_every_n_epoch=10` and `patience=3`, the trainer
|
||||
will perform at least 40 training epochs before being stopped.
|
||||
|
||||
.. seealso::
|
||||
- :class:`~pytorch_lightning.trainer.trainer.Trainer`
|
||||
|
|
|
@ -2,7 +2,7 @@ r"""
|
|||
Early Stopping
|
||||
==============
|
||||
|
||||
Stop training when a monitored quantity has stopped improving.
|
||||
Monitor a validation metric and stop training when it stops improving.
|
||||
|
||||
"""
|
||||
|
||||
|
@ -25,7 +25,7 @@ class EarlyStopping(Callback):
|
|||
to qualify as an improvement, i.e. an absolute
|
||||
change of less than `min_delta`, will count as no
|
||||
improvement. Default: ``0``.
|
||||
patience: number of epochs with no improvement
|
||||
patience: number of validation epochs with no improvement
|
||||
after which training will be stopped. Default: ``0``.
|
||||
verbose: verbosity mode. Default: ``False``.
|
||||
mode: one of {auto, min, max}. In `min` mode,
|
||||
|
@ -36,7 +36,7 @@ class EarlyStopping(Callback):
|
|||
mode, the direction is automatically inferred
|
||||
from the name of the monitored quantity. Default: ``'auto'``.
|
||||
strict: whether to crash the training if `monitor` is
|
||||
not found in the metrics. Default: ``True``.
|
||||
not found in the validation metrics. Default: ``True``.
|
||||
|
||||
Example::
|
||||
|
||||
|
@ -109,7 +109,10 @@ class EarlyStopping(Callback):
|
|||
self.stopped_epoch = 0
|
||||
self.best = torch_inf if self.monitor_op == torch.lt else -torch_inf
|
||||
|
||||
def on_epoch_end(self, trainer, pl_module):
|
||||
def on_validation_end(self, trainer, pl_module):
|
||||
self._run_early_stopping_check(trainer, pl_module)
|
||||
|
||||
def _run_early_stopping_check(self, trainer, pl_module):
|
||||
logs = trainer.callback_metrics
|
||||
stop_training = False
|
||||
if not self._validate_condition_metric(logs):
|
||||
|
|
|
@ -362,7 +362,7 @@ class TrainerTrainLoopMixin(ABC):
|
|||
# TODO wrap this logic into the callback
|
||||
if self.enable_early_stop:
|
||||
if (met_min_epochs and met_min_steps) or self.fast_dev_run:
|
||||
should_stop = self.early_stop_callback.on_epoch_end(self, self.get_model())
|
||||
should_stop = self.early_stop_callback.on_validation_end(self, self.get_model())
|
||||
# stop training
|
||||
stop = should_stop and met_min_epochs
|
||||
if stop:
|
||||
|
|
Loading…
Reference in New Issue