diff --git a/docs/source/early_stopping.rst b/docs/source/early_stopping.rst index 643a2d5160..6e88728a3e 100644 --- a/docs/source/early_stopping.rst +++ b/docs/source/early_stopping.rst @@ -33,9 +33,17 @@ callback can be used to monitor a validation metric and stop the training when n There are two ways to enable the EarlyStopping callback: - Set `early_stop_callback=True`. - The callback will look for 'val_loss' in the dict returned by - :meth:`~pytorch_lightning.core.lightning.LightningModule.validation_epoch_end` - and raise an error if `val_loss` is not present. + If a dict is returned by + :meth:`~pytorch_lightning.core.lightning.LightningModule.validation_epoch_end`, + the callback will look for `val_loss` in the dict + and display a warning if `val_loss` is not present. + Otherwise, if a :class:`~pytorch_lightning.core.step_result.Result` is returned by + :meth:`~pytorch_lightning.core.lightning.LightningModule.validation_epoch_end`, + :meth:`~pytorch_lightning.core.lightning.LightningModule.validation_step` or + :meth:`~pytorch_lightning.core.lightning.LightningModule.training_step`, + the `early_stop_on` metric, specified in the initialization of the + :class:`~pytorch_lightning.core.step_result.Result` object is used + and display a warning if it was not specified. .. testcode:: diff --git a/pytorch_lightning/core/step_result.py b/pytorch_lightning/core/step_result.py index a9a96abaf6..16250f62c3 100644 --- a/pytorch_lightning/core/step_result.py +++ b/pytorch_lightning/core/step_result.py @@ -501,7 +501,16 @@ class TrainResult(Result): Args: minimize: Metric currently being minimized. early_stop_on: Metric to early stop on. + Should be a one element tensor if combined with default + :class:`~pytorch_lightning.callbacks.early_stopping.EarlyStopping`. + If this result is returned by + :meth:`~pytorch_lightning.core.lightning.LightningModule.training_step`, + the specified value will be averaged across all steps. checkpoint_on: Metric to checkpoint on. + Should be a one element tensor if combined with default checkpoint callback. + If this result is returned by + :meth:`~pytorch_lightning.core.lightning.LightningModule.training_step`, + the specified value will be averaged across all steps. hiddens: """ @@ -656,7 +665,16 @@ class EvalResult(Result): Args: early_stop_on: Metric to early stop on. + Should be a one element tensor if combined with default + :class:`~pytorch_lightning.callbacks.early_stopping.EarlyStopping`. + If this result is returned by + :meth:`~pytorch_lightning.core.lightning.LightningModule.validation_step`, + the specified value will be averaged across all steps. checkpoint_on: Metric to checkpoint on. + Should be a one element tensor if combined with default checkpoint callback. + If this result is returned by + :meth:`~pytorch_lightning.core.lightning.LightningModule.validation_step`, + the specified value will be averaged across all steps. hiddens: """ diff --git a/pytorch_lightning/trainer/__init__.py b/pytorch_lightning/trainer/__init__.py index 011b1b27cb..dd30d18281 100644 --- a/pytorch_lightning/trainer/__init__.py +++ b/pytorch_lightning/trainer/__init__.py @@ -411,10 +411,13 @@ early_stop_callback Callback for early stopping. early_stop_callback (:class:`pytorch_lightning.callbacks.EarlyStopping`) -- ``True``: A default callback monitoring ``'val_loss'`` is created. - Will raise an error if ``'val_loss'`` is not found. +- ``True``: A default callback monitoring ``'val_loss'`` (if dict is returned in validation loop) or + ``early_stopping_on`` (if :class:`~pytorch_lightning.core.step_result.Result` is returned) is created. + Will raise an error if a dictionary is returned and ``'val_loss'`` is not found. + Will raise an error if a :class:`~pytorch_lightning.core.step_result.Result` is returned + and ``early_stopping_on`` was not specified. - ``False``: Early stopping will be disabled. -- ``None``: The default callback monitoring ``'val_loss'`` is created. +- ``None``: Same as, if ``True`` is specified. - Default: ``None``. .. testcode::