From dd2f5a02120e247a2c6033c9cfb78df796374bfe Mon Sep 17 00:00:00 2001 From: Sean Naren Date: Thu, 25 Feb 2021 15:44:55 +0000 Subject: [PATCH] Fix for multiple callbacks (#6197) * Fix for multiple callbacks * Add CHANGELOG.md * Remove old params * Skip tests on windows using ddp * Change name of the variable to not clash with should stop, which is separate * Apply suggestions from code review * Fix params Co-authored-by: Jirka Borovec --- CHANGELOG.md | 5 +- pytorch_lightning/callbacks/early_stopping.py | 7 +-- tests/callbacks/test_early_stopping.py | 55 +++++++++++++++++++ 3 files changed, 61 insertions(+), 6 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 98332ee496..925106c035 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -28,7 +28,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Removed deprecated metrics ([#6161](https://github.com/PyTorchLightning/pytorch-lightning/pull/6161)) * from `pytorch_lightning.metrics.functional.classification` removed `to_onehot`, `to_categorical`, `get_num_classes`, `roc`, `multiclass_roc`, `average_precision`, `precision_recall_curve`, `multiclass_precision_recall_curve` - * from `pytorch_lightning.metrics.functional.reduction` removed `reduce`, `class_reduce` + * from `pytorch_lightning.metrics.functional.reduction` removed `reduce`, `class_reduce` - Removed deprecated `ModelCheckpoint` arguments `prefix`, `mode="auto"` ([#6162](https://github.com/PyTorchLightning/pytorch-lightning/pull/6162)) @@ -54,6 +54,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixed epoch level schedulers not being called when `val_check_interval < 1.0` ([#6075](https://github.com/PyTorchLightning/pytorch-lightning/pull/6075)) +- Fixed multiple early stopping callbacks ([#6197](https://github.com/PyTorchLightning/pytorch-lightning/pull/6197)) + + ## [1.2.1] - 2021-02-23 ### Fixed diff --git a/pytorch_lightning/callbacks/early_stopping.py b/pytorch_lightning/callbacks/early_stopping.py index d188aebe96..6a2b75c7de 100644 --- a/pytorch_lightning/callbacks/early_stopping.py +++ b/pytorch_lightning/callbacks/early_stopping.py @@ -158,15 +158,12 @@ class EarlyStopping(Callback): if self.monitor_op(current - self.min_delta, self.best_score): self.best_score = current self.wait_count = 0 - should_stop = False else: self.wait_count += 1 - should_stop = self.wait_count >= self.patience - if bool(should_stop): + if self.wait_count >= self.patience: self.stopped_epoch = trainer.current_epoch trainer.should_stop = True # stop every ddp process if any world process decides to stop - should_stop = trainer.training_type_plugin.reduce_early_stopping_decision(should_stop) - trainer.should_stop = should_stop + trainer.should_stop = trainer.training_type_plugin.reduce_early_stopping_decision(trainer.should_stop) diff --git a/tests/callbacks/test_early_stopping.py b/tests/callbacks/test_early_stopping.py index f36b68c0da..6470e1837d 100644 --- a/tests/callbacks/test_early_stopping.py +++ b/tests/callbacks/test_early_stopping.py @@ -13,6 +13,7 @@ # limitations under the License. import os import pickle +import sys from unittest import mock import cloudpickle @@ -344,3 +345,57 @@ def test_min_steps_override_early_stopping_functionality(tmpdir, step_freeze, mi def test_early_stopping_mode_options(): with pytest.raises(MisconfigurationException, match="`mode` can be .* got unknown_option"): EarlyStopping(mode="unknown_option") + + +class EarlyStoppingModel(BoringModel): + + def __init__(self, expected_end_epoch): + super().__init__() + self.expected_end_epoch = expected_end_epoch + + def validation_epoch_end(self, outputs): + losses = [8, 4, 2, 3, 4, 5, 8, 10] + val_loss = losses[self.current_epoch] + self.log('abc', torch.tensor(val_loss)) + self.log('cba', torch.tensor(0)) + + def on_train_end(self) -> None: + assert self.trainer.current_epoch == self.expected_end_epoch, 'Early Stopping Failed' + + +@pytest.mark.parametrize( + "callbacks, expected_stop_epoch, accelerator, num_processes", + [ + ([EarlyStopping(monitor='abc'), EarlyStopping(monitor='cba', patience=3)], 3, None, 1), + ([EarlyStopping(monitor='cba', patience=3), + EarlyStopping(monitor='abc')], 3, None, 1), + pytest.param([EarlyStopping(monitor='abc'), + EarlyStopping(monitor='cba', patience=3)], + 3, + 'ddp_cpu', + 2, + marks=pytest.mark.skipif(sys.platform == "win32", reason="DDP not available on windows")), + pytest.param([EarlyStopping(monitor='cba', patience=3), + EarlyStopping(monitor='abc')], + 3, + 'ddp_cpu', + 2, + marks=pytest.mark.skipif(sys.platform == "win32", reason="DDP not available on windows")), + ], +) +def test_multiple_early_stopping_callbacks(callbacks, expected_stop_epoch, accelerator, num_processes, tmpdir): + """ + Ensure when using multiple early stopping callbacks we stop if any signals we should stop. + """ + + model = EarlyStoppingModel(expected_stop_epoch) + + trainer = Trainer( + default_root_dir=tmpdir, + callbacks=callbacks, + overfit_batches=0.20, + max_epochs=20, + accelerator=accelerator, + num_processes=num_processes + ) + trainer.fit(model)