From 1aba411da96ed95419d13ec1f86a0d38a232f73e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Tue, 31 Mar 2020 08:24:26 +0200 Subject: [PATCH] Early stopping when validation is disabled (#1235) * early stop fallback to train epoch * added test * fix imports * update docs * update changelog * fix typo --- CHANGELOG.md | 1 + docs/source/early_stopping.rst | 15 ++++++++-- pytorch_lightning/trainer/training_loop.py | 4 +-- tests/trainer/test_callbacks.py | 35 ++++++++++++++++++++-- 4 files changed, 48 insertions(+), 7 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index fb7535fb5e..590095fa7f 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -43,6 +43,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixed all warnings and errors in the docs build process ([#1191](https://github.com/PyTorchLightning/pytorch-lightning/pull/1191)) - Fixed an issue where `val_percent_check=0` would not disable validation ([#1251](https://github.com/PyTorchLightning/pytorch-lightning/pull/1251)) - Fixed average of incomplete `TensorRunningMean` ([#1309](https://github.com/PyTorchLightning/pytorch-lightning/pull/1309)) +- Fixed an issue with early stopping that would prevent it from monitoring training metrics when validation is disabled / not implemented ([#1235](https://github.com/PyTorchLightning/pytorch-lightning/pull/1235)). ## [0.7.1] - 2020-03-07 diff --git a/docs/source/early_stopping.rst b/docs/source/early_stopping.rst index 3410fc5c0a..585627a3b0 100644 --- a/docs/source/early_stopping.rst +++ b/docs/source/early_stopping.rst @@ -4,8 +4,8 @@ Early stopping Default behavior ---------------- By default early stopping will be enabled if `'val_loss'` -is found in `validation_epoch_end()` return dict. Otherwise -training will proceed with early stopping disabled. +is found in :meth:`~pytorch_lightning.core.lightning.LightningModule.validation_epoch_end`'s +return dict. Otherwise training will proceed with early stopping disabled. Enable Early Stopping --------------------- @@ -30,9 +30,18 @@ There are two ways to enable early stopping. ) trainer = Trainer(early_stop_callback=early_stop_callback) +In any case, the callback will fall back to the training metrics (returned in +:meth:`~pytorch_lightning.core.lightning.LightningModule.training_step`, +:meth:`~pytorch_lightning.core.lightning.LightningModule.training_step_end`) +looking for a key to monitor if validation is disabled or +:meth:`~pytorch_lightning.core.lightning.LightningModule.validation_epoch_end` +is not defined. + + Disable Early Stopping ---------------------- -To disable early stopping pass ``False`` to the `early_stop_callback`. +To disable early stopping pass ``False`` to the +:paramref:`~pytorch_lightning.trainer.trainer.Trainer.early_stop_callback`. Note that ``None`` will not disable early stopping but will lead to the default behaviour. diff --git a/pytorch_lightning/trainer/training_loop.py b/pytorch_lightning/trainer/training_loop.py index d0a8ee1eb1..f01e2294c4 100644 --- a/pytorch_lightning/trainer/training_loop.py +++ b/pytorch_lightning/trainer/training_loop.py @@ -367,8 +367,8 @@ class TrainerTrainLoopMixin(ABC): met_min_steps = self.global_step >= self.min_steps if self.min_steps else True # TODO wrap this logic into the callback - if self.enable_early_stop and not self.disable_validation and is_val_epoch: - if ((met_min_epochs and met_min_steps) or self.fast_dev_run): + if self.enable_early_stop: + if (met_min_epochs and met_min_steps) or self.fast_dev_run: should_stop = self.early_stop_callback.on_epoch_end(self, self.get_model()) # stop training stop = should_stop and met_min_epochs diff --git a/tests/trainer/test_callbacks.py b/tests/trainer/test_callbacks.py index 377dce7666..5d0c396186 100644 --- a/tests/trainer/test_callbacks.py +++ b/tests/trainer/test_callbacks.py @@ -1,11 +1,12 @@ import tests.base.utils as tutils from pytorch_lightning import Callback from pytorch_lightning import Trainer, LightningModule +from pytorch_lightning.callbacks import EarlyStopping from tests.base import ( - TestModelBase, LightTrainDataloader, + LightTestMixin, LightValidationMixin, - LightTestMixin + TestModelBase ) @@ -150,3 +151,33 @@ def test_trainer_callback_system(tmpdir): assert test_callback.on_test_start_called assert test_callback.on_test_end_called + + +def test_early_stopping_without_val_step(tmpdir): + """Test that early stopping callback falls back to training metrics when no validation defined.""" + tutils.reset_seed() + + class ModelWithoutValStep(LightTrainDataloader, TestModelBase): + + def training_step(self, *args, **kwargs): + output = super().training_step(*args, **kwargs) + loss = output['loss'] # could be anything else + output.update({'my_train_metric': loss}) + return output + + hparams = tutils.get_default_hparams() + model = ModelWithoutValStep(hparams) + + stopping = EarlyStopping(monitor='my_train_metric', min_delta=0.1) + trainer_options = dict( + default_save_path=tmpdir, + early_stop_callback=stopping, + overfit_pct=0.20, + max_epochs=10, + ) + + trainer = Trainer(**trainer_options) + result = trainer.fit(model) + + assert result == 1, 'training failed to complete' + assert trainer.current_epoch < trainer.max_epochs - 1