From 8b475278dde95737ac42e2f9f3897fd040bd28ed Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Carlos=20Mochol=C3=AD?= Date: Wed, 24 Feb 2021 14:26:33 +0100 Subject: [PATCH] Prune deprecated EarlyStopping(mode='auto') (#6167) Co-authored-by: Roger Shieh Co-authored-by: Rohit Gupta --- CHANGELOG.md | 3 ++ pytorch_lightning/callbacks/early_stopping.py | 39 ++++--------------- tests/callbacks/test_early_stopping.py | 4 +- tests/deprecated_api/test_remove_1-3.py | 1 - 4 files changed, 13 insertions(+), 34 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index b4dedb84c9..98332ee496 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -34,6 +34,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Removed deprecated `ModelCheckpoint` arguments `prefix`, `mode="auto"` ([#6162](https://github.com/PyTorchLightning/pytorch-lightning/pull/6162)) +- Removed `mode='auto'` from `EarlyStopping` ([#6167](https://github.com/PyTorchLightning/pytorch-lightning/pull/6167)) + + ### Fixed - Made the `Plugin.reduce` method more consistent across all Plugins to reflect a mean-reduction by default ([#6011](https://github.com/PyTorchLightning/pytorch-lightning/pull/6011)) diff --git a/pytorch_lightning/callbacks/early_stopping.py b/pytorch_lightning/callbacks/early_stopping.py index 384ce9699f..d188aebe96 100644 --- a/pytorch_lightning/callbacks/early_stopping.py +++ b/pytorch_lightning/callbacks/early_stopping.py @@ -23,7 +23,7 @@ import numpy as np import torch from pytorch_lightning.callbacks.base import Callback -from pytorch_lightning.utilities import rank_zero_info, rank_zero_warn +from pytorch_lightning.utilities import rank_zero_warn from pytorch_lightning.utilities.exceptions import MisconfigurationException @@ -40,23 +40,18 @@ class EarlyStopping(Callback): patience: number of validation epochs with no improvement after which training will be stopped. Default: ``3``. verbose: verbosity mode. Default: ``False``. - mode: one of {auto, min, max}. In `min` mode, + mode: one of ``'min'``, ``'max'``. In ``'min'`` mode, training will stop when the quantity - monitored has stopped decreasing; in `max` + monitored has stopped decreasing and in ``'max'`` mode it will stop when the quantity - monitored has stopped increasing; in `auto` - mode, the direction is automatically inferred - from the name of the monitored quantity. - - .. warning:: - Setting ``mode='auto'`` has been deprecated in v1.1 and will be removed in v1.3. + monitored has stopped increasing. strict: whether to crash the training if `monitor` is not found in the validation metrics. Default: ``True``. Raises: MisconfigurationException: - If ``mode`` is none of ``"min"``, ``"max"``, and ``"auto"``. + If ``mode`` is none of ``"min"`` or ``"max"``. RuntimeError: If the metric ``monitor`` is not available. @@ -78,7 +73,7 @@ class EarlyStopping(Callback): min_delta: float = 0.0, patience: int = 3, verbose: bool = False, - mode: str = 'auto', + mode: str = 'min', strict: bool = True, ): super().__init__() @@ -92,31 +87,13 @@ class EarlyStopping(Callback): self.mode = mode self.warned_result_obj = False - self.__init_monitor_mode() + if self.mode not in self.mode_dict: + raise MisconfigurationException(f"`mode` can be {', '.join(self.mode_dict.keys())}, got {self.mode}") self.min_delta *= 1 if self.monitor_op == torch.gt else -1 torch_inf = torch.tensor(np.Inf) self.best_score = torch_inf if self.monitor_op == torch.lt else -torch_inf - def __init_monitor_mode(self): - if self.mode not in self.mode_dict and self.mode != 'auto': - raise MisconfigurationException(f"`mode` can be auto, {', '.join(self.mode_dict.keys())}, got {self.mode}") - - # TODO: Update with MisconfigurationException when auto mode is removed in v1.3 - if self.mode == 'auto': - rank_zero_warn( - "mode='auto' is deprecated in v1.1 and will be removed in v1.3." - " Default value for mode with be 'min' in v1.3.", DeprecationWarning - ) - - if "acc" in self.monitor or self.monitor.startswith("fmeasure"): - self.mode = 'max' - else: - self.mode = 'min' - - if self.verbose > 0: - rank_zero_info(f'EarlyStopping mode set to {self.mode} for monitoring {self.monitor}.') - def _validate_condition_metric(self, logs): monitor_val = logs.get(self.monitor) diff --git a/tests/callbacks/test_early_stopping.py b/tests/callbacks/test_early_stopping.py index 9d326f0455..f36b68c0da 100644 --- a/tests/callbacks/test_early_stopping.py +++ b/tests/callbacks/test_early_stopping.py @@ -334,7 +334,7 @@ def test_min_steps_override_early_stopping_functionality(tmpdir, step_freeze, mi # Compute min_epochs latest step by_min_epochs = min_epochs * limit_train_batches - # Make sure the trainer stops for the max of all minimun requirements + # Make sure the trainer stops for the max of all minimum requirements assert trainer.global_step == max(min_steps, by_early_stopping, by_min_epochs), \ (trainer.global_step, max(min_steps, by_early_stopping, by_min_epochs), step_freeze, min_steps, min_epochs) @@ -342,5 +342,5 @@ def test_min_steps_override_early_stopping_functionality(tmpdir, step_freeze, mi def test_early_stopping_mode_options(): - with pytest.raises(MisconfigurationException, match="`mode` can be auto, .* got unknown_option"): + with pytest.raises(MisconfigurationException, match="`mode` can be .* got unknown_option"): EarlyStopping(mode="unknown_option") diff --git a/tests/deprecated_api/test_remove_1-3.py b/tests/deprecated_api/test_remove_1-3.py index 1710bb8777..a0b84cff12 100644 --- a/tests/deprecated_api/test_remove_1-3.py +++ b/tests/deprecated_api/test_remove_1-3.py @@ -19,7 +19,6 @@ from pytorch_lightning import LightningModule def test_v1_3_0_deprecated_arguments(tmpdir): - with pytest.deprecated_call(match="The setter for self.hparams in LightningModule is deprecated"): class DeprecatedHparamsModel(LightningModule):