Deprecate the default `EarlyStopping` callback monitor value (#7907)
* removed monitor default value and added depreceation message * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * format change * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * requested changes * added test * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * format changes * typehint change * Update CHANGELOG.md * requested changes * regex Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com>
This commit is contained in:
parent
c1eac483e9
commit
8b73869369
|
@ -172,6 +172,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
|
|||
- Deprecated `self.log(sync_dist_op)` in favor of `self.log(reduce_fx)`. ([#7891](https://github.com/PyTorchLightning/pytorch-lightning/pull/7891))
|
||||
|
||||
|
||||
- Deprecated default value of `monitor` argument in EarlyStopping callback to enforce `monitor` as a required argument ([#7907](https://github.com/PyTorchLightning/pytorch-lightning/pull/7907))
|
||||
|
||||
|
||||
### Removed
|
||||
|
||||
- Removed `ProfilerConnector` ([#7654](https://github.com/PyTorchLightning/pytorch-lightning/pull/7654))
|
||||
|
|
|
@ -26,7 +26,7 @@ import torch
|
|||
|
||||
import pytorch_lightning as pl
|
||||
from pytorch_lightning.callbacks.base import Callback
|
||||
from pytorch_lightning.utilities import rank_zero_warn
|
||||
from pytorch_lightning.utilities import rank_zero_deprecation, rank_zero_warn
|
||||
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
@ -88,7 +88,7 @@ class EarlyStopping(Callback):
|
|||
|
||||
def __init__(
|
||||
self,
|
||||
monitor: str = 'early_stop_on',
|
||||
monitor: Optional[str] = None,
|
||||
min_delta: float = 0.0,
|
||||
patience: int = 3,
|
||||
verbose: bool = False,
|
||||
|
@ -100,7 +100,6 @@ class EarlyStopping(Callback):
|
|||
check_on_train_epoch_end: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
self.monitor = monitor
|
||||
self.min_delta = min_delta
|
||||
self.patience = patience
|
||||
self.verbose = verbose
|
||||
|
@ -120,6 +119,13 @@ class EarlyStopping(Callback):
|
|||
torch_inf = torch.tensor(np.Inf)
|
||||
self.best_score = torch_inf if self.monitor_op == torch.lt else -torch_inf
|
||||
|
||||
if monitor is None:
|
||||
rank_zero_deprecation(
|
||||
"The `EarlyStopping(monitor)` argument will be required starting in v1.6."
|
||||
" For backward compatibility, setting this to `early_stop_on`."
|
||||
)
|
||||
self.monitor = monitor or "early_stop_on"
|
||||
|
||||
def _validate_condition_metric(self, logs):
|
||||
monitor_val = logs.get(self.monitor)
|
||||
|
||||
|
|
|
@ -15,6 +15,7 @@
|
|||
import pytest
|
||||
|
||||
from pytorch_lightning import Trainer
|
||||
from pytorch_lightning.callbacks.early_stopping import EarlyStopping
|
||||
from pytorch_lightning.plugins.training_type import DDPPlugin, DDPSpawnPlugin
|
||||
from tests.helpers import BoringDataModule, BoringModel
|
||||
|
||||
|
@ -172,3 +173,11 @@ def test_v1_6_0_datamodule_hooks_calls(tmpdir):
|
|||
assert dm.prepare_data_calls == 1
|
||||
assert dm.setup_calls == ['fit', None]
|
||||
assert dm.teardown_calls == ['validate', 'test']
|
||||
|
||||
|
||||
def test_v1_6_0_early_stopping_monitor(tmpdir):
|
||||
with pytest.deprecated_call(
|
||||
match=r"The `EarlyStopping\(monitor\)` argument will be required starting in v1.6."
|
||||
" For backward compatibility, setting this to `early_stop_on`."
|
||||
):
|
||||
EarlyStopping()
|
||||
|
|
Loading…
Reference in New Issue