lightning/tests/callbacks/test_early_stopping.py

459 lines
17 KiB
Python

# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import math
import pickle
from typing import List, Optional
from unittest import mock
from unittest.mock import Mock
import cloudpickle
import numpy as np
import pytest
import torch
from pytorch_lightning import seed_everything, Trainer
from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from tests.helpers import BoringModel
from tests.helpers.datamodules import ClassifDataModule
from tests.helpers.runif import RunIf
from tests.helpers.simple_models import ClassificationModel
_logger = logging.getLogger(__name__)
def test_early_stopping_state_key():
early_stopping = EarlyStopping(monitor="val_loss")
assert early_stopping.state_key == "EarlyStopping{'monitor': 'val_loss', 'mode': 'min'}"
class EarlyStoppingTestRestore(EarlyStopping):
# this class has to be defined outside the test function, otherwise we get pickle error
def __init__(self, expected_state, *args, **kwargs):
super().__init__(*args, **kwargs)
self.expected_state = expected_state
# cache the state for each epoch
self.saved_states = []
def on_train_start(self, trainer, pl_module):
if self.expected_state:
assert self.state_dict() == self.expected_state
def on_train_epoch_end(self, trainer, pl_module):
super().on_train_epoch_end(trainer, pl_module)
self.saved_states.append(self.state_dict().copy())
def test_resume_early_stopping_from_checkpoint(tmpdir):
"""Prevent regressions to bugs:
https://github.com/PyTorchLightning/pytorch-lightning/issues/1464
https://github.com/PyTorchLightning/pytorch-lightning/issues/1463
"""
seed_everything(42)
model = ClassificationModel()
dm = ClassifDataModule()
checkpoint_callback = ModelCheckpoint(dirpath=tmpdir, monitor="train_loss", save_top_k=1)
early_stop_callback = EarlyStoppingTestRestore(None, monitor="train_loss")
trainer = Trainer(
default_root_dir=tmpdir,
callbacks=[early_stop_callback, checkpoint_callback],
num_sanity_val_steps=0,
max_epochs=4,
)
trainer.fit(model, datamodule=dm)
assert len(early_stop_callback.saved_states) == 4
checkpoint_filepath = checkpoint_callback.kth_best_model_path
# ensure state is persisted properly
checkpoint = torch.load(checkpoint_filepath)
# the checkpoint saves "epoch + 1"
early_stop_callback_state = early_stop_callback.saved_states[checkpoint["epoch"]]
assert 4 == len(early_stop_callback.saved_states)
es_name = "EarlyStoppingTestRestore{'monitor': 'train_loss', 'mode': 'min'}"
assert checkpoint["callbacks"][es_name] == early_stop_callback_state
# ensure state is reloaded properly (assertion in the callback)
early_stop_callback = EarlyStoppingTestRestore(early_stop_callback_state, monitor="train_loss")
new_trainer = Trainer(
default_root_dir=tmpdir,
max_epochs=1,
callbacks=[early_stop_callback],
)
with pytest.raises(MisconfigurationException, match=r"You restored a checkpoint with current_epoch"):
new_trainer.fit(model, datamodule=dm, ckpt_path=checkpoint_filepath)
def test_early_stopping_no_extraneous_invocations(tmpdir):
"""Test to ensure that callback methods aren't being invoked outside of the callback handler."""
model = ClassificationModel()
dm = ClassifDataModule()
early_stop_callback = EarlyStopping(monitor="train_loss")
early_stop_callback._run_early_stopping_check = Mock()
expected_count = 4
trainer = Trainer(
default_root_dir=tmpdir,
callbacks=[early_stop_callback],
limit_train_batches=4,
limit_val_batches=4,
max_epochs=expected_count,
enable_checkpointing=False,
)
trainer.fit(model, datamodule=dm)
assert trainer.early_stopping_callback == early_stop_callback
assert trainer.early_stopping_callbacks == [early_stop_callback]
assert early_stop_callback._run_early_stopping_check.call_count == expected_count
@pytest.mark.parametrize(
"loss_values, patience, expected_stop_epoch",
[([6, 5, 5, 5, 5, 5], 3, 4), ([6, 5, 4, 4, 3, 3], 1, 3), ([6, 5, 6, 5, 5, 5], 3, 4)],
)
def test_early_stopping_patience(tmpdir, loss_values: list, patience: int, expected_stop_epoch: int):
"""Test to ensure that early stopping is not triggered before patience is exhausted."""
class ModelOverrideValidationReturn(BoringModel):
validation_return_values = torch.tensor(loss_values)
def validation_epoch_end(self, outputs):
loss = self.validation_return_values[self.current_epoch]
self.log("test_val_loss", loss)
model = ModelOverrideValidationReturn()
early_stop_callback = EarlyStopping(monitor="test_val_loss", patience=patience, verbose=True)
trainer = Trainer(
default_root_dir=tmpdir,
callbacks=[early_stop_callback],
num_sanity_val_steps=0,
max_epochs=10,
enable_progress_bar=False,
)
trainer.fit(model)
assert trainer.current_epoch - 1 == expected_stop_epoch
@pytest.mark.parametrize("validation_step_none", [True, False])
@pytest.mark.parametrize(
"loss_values, patience, expected_stop_epoch",
[([6, 5, 5, 5, 5, 5], 3, 4), ([6, 5, 4, 4, 3, 3], 1, 3), ([6, 5, 6, 5, 5, 5], 3, 4)],
)
def test_early_stopping_patience_train(
tmpdir, validation_step_none: bool, loss_values: list, patience: int, expected_stop_epoch: int
):
"""Test to ensure that early stopping is not triggered before patience is exhausted."""
class ModelOverrideTrainReturn(BoringModel):
train_return_values = torch.tensor(loss_values)
def training_epoch_end(self, outputs):
loss = self.train_return_values[self.current_epoch]
self.log("train_loss", loss)
model = ModelOverrideTrainReturn()
if validation_step_none:
model.validation_step = None
early_stop_callback = EarlyStopping(
monitor="train_loss", patience=patience, verbose=True, check_on_train_epoch_end=True
)
trainer = Trainer(
default_root_dir=tmpdir,
callbacks=[early_stop_callback],
num_sanity_val_steps=0,
max_epochs=10,
enable_progress_bar=False,
)
trainer.fit(model)
assert trainer.current_epoch - 1 == expected_stop_epoch
def test_pickling(tmpdir):
early_stopping = EarlyStopping(monitor="foo")
early_stopping_pickled = pickle.dumps(early_stopping)
early_stopping_loaded = pickle.loads(early_stopping_pickled)
assert vars(early_stopping) == vars(early_stopping_loaded)
early_stopping_pickled = cloudpickle.dumps(early_stopping)
early_stopping_loaded = cloudpickle.loads(early_stopping_pickled)
assert vars(early_stopping) == vars(early_stopping_loaded)
def test_early_stopping_no_val_step(tmpdir):
"""Test that early stopping callback falls back to training metrics when no validation defined."""
model = ClassificationModel()
dm = ClassifDataModule()
model.validation_step = None
model.val_dataloader = None
stopping = EarlyStopping(monitor="train_loss", min_delta=0.1, patience=0, check_on_train_epoch_end=True)
trainer = Trainer(default_root_dir=tmpdir, callbacks=[stopping], overfit_batches=0.20, max_epochs=10)
trainer.fit(model, datamodule=dm)
assert trainer.state.finished, f"Training failed with {trainer.state}"
assert trainer.current_epoch < trainer.max_epochs - 1
@pytest.mark.parametrize(
"stopping_threshold,divergence_threshold,losses,expected_epoch",
[
(None, None, [8, 4, 2, 3, 4, 5, 8, 10], 5),
(2.9, None, [9, 8, 7, 6, 5, 6, 4, 3, 2, 1], 8),
(None, 15.9, [9, 4, 2, 16, 32, 64], 3),
],
)
def test_early_stopping_thresholds(tmpdir, stopping_threshold, divergence_threshold, losses, expected_epoch):
class CurrentModel(BoringModel):
def validation_epoch_end(self, outputs):
val_loss = losses[self.current_epoch]
self.log("abc", val_loss)
model = CurrentModel()
early_stopping = EarlyStopping(
monitor="abc", stopping_threshold=stopping_threshold, divergence_threshold=divergence_threshold
)
trainer = Trainer(
default_root_dir=tmpdir,
callbacks=[early_stopping],
limit_train_batches=0.2,
limit_val_batches=0.2,
max_epochs=20,
)
trainer.fit(model)
assert trainer.current_epoch - 1 == expected_epoch, "early_stopping failed"
@pytest.mark.parametrize("stop_value", [torch.tensor(np.inf), torch.tensor(np.nan)])
def test_early_stopping_on_non_finite_monitor(tmpdir, stop_value):
losses = [4, 3, stop_value, 2, 1]
expected_stop_epoch = 2
class CurrentModel(BoringModel):
def validation_epoch_end(self, outputs):
val_loss = losses[self.current_epoch]
self.log("val_loss", val_loss)
model = CurrentModel()
early_stopping = EarlyStopping(monitor="val_loss", check_finite=True)
trainer = Trainer(
default_root_dir=tmpdir,
callbacks=[early_stopping],
limit_train_batches=0.2,
limit_val_batches=0.2,
max_epochs=10,
)
trainer.fit(model)
assert trainer.current_epoch - 1 == expected_stop_epoch
assert early_stopping.stopped_epoch == expected_stop_epoch
@pytest.mark.parametrize("limit_train_batches", (3, 5))
@pytest.mark.parametrize(
["min_epochs", "min_steps"],
[
# IF `min_steps` was set to a higher value than the `trainer.global_step` when `early_stopping` is being
# triggered, THEN the trainer should continue until reaching `trainer.global_step == min_steps` and stop
(0, 10),
# IF `min_epochs` resulted in a higher number of steps than the `trainer.global_step` when `early_stopping` is
# being triggered, THEN the trainer should continue until reaching
# `trainer.global_step` == `min_epochs * len(train_dataloader)`
(2, 0),
# IF both `min_epochs` and `min_steps` are provided and higher than the `trainer.global_step` when
# `early_stopping` is being triggered, THEN the highest between `min_epochs * len(train_dataloader)` and
# `min_steps` would be reached
(1, 10),
(3, 10),
],
)
def test_min_epochs_min_steps_global_step(tmpdir, limit_train_batches, min_epochs, min_steps):
if min_steps:
assert limit_train_batches < min_steps
class TestModel(BoringModel):
def training_step(self, batch, batch_idx):
self.log("foo", batch_idx)
return super().training_step(batch, batch_idx)
es_callback = EarlyStopping("foo")
trainer = Trainer(
default_root_dir=tmpdir,
callbacks=es_callback,
limit_val_batches=0,
limit_train_batches=limit_train_batches,
min_epochs=min_epochs,
min_steps=min_steps,
logger=False,
enable_checkpointing=False,
enable_progress_bar=False,
enable_model_summary=False,
)
model = TestModel()
expected_epochs = max(math.ceil(min_steps / limit_train_batches), min_epochs)
# trigger early stopping directly after the first epoch
side_effect = [(True, "")] * expected_epochs
with mock.patch.object(es_callback, "_evaluate_stopping_criteria", side_effect=side_effect):
trainer.fit(model)
assert trainer.should_stop
# epochs continue until min steps are reached
assert trainer.current_epoch == expected_epochs
# steps continue until min steps are reached AND the epoch is exhausted
# stopping mid-epoch is not supported
assert trainer.global_step == limit_train_batches * expected_epochs
def test_early_stopping_mode_options():
with pytest.raises(MisconfigurationException, match="`mode` can be .* got unknown_option"):
EarlyStopping(monitor="foo", mode="unknown_option")
class EarlyStoppingModel(BoringModel):
def __init__(self, expected_end_epoch: int, early_stop_on_train: bool):
super().__init__()
self.expected_end_epoch = expected_end_epoch
self.early_stop_on_train = early_stop_on_train
def _epoch_end(self) -> None:
losses = [8, 4, 2, 3, 4, 5, 8, 10]
loss = losses[self.current_epoch]
self.log("abc", torch.tensor(loss))
self.log("cba", torch.tensor(0))
def training_epoch_end(self, outputs):
if not self.early_stop_on_train:
return
self._epoch_end()
def validation_epoch_end(self, outputs):
if self.early_stop_on_train:
return
self._epoch_end()
def on_train_end(self) -> None:
assert self.trainer.current_epoch - 1 == self.expected_end_epoch, "Early Stopping Failed"
_ES_CHECK = dict(check_on_train_epoch_end=True)
_ES_CHECK_P3 = dict(patience=3, check_on_train_epoch_end=True)
_SPAWN_MARK = dict(marks=RunIf(skip_windows=True))
@pytest.mark.parametrize(
"callbacks, expected_stop_epoch, check_on_train_epoch_end, strategy, devices",
[
([EarlyStopping("abc"), EarlyStopping("cba", patience=3)], 3, False, None, 1),
([EarlyStopping("cba", patience=3), EarlyStopping("abc")], 3, False, None, 1),
pytest.param([EarlyStopping("abc"), EarlyStopping("cba", patience=3)], 3, False, "ddp_spawn", 2, **_SPAWN_MARK),
pytest.param([EarlyStopping("cba", patience=3), EarlyStopping("abc")], 3, False, "ddp_spawn", 2, **_SPAWN_MARK),
([EarlyStopping("abc", **_ES_CHECK), EarlyStopping("cba", **_ES_CHECK_P3)], 3, True, None, 1),
([EarlyStopping("cba", **_ES_CHECK_P3), EarlyStopping("abc", **_ES_CHECK)], 3, True, None, 1),
pytest.param(
[EarlyStopping("abc", **_ES_CHECK), EarlyStopping("cba", **_ES_CHECK_P3)],
3,
True,
"ddp_spawn",
2,
**_SPAWN_MARK,
),
pytest.param(
[EarlyStopping("cba", **_ES_CHECK_P3), EarlyStopping("abc", **_ES_CHECK)],
3,
True,
"ddp_spawn",
2,
**_SPAWN_MARK,
),
],
)
def test_multiple_early_stopping_callbacks(
tmpdir,
callbacks: List[EarlyStopping],
expected_stop_epoch: int,
check_on_train_epoch_end: bool,
strategy: Optional[str],
devices: int,
):
"""Ensure when using multiple early stopping callbacks we stop if any signals we should stop."""
model = EarlyStoppingModel(expected_stop_epoch, check_on_train_epoch_end)
trainer = Trainer(
default_root_dir=tmpdir,
callbacks=callbacks,
limit_train_batches=0.1,
limit_val_batches=0.1,
max_epochs=20,
strategy=strategy,
accelerator="cpu",
devices=devices,
)
trainer.fit(model)
@pytest.mark.parametrize(
"case",
{
"val_check_interval": {"val_check_interval": 0.3, "limit_train_batches": 10, "max_epochs": 10},
"check_val_every_n_epoch": {"check_val_every_n_epoch": 2, "max_epochs": 5},
}.items(),
)
def test_check_on_train_epoch_end_smart_handling(tmpdir, case):
class TestModel(BoringModel):
def validation_step(self, batch, batch_idx):
self.log("foo", 1)
return super().validation_step(batch, batch_idx)
case, kwargs = case
model = TestModel()
trainer = Trainer(
default_root_dir=tmpdir,
limit_val_batches=1,
callbacks=EarlyStopping(monitor="foo"),
enable_progress_bar=False,
**kwargs,
)
side_effect = [(False, "A"), (True, "B")]
with mock.patch(
"pytorch_lightning.callbacks.EarlyStopping._evaluate_stopping_criteria", side_effect=side_effect
) as es_mock:
trainer.fit(model)
assert es_mock.call_count == len(side_effect)
if case == "val_check_interval":
assert trainer.global_step == len(side_effect) * int(trainer.limit_train_batches * trainer.val_check_interval)
else:
assert trainer.current_epoch == len(side_effect) * trainer.check_val_every_n_epoch
def test_early_stopping_squeezes():
early_stopping = EarlyStopping(monitor="foo")
trainer = Trainer()
trainer.callback_metrics["foo"] = torch.tensor([[[0]]])
with mock.patch(
"pytorch_lightning.callbacks.EarlyStopping._evaluate_stopping_criteria", return_value=(False, "")
) as es_mock:
early_stopping._run_early_stopping_check(trainer)
es_mock.assert_called_once_with(torch.tensor(0))