Early stopping when validation is disabled (#1235)

* early stop fallback to train epoch

* added test

* fix imports

* update docs

* update changelog

* fix typo
This commit is contained in:
Adrian Wälchli 2020-03-31 08:24:26 +02:00 committed by GitHub
parent a707d4bea1
commit 1aba411da9
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 48 additions and 7 deletions

View File

@ -43,6 +43,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Fixed all warnings and errors in the docs build process ([#1191](https://github.com/PyTorchLightning/pytorch-lightning/pull/1191))
- Fixed an issue where `val_percent_check=0` would not disable validation ([#1251](https://github.com/PyTorchLightning/pytorch-lightning/pull/1251))
- Fixed average of incomplete `TensorRunningMean` ([#1309](https://github.com/PyTorchLightning/pytorch-lightning/pull/1309))
- Fixed an issue with early stopping that would prevent it from monitoring training metrics when validation is disabled / not implemented ([#1235](https://github.com/PyTorchLightning/pytorch-lightning/pull/1235)).
## [0.7.1] - 2020-03-07

View File

@ -4,8 +4,8 @@ Early stopping
Default behavior
----------------
By default early stopping will be enabled if `'val_loss'`
is found in `validation_epoch_end()` return dict. Otherwise
training will proceed with early stopping disabled.
is found in :meth:`~pytorch_lightning.core.lightning.LightningModule.validation_epoch_end`'s
return dict. Otherwise training will proceed with early stopping disabled.
Enable Early Stopping
---------------------
@ -30,9 +30,18 @@ There are two ways to enable early stopping.
)
trainer = Trainer(early_stop_callback=early_stop_callback)
In any case, the callback will fall back to the training metrics (returned in
:meth:`~pytorch_lightning.core.lightning.LightningModule.training_step`,
:meth:`~pytorch_lightning.core.lightning.LightningModule.training_step_end`)
looking for a key to monitor if validation is disabled or
:meth:`~pytorch_lightning.core.lightning.LightningModule.validation_epoch_end`
is not defined.
Disable Early Stopping
----------------------
To disable early stopping pass ``False`` to the `early_stop_callback`.
To disable early stopping pass ``False`` to the
:paramref:`~pytorch_lightning.trainer.trainer.Trainer.early_stop_callback`.
Note that ``None`` will not disable early stopping but will lead to the
default behaviour.

View File

@ -367,8 +367,8 @@ class TrainerTrainLoopMixin(ABC):
met_min_steps = self.global_step >= self.min_steps if self.min_steps else True
# TODO wrap this logic into the callback
if self.enable_early_stop and not self.disable_validation and is_val_epoch:
if ((met_min_epochs and met_min_steps) or self.fast_dev_run):
if self.enable_early_stop:
if (met_min_epochs and met_min_steps) or self.fast_dev_run:
should_stop = self.early_stop_callback.on_epoch_end(self, self.get_model())
# stop training
stop = should_stop and met_min_epochs

View File

@ -1,11 +1,12 @@
import tests.base.utils as tutils
from pytorch_lightning import Callback
from pytorch_lightning import Trainer, LightningModule
from pytorch_lightning.callbacks import EarlyStopping
from tests.base import (
TestModelBase,
LightTrainDataloader,
LightTestMixin,
LightValidationMixin,
LightTestMixin
TestModelBase
)
@ -150,3 +151,33 @@ def test_trainer_callback_system(tmpdir):
assert test_callback.on_test_start_called
assert test_callback.on_test_end_called
def test_early_stopping_without_val_step(tmpdir):
"""Test that early stopping callback falls back to training metrics when no validation defined."""
tutils.reset_seed()
class ModelWithoutValStep(LightTrainDataloader, TestModelBase):
def training_step(self, *args, **kwargs):
output = super().training_step(*args, **kwargs)
loss = output['loss'] # could be anything else
output.update({'my_train_metric': loss})
return output
hparams = tutils.get_default_hparams()
model = ModelWithoutValStep(hparams)
stopping = EarlyStopping(monitor='my_train_metric', min_delta=0.1)
trainer_options = dict(
default_save_path=tmpdir,
early_stop_callback=stopping,
overfit_pct=0.20,
max_epochs=10,
)
trainer = Trainer(**trainer_options)
result = trainer.fit(model)
assert result == 1, 'training failed to complete'
assert trainer.current_epoch < trainer.max_epochs - 1