Prune deprecated EarlyStopping(mode='auto') (#6167)
Co-authored-by: Roger Shieh <sh.rog@protonmail.ch> Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com>
This commit is contained in:
parent
46617d9021
commit
8b475278dd
|
@ -34,6 +34,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
|
|||
- Removed deprecated `ModelCheckpoint` arguments `prefix`, `mode="auto"` ([#6162](https://github.com/PyTorchLightning/pytorch-lightning/pull/6162))
|
||||
|
||||
|
||||
- Removed `mode='auto'` from `EarlyStopping` ([#6167](https://github.com/PyTorchLightning/pytorch-lightning/pull/6167))
|
||||
|
||||
|
||||
### Fixed
|
||||
|
||||
- Made the `Plugin.reduce` method more consistent across all Plugins to reflect a mean-reduction by default ([#6011](https://github.com/PyTorchLightning/pytorch-lightning/pull/6011))
|
||||
|
|
|
@ -23,7 +23,7 @@ import numpy as np
|
|||
import torch
|
||||
|
||||
from pytorch_lightning.callbacks.base import Callback
|
||||
from pytorch_lightning.utilities import rank_zero_info, rank_zero_warn
|
||||
from pytorch_lightning.utilities import rank_zero_warn
|
||||
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
||||
|
||||
|
||||
|
@ -40,23 +40,18 @@ class EarlyStopping(Callback):
|
|||
patience: number of validation epochs with no improvement
|
||||
after which training will be stopped. Default: ``3``.
|
||||
verbose: verbosity mode. Default: ``False``.
|
||||
mode: one of {auto, min, max}. In `min` mode,
|
||||
mode: one of ``'min'``, ``'max'``. In ``'min'`` mode,
|
||||
training will stop when the quantity
|
||||
monitored has stopped decreasing; in `max`
|
||||
monitored has stopped decreasing and in ``'max'``
|
||||
mode it will stop when the quantity
|
||||
monitored has stopped increasing; in `auto`
|
||||
mode, the direction is automatically inferred
|
||||
from the name of the monitored quantity.
|
||||
|
||||
.. warning::
|
||||
Setting ``mode='auto'`` has been deprecated in v1.1 and will be removed in v1.3.
|
||||
monitored has stopped increasing.
|
||||
|
||||
strict: whether to crash the training if `monitor` is
|
||||
not found in the validation metrics. Default: ``True``.
|
||||
|
||||
Raises:
|
||||
MisconfigurationException:
|
||||
If ``mode`` is none of ``"min"``, ``"max"``, and ``"auto"``.
|
||||
If ``mode`` is none of ``"min"`` or ``"max"``.
|
||||
RuntimeError:
|
||||
If the metric ``monitor`` is not available.
|
||||
|
||||
|
@ -78,7 +73,7 @@ class EarlyStopping(Callback):
|
|||
min_delta: float = 0.0,
|
||||
patience: int = 3,
|
||||
verbose: bool = False,
|
||||
mode: str = 'auto',
|
||||
mode: str = 'min',
|
||||
strict: bool = True,
|
||||
):
|
||||
super().__init__()
|
||||
|
@ -92,31 +87,13 @@ class EarlyStopping(Callback):
|
|||
self.mode = mode
|
||||
self.warned_result_obj = False
|
||||
|
||||
self.__init_monitor_mode()
|
||||
if self.mode not in self.mode_dict:
|
||||
raise MisconfigurationException(f"`mode` can be {', '.join(self.mode_dict.keys())}, got {self.mode}")
|
||||
|
||||
self.min_delta *= 1 if self.monitor_op == torch.gt else -1
|
||||
torch_inf = torch.tensor(np.Inf)
|
||||
self.best_score = torch_inf if self.monitor_op == torch.lt else -torch_inf
|
||||
|
||||
def __init_monitor_mode(self):
|
||||
if self.mode not in self.mode_dict and self.mode != 'auto':
|
||||
raise MisconfigurationException(f"`mode` can be auto, {', '.join(self.mode_dict.keys())}, got {self.mode}")
|
||||
|
||||
# TODO: Update with MisconfigurationException when auto mode is removed in v1.3
|
||||
if self.mode == 'auto':
|
||||
rank_zero_warn(
|
||||
"mode='auto' is deprecated in v1.1 and will be removed in v1.3."
|
||||
" Default value for mode with be 'min' in v1.3.", DeprecationWarning
|
||||
)
|
||||
|
||||
if "acc" in self.monitor or self.monitor.startswith("fmeasure"):
|
||||
self.mode = 'max'
|
||||
else:
|
||||
self.mode = 'min'
|
||||
|
||||
if self.verbose > 0:
|
||||
rank_zero_info(f'EarlyStopping mode set to {self.mode} for monitoring {self.monitor}.')
|
||||
|
||||
def _validate_condition_metric(self, logs):
|
||||
monitor_val = logs.get(self.monitor)
|
||||
|
||||
|
|
|
@ -334,7 +334,7 @@ def test_min_steps_override_early_stopping_functionality(tmpdir, step_freeze, mi
|
|||
# Compute min_epochs latest step
|
||||
by_min_epochs = min_epochs * limit_train_batches
|
||||
|
||||
# Make sure the trainer stops for the max of all minimun requirements
|
||||
# Make sure the trainer stops for the max of all minimum requirements
|
||||
assert trainer.global_step == max(min_steps, by_early_stopping, by_min_epochs), \
|
||||
(trainer.global_step, max(min_steps, by_early_stopping, by_min_epochs), step_freeze, min_steps, min_epochs)
|
||||
|
||||
|
@ -342,5 +342,5 @@ def test_min_steps_override_early_stopping_functionality(tmpdir, step_freeze, mi
|
|||
|
||||
|
||||
def test_early_stopping_mode_options():
|
||||
with pytest.raises(MisconfigurationException, match="`mode` can be auto, .* got unknown_option"):
|
||||
with pytest.raises(MisconfigurationException, match="`mode` can be .* got unknown_option"):
|
||||
EarlyStopping(mode="unknown_option")
|
||||
|
|
|
@ -19,7 +19,6 @@ from pytorch_lightning import LightningModule
|
|||
|
||||
|
||||
def test_v1_3_0_deprecated_arguments(tmpdir):
|
||||
|
||||
with pytest.deprecated_call(match="The setter for self.hparams in LightningModule is deprecated"):
|
||||
|
||||
class DeprecatedHparamsModel(LightningModule):
|
||||
|
|
Loading…
Reference in New Issue