diff --git a/CHANGELOG.md b/CHANGELOG.md index e8c71252c1..f87b746a14 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -16,6 +16,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Added option `save_last` to save the model at the end of every epoch in `ModelCheckpoint` [(#1908)](https://github.com/PyTorchLightning/pytorch-lightning/pull/1908) +- Early stopping checks `on_validation_end` ([#1458](https://github.com/PyTorchLightning/pytorch-lightning/pull/1458)) + ### Changed - Allow user to select individual TPU core to train on ([#1729](https://github.com/PyTorchLightning/pytorch-lightning/pull/1729)) diff --git a/docs/source/early_stopping.rst b/docs/source/early_stopping.rst index a0bfc83ec2..7e949c0f7f 100644 --- a/docs/source/early_stopping.rst +++ b/docs/source/early_stopping.rst @@ -19,36 +19,63 @@ By default early stopping will be enabled if `'val_loss'` is found in :meth:`~pytorch_lightning.core.lightning.LightningModule.validation_epoch_end`'s return dict. Otherwise training will proceed with early stopping disabled. -Enable Early Stopping using Callbacks on epoch end --------------------------------------------------- -There are two ways to enable early stopping using callbacks on epoch end. +Enable Early Stopping using the EarlyStopping Callback +------------------------------------------------------ +The +:class:`~pytorch_lightning.callbacks.early_stopping.EarlyStopping` +callback can be used to monitor a validation metric and stop the training when no improvement is observed. -- Set early_stop_callback to True. Will look for 'val_loss' in validation_epoch_end() return dict. - If it is not found an error is raised. +There are two ways to enable the EarlyStopping callback: + +- Set `early_stop_callback=True`. + The callback will look for 'val_loss' in the dict returned by + :meth:`~pytorch_lightning.core.lightning.LightningModule.validation_epoch_end` + and raise an error if `val_loss` is not present. .. testcode:: trainer = Trainer(early_stop_callback=True) -- Or configure your own callback +- Create the callback object and pass it to the trainer. + This allows for further customization. .. testcode:: early_stop_callback = EarlyStopping( - monitor='val_loss', - min_delta=0.00, - patience=3, - verbose=False, - mode='min' + monitor='val_accuracy', + min_delta=0.00, + patience=3, + verbose=False, + mode='max' ) trainer = Trainer(early_stop_callback=early_stop_callback) -In any case, the callback will fall back to the training metrics (returned in -:meth:`~pytorch_lightning.core.lightning.LightningModule.training_step`, -:meth:`~pytorch_lightning.core.lightning.LightningModule.training_step_end`) -looking for a key to monitor if validation is disabled or -:meth:`~pytorch_lightning.core.lightning.LightningModule.validation_epoch_end` -is not defined. +In case you need early stopping in a different part of training, subclass EarlyStopping +and change where it is called: + +.. testcode:: + + class MyEarlyStopping(EarlyStopping): + + def on_validation_end(self, trainer, pl_module): + # override this to disable early stopping at the end of val loop + pass + + def on_train_end(self, trainer, pl_module): + # instead, do it at the end of training loop + self._run_early_stopping_check(trainer, pl_module) + +.. note:: + The EarlyStopping callback runs at the end of every validation epoch, + which, under the default configuration, happen after every training epoch. + However, the frequency of validation can be modified by setting various parameters + on the :class:`~pytorch_lightning.trainer.trainer.Trainer`, + for example :paramref:`~pytorch_lightning.trainer.trainer.Trainer.check_val_every_n_epoch` + and :paramref:`~pytorch_lightning.trainer.trainer.Trainer.val_check_interval`. + It must be noted that the `patience` parameter counts the number of + validation epochs with no improvement, and not the number of training epochs. + Therefore, with parameters `check_val_every_n_epoch=10` and `patience=3`, the trainer + will perform at least 40 training epochs before being stopped. .. seealso:: - :class:`~pytorch_lightning.trainer.trainer.Trainer` diff --git a/pytorch_lightning/callbacks/early_stopping.py b/pytorch_lightning/callbacks/early_stopping.py index 100c317172..61abfb879a 100644 --- a/pytorch_lightning/callbacks/early_stopping.py +++ b/pytorch_lightning/callbacks/early_stopping.py @@ -2,7 +2,7 @@ r""" Early Stopping ============== -Stop training when a monitored quantity has stopped improving. +Monitor a validation metric and stop training when it stops improving. """ @@ -25,7 +25,7 @@ class EarlyStopping(Callback): to qualify as an improvement, i.e. an absolute change of less than `min_delta`, will count as no improvement. Default: ``0``. - patience: number of epochs with no improvement + patience: number of validation epochs with no improvement after which training will be stopped. Default: ``0``. verbose: verbosity mode. Default: ``False``. mode: one of {auto, min, max}. In `min` mode, @@ -36,7 +36,7 @@ class EarlyStopping(Callback): mode, the direction is automatically inferred from the name of the monitored quantity. Default: ``'auto'``. strict: whether to crash the training if `monitor` is - not found in the metrics. Default: ``True``. + not found in the validation metrics. Default: ``True``. Example:: @@ -109,7 +109,10 @@ class EarlyStopping(Callback): self.stopped_epoch = 0 self.best = torch_inf if self.monitor_op == torch.lt else -torch_inf - def on_epoch_end(self, trainer, pl_module): + def on_validation_end(self, trainer, pl_module): + self._run_early_stopping_check(trainer, pl_module) + + def _run_early_stopping_check(self, trainer, pl_module): logs = trainer.callback_metrics stop_training = False if not self._validate_condition_metric(logs): diff --git a/pytorch_lightning/trainer/training_loop.py b/pytorch_lightning/trainer/training_loop.py index 61358731f7..55c63679ae 100644 --- a/pytorch_lightning/trainer/training_loop.py +++ b/pytorch_lightning/trainer/training_loop.py @@ -362,7 +362,7 @@ class TrainerTrainLoopMixin(ABC): # TODO wrap this logic into the callback if self.enable_early_stop: if (met_min_epochs and met_min_steps) or self.fast_dev_run: - should_stop = self.early_stop_callback.on_epoch_end(self, self.get_model()) + should_stop = self.early_stop_callback.on_validation_end(self, self.get_model()) # stop training stop = should_stop and met_min_epochs if stop: