diff --git a/docs/source/early_stopping.rst b/docs/source/early_stopping.rst index e94cb079a8..e74a720b30 100644 --- a/docs/source/early_stopping.rst +++ b/docs/source/early_stopping.rst @@ -1,15 +1,21 @@ Early stopping ============== -Default behavior ----------------- +Stopping an epoch early +----------------------- +You can stop an epoch early by overriding :meth:`~pytorch_lightning.core.lightning.LightningModule.on_batch_start` to return `-1` when some condition is met. + +If you do this repeatedly, for every epoch you had originally requested, then this will stop your entire run. + +Default Epoch End Callback Behavior +----------------------------------- By default early stopping will be enabled if `'val_loss'` is found in :meth:`~pytorch_lightning.core.lightning.LightningModule.validation_epoch_end`'s return dict. Otherwise training will proceed with early stopping disabled. -Enable Early Stopping ---------------------- -There are two ways to enable early stopping. +Enable Early Stopping using Callbacks on epoch end +-------------------------------------------------- +There are two ways to enable early stopping using callbacks on epoch end. .. doctest:: @@ -39,8 +45,8 @@ is not defined. .. seealso:: :class:`~pytorch_lightning.trainer.trainer.Trainer` -Disable Early Stopping ----------------------- +Disable Early Stopping with callbacks on epoch end +-------------------------------------------------- To disable early stopping pass ``False`` to the :paramref:`~pytorch_lightning.trainer.trainer.Trainer.early_stop_callback`. Note that ``None`` will not disable early stopping but will lead to the diff --git a/pytorch_lightning/core/hooks.py b/pytorch_lightning/core/hooks.py index bfac96d205..9c4b999d14 100644 --- a/pytorch_lightning/core/hooks.py +++ b/pytorch_lightning/core/hooks.py @@ -51,6 +51,8 @@ class ModelHooks(torch.nn.Module): def on_batch_start(self, batch: Any) -> None: """Called in the training loop before anything happens for that batch. + If you return -1 here, you will skip training for the rest of the current epoch. + :param batch: """ # do something when the batch starts diff --git a/pytorch_lightning/core/lightning.py b/pytorch_lightning/core/lightning.py index f0413d4de8..ee0b2463ac 100644 --- a/pytorch_lightning/core/lightning.py +++ b/pytorch_lightning/core/lightning.py @@ -221,9 +221,6 @@ class LightningModule(ABC, GradInformation, ModelIO, ModelHooks): "hiddens": hiddens # remember to detach() this } - You can also return a -1 instead of a dict to stop the current loop. This is useful - if you want to break out of the current training epoch early. - Notes: The loss value shown in the progress bar is smoothed (averaged) over the last values, so it differs from the actual loss returned in train/validation step.