From 8b738693694a500f2e1903c962cdd00e3d90e2b6 Mon Sep 17 00:00:00 2001 From: Burhanuddin Rangwala Date: Fri, 11 Jun 2021 06:03:39 +0530 Subject: [PATCH] Deprecate the default `EarlyStopping` callback monitor value (#7907) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * removed monitor default value and added depreceation message * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * format change * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * requested changes * added test * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * format changes * typehint change * Update CHANGELOG.md * requested changes * regex Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Adrian Wälchli --- CHANGELOG.md | 3 +++ pytorch_lightning/callbacks/early_stopping.py | 12 +++++++++--- tests/deprecated_api/test_remove_1-6.py | 9 +++++++++ 3 files changed, 21 insertions(+), 3 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 54f9ad91db..de24cc9daa 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -172,6 +172,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Deprecated `self.log(sync_dist_op)` in favor of `self.log(reduce_fx)`. ([#7891](https://github.com/PyTorchLightning/pytorch-lightning/pull/7891)) +- Deprecated default value of `monitor` argument in EarlyStopping callback to enforce `monitor` as a required argument ([#7907](https://github.com/PyTorchLightning/pytorch-lightning/pull/7907)) + + ### Removed - Removed `ProfilerConnector` ([#7654](https://github.com/PyTorchLightning/pytorch-lightning/pull/7654)) diff --git a/pytorch_lightning/callbacks/early_stopping.py b/pytorch_lightning/callbacks/early_stopping.py index 242eeed808..e40bb7180c 100644 --- a/pytorch_lightning/callbacks/early_stopping.py +++ b/pytorch_lightning/callbacks/early_stopping.py @@ -26,7 +26,7 @@ import torch import pytorch_lightning as pl from pytorch_lightning.callbacks.base import Callback -from pytorch_lightning.utilities import rank_zero_warn +from pytorch_lightning.utilities import rank_zero_deprecation, rank_zero_warn from pytorch_lightning.utilities.exceptions import MisconfigurationException log = logging.getLogger(__name__) @@ -88,7 +88,7 @@ class EarlyStopping(Callback): def __init__( self, - monitor: str = 'early_stop_on', + monitor: Optional[str] = None, min_delta: float = 0.0, patience: int = 3, verbose: bool = False, @@ -100,7 +100,6 @@ class EarlyStopping(Callback): check_on_train_epoch_end: bool = False, ): super().__init__() - self.monitor = monitor self.min_delta = min_delta self.patience = patience self.verbose = verbose @@ -120,6 +119,13 @@ class EarlyStopping(Callback): torch_inf = torch.tensor(np.Inf) self.best_score = torch_inf if self.monitor_op == torch.lt else -torch_inf + if monitor is None: + rank_zero_deprecation( + "The `EarlyStopping(monitor)` argument will be required starting in v1.6." + " For backward compatibility, setting this to `early_stop_on`." + ) + self.monitor = monitor or "early_stop_on" + def _validate_condition_metric(self, logs): monitor_val = logs.get(self.monitor) diff --git a/tests/deprecated_api/test_remove_1-6.py b/tests/deprecated_api/test_remove_1-6.py index 1b4f6cacfe..1dfbb91022 100644 --- a/tests/deprecated_api/test_remove_1-6.py +++ b/tests/deprecated_api/test_remove_1-6.py @@ -15,6 +15,7 @@ import pytest from pytorch_lightning import Trainer +from pytorch_lightning.callbacks.early_stopping import EarlyStopping from pytorch_lightning.plugins.training_type import DDPPlugin, DDPSpawnPlugin from tests.helpers import BoringDataModule, BoringModel @@ -172,3 +173,11 @@ def test_v1_6_0_datamodule_hooks_calls(tmpdir): assert dm.prepare_data_calls == 1 assert dm.setup_calls == ['fit', None] assert dm.teardown_calls == ['validate', 'test'] + + +def test_v1_6_0_early_stopping_monitor(tmpdir): + with pytest.deprecated_call( + match=r"The `EarlyStopping\(monitor\)` argument will be required starting in v1.6." + " For backward compatibility, setting this to `early_stop_on`." + ): + EarlyStopping()