512 lines
19 KiB
Python
512 lines
19 KiB
Python
# Copyright The Lightning AI team.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
import logging
|
|
import math
|
|
import os
|
|
import pickle
|
|
from typing import List, Optional
|
|
from unittest import mock
|
|
from unittest.mock import Mock
|
|
|
|
import cloudpickle
|
|
import pytest
|
|
import torch
|
|
|
|
from lightning.pytorch import seed_everything, Trainer
|
|
from lightning.pytorch.callbacks import EarlyStopping, ModelCheckpoint
|
|
from lightning.pytorch.demos.boring_classes import BoringModel
|
|
from lightning.pytorch.utilities.exceptions import MisconfigurationException
|
|
from tests_pytorch.helpers.datamodules import ClassifDataModule
|
|
from tests_pytorch.helpers.runif import RunIf
|
|
from tests_pytorch.helpers.simple_models import ClassificationModel
|
|
|
|
_logger = logging.getLogger(__name__)
|
|
|
|
|
|
def test_early_stopping_state_key():
|
|
early_stopping = EarlyStopping(monitor="val_loss")
|
|
assert early_stopping.state_key == "EarlyStopping{'monitor': 'val_loss', 'mode': 'min'}"
|
|
|
|
|
|
class EarlyStoppingTestRestore(EarlyStopping):
|
|
# this class has to be defined outside the test function, otherwise we get pickle error
|
|
def __init__(self, expected_state, *args, **kwargs):
|
|
super().__init__(*args, **kwargs)
|
|
self.expected_state = expected_state
|
|
# cache the state for each epoch
|
|
self.saved_states = []
|
|
|
|
def on_train_start(self, trainer, pl_module):
|
|
if self.expected_state:
|
|
assert self.state_dict() == self.expected_state
|
|
|
|
def on_train_epoch_end(self, trainer, pl_module):
|
|
super().on_train_epoch_end(trainer, pl_module)
|
|
self.saved_states.append(self.state_dict().copy())
|
|
|
|
|
|
@RunIf(sklearn=True)
|
|
@mock.patch.dict(os.environ, os.environ.copy(), clear=True)
|
|
def test_resume_early_stopping_from_checkpoint(tmpdir):
|
|
"""Prevent regressions to bugs:
|
|
|
|
https://github.com/Lightning-AI/lightning/issues/1464
|
|
https://github.com/Lightning-AI/lightning/issues/1463
|
|
"""
|
|
seed_everything(42)
|
|
model = ClassificationModel()
|
|
dm = ClassifDataModule()
|
|
checkpoint_callback = ModelCheckpoint(dirpath=tmpdir, monitor="train_loss", save_top_k=1)
|
|
early_stop_callback = EarlyStoppingTestRestore(None, monitor="train_loss")
|
|
trainer = Trainer(
|
|
default_root_dir=tmpdir,
|
|
callbacks=[early_stop_callback, checkpoint_callback],
|
|
num_sanity_val_steps=0,
|
|
max_epochs=4,
|
|
)
|
|
trainer.fit(model, datamodule=dm)
|
|
|
|
assert len(early_stop_callback.saved_states) == 4
|
|
|
|
checkpoint_filepath = checkpoint_callback.kth_best_model_path
|
|
# ensure state is persisted properly
|
|
checkpoint = torch.load(checkpoint_filepath)
|
|
# the checkpoint saves "epoch + 1"
|
|
early_stop_callback_state = early_stop_callback.saved_states[checkpoint["epoch"]]
|
|
assert len(early_stop_callback.saved_states) == 4
|
|
es_name = "EarlyStoppingTestRestore{'monitor': 'train_loss', 'mode': 'min'}"
|
|
assert checkpoint["callbacks"][es_name] == early_stop_callback_state
|
|
|
|
# ensure state is reloaded properly (assertion in the callback)
|
|
early_stop_callback = EarlyStoppingTestRestore(early_stop_callback_state, monitor="train_loss")
|
|
new_trainer = Trainer(
|
|
default_root_dir=tmpdir,
|
|
max_epochs=1,
|
|
callbacks=[early_stop_callback],
|
|
)
|
|
|
|
with pytest.raises(MisconfigurationException, match=r"You restored a checkpoint with current_epoch"):
|
|
new_trainer.fit(model, datamodule=dm, ckpt_path=checkpoint_filepath)
|
|
|
|
|
|
@RunIf(sklearn=True)
|
|
def test_early_stopping_no_extraneous_invocations(tmpdir):
|
|
"""Test to ensure that callback methods aren't being invoked outside of the callback handler."""
|
|
model = ClassificationModel()
|
|
dm = ClassifDataModule()
|
|
early_stop_callback = EarlyStopping(monitor="train_loss")
|
|
early_stop_callback._run_early_stopping_check = Mock()
|
|
expected_count = 4
|
|
trainer = Trainer(
|
|
default_root_dir=tmpdir,
|
|
callbacks=[early_stop_callback],
|
|
limit_train_batches=4,
|
|
limit_val_batches=4,
|
|
max_epochs=expected_count,
|
|
enable_checkpointing=False,
|
|
)
|
|
trainer.fit(model, datamodule=dm)
|
|
|
|
assert trainer.early_stopping_callback == early_stop_callback
|
|
assert trainer.early_stopping_callbacks == [early_stop_callback]
|
|
assert early_stop_callback._run_early_stopping_check.call_count == expected_count
|
|
|
|
|
|
@pytest.mark.parametrize(
|
|
("loss_values", "patience", "expected_stop_epoch"),
|
|
[([6, 5, 5, 5, 5, 5], 3, 4), ([6, 5, 4, 4, 3, 3], 1, 3), ([6, 5, 6, 5, 5, 5], 3, 4)],
|
|
)
|
|
def test_early_stopping_patience(tmpdir, loss_values: list, patience: int, expected_stop_epoch: int):
|
|
"""Test to ensure that early stopping is not triggered before patience is exhausted."""
|
|
|
|
class ModelOverrideValidationReturn(BoringModel):
|
|
validation_return_values = torch.tensor(loss_values)
|
|
|
|
def on_validation_epoch_end(self):
|
|
loss = self.validation_return_values[self.current_epoch]
|
|
self.log("test_val_loss", loss)
|
|
|
|
model = ModelOverrideValidationReturn()
|
|
early_stop_callback = EarlyStopping(monitor="test_val_loss", patience=patience, verbose=True)
|
|
trainer = Trainer(
|
|
default_root_dir=tmpdir,
|
|
callbacks=[early_stop_callback],
|
|
num_sanity_val_steps=0,
|
|
max_epochs=10,
|
|
enable_progress_bar=False,
|
|
)
|
|
trainer.fit(model)
|
|
assert trainer.current_epoch - 1 == expected_stop_epoch
|
|
|
|
|
|
@pytest.mark.parametrize("validation_step_none", [True, False])
|
|
@pytest.mark.parametrize(
|
|
("loss_values", "patience", "expected_stop_epoch"),
|
|
[([6, 5, 5, 5, 5, 5], 3, 4), ([6, 5, 4, 4, 3, 3], 1, 3), ([6, 5, 6, 5, 5, 5], 3, 4)],
|
|
)
|
|
def test_early_stopping_patience_train(
|
|
tmpdir, validation_step_none: bool, loss_values: list, patience: int, expected_stop_epoch: int
|
|
):
|
|
"""Test to ensure that early stopping is not triggered before patience is exhausted."""
|
|
|
|
class ModelOverrideTrainReturn(BoringModel):
|
|
train_return_values = torch.tensor(loss_values)
|
|
|
|
def on_train_epoch_end(self):
|
|
loss = self.train_return_values[self.current_epoch]
|
|
self.log("train_loss", loss)
|
|
|
|
model = ModelOverrideTrainReturn()
|
|
|
|
if validation_step_none:
|
|
model.validation_step = None
|
|
|
|
early_stop_callback = EarlyStopping(
|
|
monitor="train_loss", patience=patience, verbose=True, check_on_train_epoch_end=True
|
|
)
|
|
trainer = Trainer(
|
|
default_root_dir=tmpdir,
|
|
callbacks=[early_stop_callback],
|
|
num_sanity_val_steps=0,
|
|
max_epochs=10,
|
|
enable_progress_bar=False,
|
|
)
|
|
trainer.fit(model)
|
|
assert trainer.current_epoch - 1 == expected_stop_epoch
|
|
|
|
|
|
def test_pickling():
|
|
early_stopping = EarlyStopping(monitor="foo")
|
|
|
|
early_stopping_pickled = pickle.dumps(early_stopping)
|
|
early_stopping_loaded = pickle.loads(early_stopping_pickled)
|
|
assert vars(early_stopping) == vars(early_stopping_loaded)
|
|
|
|
early_stopping_pickled = cloudpickle.dumps(early_stopping)
|
|
early_stopping_loaded = cloudpickle.loads(early_stopping_pickled)
|
|
assert vars(early_stopping) == vars(early_stopping_loaded)
|
|
|
|
|
|
@RunIf(sklearn=True)
|
|
def test_early_stopping_no_val_step(tmpdir):
|
|
"""Test that early stopping callback falls back to training metrics when no validation defined."""
|
|
model = ClassificationModel()
|
|
dm = ClassifDataModule()
|
|
model.validation_step = None
|
|
model.val_dataloader = None
|
|
|
|
stopping = EarlyStopping(monitor="train_loss", min_delta=0.1, patience=0, check_on_train_epoch_end=True)
|
|
trainer = Trainer(default_root_dir=tmpdir, callbacks=[stopping], overfit_batches=0.20, max_epochs=10)
|
|
trainer.fit(model, datamodule=dm)
|
|
|
|
assert trainer.state.finished, f"Training failed with {trainer.state}"
|
|
assert trainer.current_epoch < trainer.max_epochs - 1
|
|
|
|
|
|
@pytest.mark.parametrize(
|
|
("stopping_threshold", "divergence_threshold", "losses", "expected_epoch"),
|
|
[
|
|
(None, None, [8, 4, 2, 3, 4, 5, 8, 10], 5),
|
|
(2.9, None, [9, 8, 7, 6, 5, 6, 4, 3, 2, 1], 8),
|
|
(None, 15.9, [9, 4, 2, 16, 32, 64], 3),
|
|
],
|
|
)
|
|
def test_early_stopping_thresholds(tmpdir, stopping_threshold, divergence_threshold, losses, expected_epoch):
|
|
class CurrentModel(BoringModel):
|
|
def on_validation_epoch_end(self):
|
|
val_loss = losses[self.current_epoch]
|
|
self.log("abc", val_loss)
|
|
|
|
model = CurrentModel()
|
|
early_stopping = EarlyStopping(
|
|
monitor="abc", stopping_threshold=stopping_threshold, divergence_threshold=divergence_threshold
|
|
)
|
|
trainer = Trainer(
|
|
default_root_dir=tmpdir,
|
|
callbacks=[early_stopping],
|
|
limit_train_batches=0.2,
|
|
limit_val_batches=0.2,
|
|
max_epochs=20,
|
|
)
|
|
trainer.fit(model)
|
|
assert trainer.current_epoch - 1 == expected_epoch, "early_stopping failed"
|
|
|
|
|
|
@pytest.mark.parametrize("stop_value", [torch.tensor(torch.inf), torch.tensor(torch.nan)])
|
|
def test_early_stopping_on_non_finite_monitor(tmpdir, stop_value):
|
|
losses = [4, 3, stop_value, 2, 1]
|
|
expected_stop_epoch = 2
|
|
|
|
class CurrentModel(BoringModel):
|
|
def on_validation_epoch_end(self):
|
|
val_loss = losses[self.current_epoch]
|
|
self.log("val_loss", val_loss)
|
|
|
|
model = CurrentModel()
|
|
early_stopping = EarlyStopping(monitor="val_loss", check_finite=True)
|
|
trainer = Trainer(
|
|
default_root_dir=tmpdir,
|
|
callbacks=[early_stopping],
|
|
limit_train_batches=0.2,
|
|
limit_val_batches=0.2,
|
|
max_epochs=10,
|
|
)
|
|
trainer.fit(model)
|
|
assert trainer.current_epoch - 1 == expected_stop_epoch
|
|
assert early_stopping.stopped_epoch == expected_stop_epoch
|
|
|
|
|
|
@pytest.mark.parametrize(
|
|
("limit_train_batches", "min_epochs", "min_steps", "stop_step"),
|
|
[
|
|
# IF `min_steps` was set to a higher value than the `trainer.global_step` when `early_stopping` is being
|
|
# triggered, THEN the trainer should continue until reaching `trainer.global_step == min_steps` and stop
|
|
(3, 0, 10, 10),
|
|
(5, 0, 10, 10),
|
|
# IF `min_epochs` resulted in a higher number of steps than the `trainer.global_step` when `early_stopping` is
|
|
# being triggered, THEN the trainer should continue until reaching
|
|
# `trainer.global_step` == `min_epochs * len(train_dataloader)`
|
|
(3, 2, 0, 6),
|
|
(5, 2, 0, 10),
|
|
# IF both `min_epochs` and `min_steps` are provided and higher than the `trainer.global_step` when
|
|
# `early_stopping` is being triggered, THEN the highest between `min_epochs * len(train_dataloader)` and
|
|
# `min_steps` would be reached
|
|
(3, 1, 10, 10),
|
|
(5, 1, 10, 10),
|
|
(3, 3, 10, 10),
|
|
(5, 3, 10, 15),
|
|
],
|
|
)
|
|
def test_min_epochs_min_steps_global_step(tmpdir, limit_train_batches, min_epochs, min_steps, stop_step):
|
|
if min_steps:
|
|
assert limit_train_batches < min_steps
|
|
|
|
class TestModel(BoringModel):
|
|
def training_step(self, batch, batch_idx):
|
|
self.log("foo", batch_idx)
|
|
return super().training_step(batch, batch_idx)
|
|
|
|
es_callback = EarlyStopping("foo")
|
|
trainer = Trainer(
|
|
default_root_dir=tmpdir,
|
|
callbacks=es_callback,
|
|
limit_val_batches=0,
|
|
limit_train_batches=limit_train_batches,
|
|
min_epochs=min_epochs,
|
|
min_steps=min_steps,
|
|
logger=False,
|
|
enable_checkpointing=False,
|
|
enable_progress_bar=False,
|
|
enable_model_summary=False,
|
|
)
|
|
model = TestModel()
|
|
|
|
expected_epochs = max(math.ceil(min_steps / limit_train_batches), min_epochs)
|
|
# trigger early stopping directly after the first epoch
|
|
side_effect = [(True, "")] * expected_epochs
|
|
with mock.patch.object(es_callback, "_evaluate_stopping_criteria", side_effect=side_effect):
|
|
trainer.fit(model)
|
|
|
|
assert trainer.should_stop
|
|
# epochs continue until min steps are reached
|
|
assert trainer.current_epoch == expected_epochs
|
|
# steps continue until min steps are reached AND the epoch is exhausted
|
|
assert trainer.global_step == stop_step
|
|
|
|
|
|
def test_early_stopping_mode_options():
|
|
with pytest.raises(MisconfigurationException, match="`mode` can be .* got unknown_option"):
|
|
EarlyStopping(monitor="foo", mode="unknown_option")
|
|
|
|
|
|
class EarlyStoppingModel(BoringModel):
|
|
def __init__(self, expected_end_epoch: int, early_stop_on_train: bool, dist_diverge_epoch: Optional[int] = None):
|
|
super().__init__()
|
|
self.expected_end_epoch = expected_end_epoch
|
|
self.early_stop_on_train = early_stop_on_train
|
|
self.dist_diverge_epoch = dist_diverge_epoch
|
|
|
|
def _dist_diverge(self):
|
|
should_diverge = (
|
|
self.dist_diverge_epoch and self.current_epoch >= self.dist_diverge_epoch and self.trainer.global_rank == 0
|
|
)
|
|
return 10 if should_diverge else None
|
|
|
|
def _epoch_end(self) -> None:
|
|
losses = [8, 4, 2, 3, 4, 5, 8, 10]
|
|
loss = self._dist_diverge() or losses[self.current_epoch]
|
|
self.log("abc", torch.tensor(loss))
|
|
self.log("cba", torch.tensor(0))
|
|
|
|
def on_train_epoch_end(self):
|
|
if not self.early_stop_on_train:
|
|
return
|
|
self._epoch_end()
|
|
|
|
def on_validation_epoch_end(self):
|
|
if self.early_stop_on_train:
|
|
return
|
|
self._epoch_end()
|
|
|
|
def on_train_end(self) -> None:
|
|
assert self.trainer.current_epoch - 1 == self.expected_end_epoch, "Early Stopping Failed"
|
|
|
|
|
|
_ES_CHECK = {"check_on_train_epoch_end": True}
|
|
_ES_CHECK_P3 = {"patience": 3, "check_on_train_epoch_end": True}
|
|
_SPAWN_MARK = {"marks": RunIf(skip_windows=True)}
|
|
|
|
|
|
@pytest.mark.parametrize(
|
|
("callbacks", "expected_stop_epoch", "check_on_train_epoch_end", "strategy", "devices", "dist_diverge_epoch"),
|
|
[
|
|
([EarlyStopping("abc"), EarlyStopping("cba", patience=3)], 3, False, "auto", 1, None),
|
|
([EarlyStopping("cba", patience=3), EarlyStopping("abc")], 3, False, "auto", 1, None),
|
|
pytest.param(
|
|
[EarlyStopping("abc", patience=1), EarlyStopping("cba")], 2, False, "ddp_spawn", 2, 2, **_SPAWN_MARK
|
|
),
|
|
pytest.param(
|
|
[EarlyStopping("abc"), EarlyStopping("cba", patience=3)], 3, False, "ddp_spawn", 2, None, **_SPAWN_MARK
|
|
),
|
|
pytest.param(
|
|
[EarlyStopping("cba", patience=3), EarlyStopping("abc")], 3, False, "ddp_spawn", 2, None, **_SPAWN_MARK
|
|
),
|
|
([EarlyStopping("abc", **_ES_CHECK), EarlyStopping("cba", **_ES_CHECK_P3)], 3, True, "auto", 1, None),
|
|
([EarlyStopping("cba", **_ES_CHECK_P3), EarlyStopping("abc", **_ES_CHECK)], 3, True, "auto", 1, None),
|
|
pytest.param(
|
|
[EarlyStopping("abc", **_ES_CHECK), EarlyStopping("cba", **_ES_CHECK_P3)],
|
|
3,
|
|
True,
|
|
"ddp_spawn",
|
|
2,
|
|
None,
|
|
**_SPAWN_MARK,
|
|
),
|
|
pytest.param(
|
|
[EarlyStopping("cba", **_ES_CHECK_P3), EarlyStopping("abc", **_ES_CHECK)],
|
|
3,
|
|
True,
|
|
"ddp_spawn",
|
|
2,
|
|
None,
|
|
**_SPAWN_MARK,
|
|
),
|
|
],
|
|
)
|
|
def test_multiple_early_stopping_callbacks(
|
|
tmpdir,
|
|
callbacks: List[EarlyStopping],
|
|
expected_stop_epoch: int,
|
|
check_on_train_epoch_end: bool,
|
|
strategy: str,
|
|
devices: int,
|
|
dist_diverge_epoch: Optional[int],
|
|
):
|
|
"""Ensure when using multiple early stopping callbacks we stop if any signals we should stop."""
|
|
|
|
model = EarlyStoppingModel(expected_stop_epoch, check_on_train_epoch_end, dist_diverge_epoch=dist_diverge_epoch)
|
|
|
|
trainer = Trainer(
|
|
default_root_dir=tmpdir,
|
|
callbacks=callbacks,
|
|
limit_train_batches=0.1,
|
|
limit_val_batches=0.1,
|
|
max_epochs=20,
|
|
strategy=strategy,
|
|
accelerator="cpu",
|
|
devices=devices,
|
|
)
|
|
trainer.fit(model)
|
|
|
|
|
|
@pytest.mark.parametrize(
|
|
"case",
|
|
{
|
|
"val_check_interval": {"val_check_interval": 0.3, "limit_train_batches": 10, "max_epochs": 10},
|
|
"check_val_every_n_epoch": {"check_val_every_n_epoch": 2, "max_epochs": 5},
|
|
}.items(),
|
|
)
|
|
def test_check_on_train_epoch_end_smart_handling(tmpdir, case):
|
|
class TestModel(BoringModel):
|
|
def validation_step(self, batch, batch_idx):
|
|
self.log("foo", 1)
|
|
return super().validation_step(batch, batch_idx)
|
|
|
|
case, kwargs = case
|
|
model = TestModel()
|
|
trainer = Trainer(
|
|
default_root_dir=tmpdir,
|
|
limit_val_batches=1,
|
|
callbacks=EarlyStopping(monitor="foo"),
|
|
enable_progress_bar=False,
|
|
**kwargs,
|
|
)
|
|
|
|
side_effect = [(False, "A"), (True, "B")]
|
|
with mock.patch(
|
|
"lightning.pytorch.callbacks.EarlyStopping._evaluate_stopping_criteria", side_effect=side_effect
|
|
) as es_mock:
|
|
trainer.fit(model)
|
|
|
|
assert es_mock.call_count == len(side_effect)
|
|
if case == "val_check_interval":
|
|
assert trainer.global_step == len(side_effect) * int(trainer.limit_train_batches * trainer.val_check_interval)
|
|
else:
|
|
assert trainer.current_epoch == len(side_effect) * trainer.check_val_every_n_epoch
|
|
|
|
|
|
def test_early_stopping_squeezes():
|
|
early_stopping = EarlyStopping(monitor="foo")
|
|
trainer = Trainer()
|
|
trainer.callback_metrics["foo"] = torch.tensor([[[0]]])
|
|
|
|
with mock.patch(
|
|
"lightning.pytorch.callbacks.EarlyStopping._evaluate_stopping_criteria", return_value=(False, "")
|
|
) as es_mock:
|
|
early_stopping._run_early_stopping_check(trainer)
|
|
|
|
es_mock.assert_called_once_with(torch.tensor(0))
|
|
|
|
|
|
@pytest.mark.parametrize("trainer", [Trainer(), None])
|
|
@pytest.mark.parametrize(
|
|
("log_rank_zero_only", "world_size", "global_rank", "expected_log"),
|
|
[
|
|
(False, 1, 0, "bar"),
|
|
(False, 2, 0, "[rank: 0] bar"),
|
|
(False, 2, 1, "[rank: 1] bar"),
|
|
(True, 1, 0, "bar"),
|
|
(True, 2, 0, "[rank: 0] bar"),
|
|
(True, 2, 1, None),
|
|
],
|
|
)
|
|
def test_early_stopping_log_info(trainer, log_rank_zero_only, world_size, global_rank, expected_log):
|
|
"""Checks if log.info() gets called with expected message when used within EarlyStopping."""
|
|
# set the global_rank and world_size if trainer is not None
|
|
# or else always expect the simple logging message
|
|
if trainer:
|
|
trainer.strategy.global_rank = global_rank
|
|
trainer.strategy.world_size = world_size
|
|
else:
|
|
expected_log = "bar"
|
|
|
|
with mock.patch("lightning.pytorch.callbacks.early_stopping.log.info") as log_mock:
|
|
EarlyStopping._log_info(trainer, "bar", log_rank_zero_only)
|
|
|
|
# check log.info() was called or not with expected arg
|
|
if expected_log:
|
|
log_mock.assert_called_once_with(expected_log)
|
|
else:
|
|
log_mock.assert_not_called()
|