Early stopping when validation is disabled (#1235)
* early stop fallback to train epoch * added test * fix imports * update docs * update changelog * fix typo
This commit is contained in:
parent
a707d4bea1
commit
1aba411da9
|
@ -43,6 +43,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
|
|||
- Fixed all warnings and errors in the docs build process ([#1191](https://github.com/PyTorchLightning/pytorch-lightning/pull/1191))
|
||||
- Fixed an issue where `val_percent_check=0` would not disable validation ([#1251](https://github.com/PyTorchLightning/pytorch-lightning/pull/1251))
|
||||
- Fixed average of incomplete `TensorRunningMean` ([#1309](https://github.com/PyTorchLightning/pytorch-lightning/pull/1309))
|
||||
- Fixed an issue with early stopping that would prevent it from monitoring training metrics when validation is disabled / not implemented ([#1235](https://github.com/PyTorchLightning/pytorch-lightning/pull/1235)).
|
||||
|
||||
## [0.7.1] - 2020-03-07
|
||||
|
||||
|
|
|
@ -4,8 +4,8 @@ Early stopping
|
|||
Default behavior
|
||||
----------------
|
||||
By default early stopping will be enabled if `'val_loss'`
|
||||
is found in `validation_epoch_end()` return dict. Otherwise
|
||||
training will proceed with early stopping disabled.
|
||||
is found in :meth:`~pytorch_lightning.core.lightning.LightningModule.validation_epoch_end`'s
|
||||
return dict. Otherwise training will proceed with early stopping disabled.
|
||||
|
||||
Enable Early Stopping
|
||||
---------------------
|
||||
|
@ -30,9 +30,18 @@ There are two ways to enable early stopping.
|
|||
)
|
||||
trainer = Trainer(early_stop_callback=early_stop_callback)
|
||||
|
||||
In any case, the callback will fall back to the training metrics (returned in
|
||||
:meth:`~pytorch_lightning.core.lightning.LightningModule.training_step`,
|
||||
:meth:`~pytorch_lightning.core.lightning.LightningModule.training_step_end`)
|
||||
looking for a key to monitor if validation is disabled or
|
||||
:meth:`~pytorch_lightning.core.lightning.LightningModule.validation_epoch_end`
|
||||
is not defined.
|
||||
|
||||
|
||||
Disable Early Stopping
|
||||
----------------------
|
||||
To disable early stopping pass ``False`` to the `early_stop_callback`.
|
||||
To disable early stopping pass ``False`` to the
|
||||
:paramref:`~pytorch_lightning.trainer.trainer.Trainer.early_stop_callback`.
|
||||
Note that ``None`` will not disable early stopping but will lead to the
|
||||
default behaviour.
|
||||
|
||||
|
|
|
@ -367,8 +367,8 @@ class TrainerTrainLoopMixin(ABC):
|
|||
met_min_steps = self.global_step >= self.min_steps if self.min_steps else True
|
||||
|
||||
# TODO wrap this logic into the callback
|
||||
if self.enable_early_stop and not self.disable_validation and is_val_epoch:
|
||||
if ((met_min_epochs and met_min_steps) or self.fast_dev_run):
|
||||
if self.enable_early_stop:
|
||||
if (met_min_epochs and met_min_steps) or self.fast_dev_run:
|
||||
should_stop = self.early_stop_callback.on_epoch_end(self, self.get_model())
|
||||
# stop training
|
||||
stop = should_stop and met_min_epochs
|
||||
|
|
|
@ -1,11 +1,12 @@
|
|||
import tests.base.utils as tutils
|
||||
from pytorch_lightning import Callback
|
||||
from pytorch_lightning import Trainer, LightningModule
|
||||
from pytorch_lightning.callbacks import EarlyStopping
|
||||
from tests.base import (
|
||||
TestModelBase,
|
||||
LightTrainDataloader,
|
||||
LightTestMixin,
|
||||
LightValidationMixin,
|
||||
LightTestMixin
|
||||
TestModelBase
|
||||
)
|
||||
|
||||
|
||||
|
@ -150,3 +151,33 @@ def test_trainer_callback_system(tmpdir):
|
|||
|
||||
assert test_callback.on_test_start_called
|
||||
assert test_callback.on_test_end_called
|
||||
|
||||
|
||||
def test_early_stopping_without_val_step(tmpdir):
|
||||
"""Test that early stopping callback falls back to training metrics when no validation defined."""
|
||||
tutils.reset_seed()
|
||||
|
||||
class ModelWithoutValStep(LightTrainDataloader, TestModelBase):
|
||||
|
||||
def training_step(self, *args, **kwargs):
|
||||
output = super().training_step(*args, **kwargs)
|
||||
loss = output['loss'] # could be anything else
|
||||
output.update({'my_train_metric': loss})
|
||||
return output
|
||||
|
||||
hparams = tutils.get_default_hparams()
|
||||
model = ModelWithoutValStep(hparams)
|
||||
|
||||
stopping = EarlyStopping(monitor='my_train_metric', min_delta=0.1)
|
||||
trainer_options = dict(
|
||||
default_save_path=tmpdir,
|
||||
early_stop_callback=stopping,
|
||||
overfit_pct=0.20,
|
||||
max_epochs=10,
|
||||
)
|
||||
|
||||
trainer = Trainer(**trainer_options)
|
||||
result = trainer.fit(model)
|
||||
|
||||
assert result == 1, 'training failed to complete'
|
||||
assert trainer.current_epoch < trainer.max_epochs - 1
|
||||
|
|
Loading…
Reference in New Issue