early stopping checks on_validation_end (#1458)

* Fixes PyTorchLightning/pytorch-lightning#490

`EarlyStopping` should check the metric of interest `on_validation_end` rather than `on_epoch_end`. 
In a normal scenario, this does not cause a problem, but in combination with `check_val_every_n_epoch>1` in the `Trainer` it results in a warning or in a `RuntimeError` depending on `strict`.

* Highlighted that ES callback runs on val epochs in docstring

* Updated EarlyStopping in rst doc

* Update early_stopping.py

* Update early_stopping.rst

* Update early_stopping.rst

* Update early_stopping.rst

* Update early_stopping.rst

* Apply suggestions from code review

Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com>

* Update docs/source/early_stopping.rst

* fix doctest indentation warning

* Train loop calls early_stop.on_validation_end

* chlog

Co-authored-by: William Falcon <waf2107@columbia.edu>
Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>
Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com>
Co-authored-by: Jirka <jirka@pytorchlightning.ai>
This commit is contained in:
Federico Baldassarre 2020-05-25 19:33:00 +02:00 committed by GitHub
parent 8ca8336ce5
commit 65b4352930
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 54 additions and 22 deletions

View File

@ -16,6 +16,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Added option `save_last` to save the model at the end of every epoch in `ModelCheckpoint` [(#1908)](https://github.com/PyTorchLightning/pytorch-lightning/pull/1908)
- Early stopping checks `on_validation_end` ([#1458](https://github.com/PyTorchLightning/pytorch-lightning/pull/1458))
### Changed
- Allow user to select individual TPU core to train on ([#1729](https://github.com/PyTorchLightning/pytorch-lightning/pull/1729))

View File

@ -19,36 +19,63 @@ By default early stopping will be enabled if `'val_loss'`
is found in :meth:`~pytorch_lightning.core.lightning.LightningModule.validation_epoch_end`'s
return dict. Otherwise training will proceed with early stopping disabled.
Enable Early Stopping using Callbacks on epoch end
--------------------------------------------------
There are two ways to enable early stopping using callbacks on epoch end.
Enable Early Stopping using the EarlyStopping Callback
------------------------------------------------------
The
:class:`~pytorch_lightning.callbacks.early_stopping.EarlyStopping`
callback can be used to monitor a validation metric and stop the training when no improvement is observed.
- Set early_stop_callback to True. Will look for 'val_loss' in validation_epoch_end() return dict.
If it is not found an error is raised.
There are two ways to enable the EarlyStopping callback:
- Set `early_stop_callback=True`.
The callback will look for 'val_loss' in the dict returned by
:meth:`~pytorch_lightning.core.lightning.LightningModule.validation_epoch_end`
and raise an error if `val_loss` is not present.
.. testcode::
trainer = Trainer(early_stop_callback=True)
- Or configure your own callback
- Create the callback object and pass it to the trainer.
This allows for further customization.
.. testcode::
early_stop_callback = EarlyStopping(
monitor='val_loss',
min_delta=0.00,
patience=3,
verbose=False,
mode='min'
monitor='val_accuracy',
min_delta=0.00,
patience=3,
verbose=False,
mode='max'
)
trainer = Trainer(early_stop_callback=early_stop_callback)
In any case, the callback will fall back to the training metrics (returned in
:meth:`~pytorch_lightning.core.lightning.LightningModule.training_step`,
:meth:`~pytorch_lightning.core.lightning.LightningModule.training_step_end`)
looking for a key to monitor if validation is disabled or
:meth:`~pytorch_lightning.core.lightning.LightningModule.validation_epoch_end`
is not defined.
In case you need early stopping in a different part of training, subclass EarlyStopping
and change where it is called:
.. testcode::
class MyEarlyStopping(EarlyStopping):
def on_validation_end(self, trainer, pl_module):
# override this to disable early stopping at the end of val loop
pass
def on_train_end(self, trainer, pl_module):
# instead, do it at the end of training loop
self._run_early_stopping_check(trainer, pl_module)
.. note::
The EarlyStopping callback runs at the end of every validation epoch,
which, under the default configuration, happen after every training epoch.
However, the frequency of validation can be modified by setting various parameters
on the :class:`~pytorch_lightning.trainer.trainer.Trainer`,
for example :paramref:`~pytorch_lightning.trainer.trainer.Trainer.check_val_every_n_epoch`
and :paramref:`~pytorch_lightning.trainer.trainer.Trainer.val_check_interval`.
It must be noted that the `patience` parameter counts the number of
validation epochs with no improvement, and not the number of training epochs.
Therefore, with parameters `check_val_every_n_epoch=10` and `patience=3`, the trainer
will perform at least 40 training epochs before being stopped.
.. seealso::
- :class:`~pytorch_lightning.trainer.trainer.Trainer`

View File

@ -2,7 +2,7 @@ r"""
Early Stopping
==============
Stop training when a monitored quantity has stopped improving.
Monitor a validation metric and stop training when it stops improving.
"""
@ -25,7 +25,7 @@ class EarlyStopping(Callback):
to qualify as an improvement, i.e. an absolute
change of less than `min_delta`, will count as no
improvement. Default: ``0``.
patience: number of epochs with no improvement
patience: number of validation epochs with no improvement
after which training will be stopped. Default: ``0``.
verbose: verbosity mode. Default: ``False``.
mode: one of {auto, min, max}. In `min` mode,
@ -36,7 +36,7 @@ class EarlyStopping(Callback):
mode, the direction is automatically inferred
from the name of the monitored quantity. Default: ``'auto'``.
strict: whether to crash the training if `monitor` is
not found in the metrics. Default: ``True``.
not found in the validation metrics. Default: ``True``.
Example::
@ -109,7 +109,10 @@ class EarlyStopping(Callback):
self.stopped_epoch = 0
self.best = torch_inf if self.monitor_op == torch.lt else -torch_inf
def on_epoch_end(self, trainer, pl_module):
def on_validation_end(self, trainer, pl_module):
self._run_early_stopping_check(trainer, pl_module)
def _run_early_stopping_check(self, trainer, pl_module):
logs = trainer.callback_metrics
stop_training = False
if not self._validate_condition_metric(logs):

View File

@ -362,7 +362,7 @@ class TrainerTrainLoopMixin(ABC):
# TODO wrap this logic into the callback
if self.enable_early_stop:
if (met_min_epochs and met_min_steps) or self.fast_dev_run:
should_stop = self.early_stop_callback.on_epoch_end(self, self.get_model())
should_stop = self.early_stop_callback.on_validation_end(self, self.get_model())
# stop training
stop = should_stop and met_min_epochs
if stop: