Prune deprecated EarlyStopping(mode='auto') (#6167)

Co-authored-by: Roger Shieh <sh.rog@protonmail.ch>
Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com>
This commit is contained in:
Carlos Mocholí 2021-02-24 14:26:33 +01:00 committed by GitHub
parent 46617d9021
commit 8b475278dd
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 13 additions and 34 deletions

View File

@ -34,6 +34,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Removed deprecated `ModelCheckpoint` arguments `prefix`, `mode="auto"` ([#6162](https://github.com/PyTorchLightning/pytorch-lightning/pull/6162))
- Removed `mode='auto'` from `EarlyStopping` ([#6167](https://github.com/PyTorchLightning/pytorch-lightning/pull/6167))
### Fixed
- Made the `Plugin.reduce` method more consistent across all Plugins to reflect a mean-reduction by default ([#6011](https://github.com/PyTorchLightning/pytorch-lightning/pull/6011))

View File

@ -23,7 +23,7 @@ import numpy as np
import torch
from pytorch_lightning.callbacks.base import Callback
from pytorch_lightning.utilities import rank_zero_info, rank_zero_warn
from pytorch_lightning.utilities import rank_zero_warn
from pytorch_lightning.utilities.exceptions import MisconfigurationException
@ -40,23 +40,18 @@ class EarlyStopping(Callback):
patience: number of validation epochs with no improvement
after which training will be stopped. Default: ``3``.
verbose: verbosity mode. Default: ``False``.
mode: one of {auto, min, max}. In `min` mode,
mode: one of ``'min'``, ``'max'``. In ``'min'`` mode,
training will stop when the quantity
monitored has stopped decreasing; in `max`
monitored has stopped decreasing and in ``'max'``
mode it will stop when the quantity
monitored has stopped increasing; in `auto`
mode, the direction is automatically inferred
from the name of the monitored quantity.
.. warning::
Setting ``mode='auto'`` has been deprecated in v1.1 and will be removed in v1.3.
monitored has stopped increasing.
strict: whether to crash the training if `monitor` is
not found in the validation metrics. Default: ``True``.
Raises:
MisconfigurationException:
If ``mode`` is none of ``"min"``, ``"max"``, and ``"auto"``.
If ``mode`` is none of ``"min"`` or ``"max"``.
RuntimeError:
If the metric ``monitor`` is not available.
@ -78,7 +73,7 @@ class EarlyStopping(Callback):
min_delta: float = 0.0,
patience: int = 3,
verbose: bool = False,
mode: str = 'auto',
mode: str = 'min',
strict: bool = True,
):
super().__init__()
@ -92,31 +87,13 @@ class EarlyStopping(Callback):
self.mode = mode
self.warned_result_obj = False
self.__init_monitor_mode()
if self.mode not in self.mode_dict:
raise MisconfigurationException(f"`mode` can be {', '.join(self.mode_dict.keys())}, got {self.mode}")
self.min_delta *= 1 if self.monitor_op == torch.gt else -1
torch_inf = torch.tensor(np.Inf)
self.best_score = torch_inf if self.monitor_op == torch.lt else -torch_inf
def __init_monitor_mode(self):
if self.mode not in self.mode_dict and self.mode != 'auto':
raise MisconfigurationException(f"`mode` can be auto, {', '.join(self.mode_dict.keys())}, got {self.mode}")
# TODO: Update with MisconfigurationException when auto mode is removed in v1.3
if self.mode == 'auto':
rank_zero_warn(
"mode='auto' is deprecated in v1.1 and will be removed in v1.3."
" Default value for mode with be 'min' in v1.3.", DeprecationWarning
)
if "acc" in self.monitor or self.monitor.startswith("fmeasure"):
self.mode = 'max'
else:
self.mode = 'min'
if self.verbose > 0:
rank_zero_info(f'EarlyStopping mode set to {self.mode} for monitoring {self.monitor}.')
def _validate_condition_metric(self, logs):
monitor_val = logs.get(self.monitor)

View File

@ -334,7 +334,7 @@ def test_min_steps_override_early_stopping_functionality(tmpdir, step_freeze, mi
# Compute min_epochs latest step
by_min_epochs = min_epochs * limit_train_batches
# Make sure the trainer stops for the max of all minimun requirements
# Make sure the trainer stops for the max of all minimum requirements
assert trainer.global_step == max(min_steps, by_early_stopping, by_min_epochs), \
(trainer.global_step, max(min_steps, by_early_stopping, by_min_epochs), step_freeze, min_steps, min_epochs)
@ -342,5 +342,5 @@ def test_min_steps_override_early_stopping_functionality(tmpdir, step_freeze, mi
def test_early_stopping_mode_options():
with pytest.raises(MisconfigurationException, match="`mode` can be auto, .* got unknown_option"):
with pytest.raises(MisconfigurationException, match="`mode` can be .* got unknown_option"):
EarlyStopping(mode="unknown_option")

View File

@ -19,7 +19,6 @@ from pytorch_lightning import LightningModule
def test_v1_3_0_deprecated_arguments(tmpdir):
with pytest.deprecated_call(match="The setter for self.hparams in LightningModule is deprecated"):
class DeprecatedHparamsModel(LightningModule):