# Copyright The Lightning AI team. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import logging import math import os import pickle from typing import List, Optional from unittest import mock from unittest.mock import Mock import cloudpickle import pytest import torch from lightning.pytorch import seed_everything, Trainer from lightning.pytorch.callbacks import EarlyStopping, ModelCheckpoint from lightning.pytorch.demos.boring_classes import BoringModel from lightning.pytorch.utilities.exceptions import MisconfigurationException from tests_pytorch.helpers.datamodules import ClassifDataModule from tests_pytorch.helpers.runif import RunIf from tests_pytorch.helpers.simple_models import ClassificationModel _logger = logging.getLogger(__name__) def test_early_stopping_state_key(): early_stopping = EarlyStopping(monitor="val_loss") assert early_stopping.state_key == "EarlyStopping{'monitor': 'val_loss', 'mode': 'min'}" class EarlyStoppingTestRestore(EarlyStopping): # this class has to be defined outside the test function, otherwise we get pickle error def __init__(self, expected_state, *args, **kwargs): super().__init__(*args, **kwargs) self.expected_state = expected_state # cache the state for each epoch self.saved_states = [] def on_train_start(self, trainer, pl_module): if self.expected_state: assert self.state_dict() == self.expected_state def on_train_epoch_end(self, trainer, pl_module): super().on_train_epoch_end(trainer, pl_module) self.saved_states.append(self.state_dict().copy()) @RunIf(sklearn=True) @mock.patch.dict(os.environ, os.environ.copy(), clear=True) def test_resume_early_stopping_from_checkpoint(tmpdir): """Prevent regressions to bugs: https://github.com/Lightning-AI/lightning/issues/1464 https://github.com/Lightning-AI/lightning/issues/1463 """ seed_everything(42) model = ClassificationModel() dm = ClassifDataModule() checkpoint_callback = ModelCheckpoint(dirpath=tmpdir, monitor="train_loss", save_top_k=1) early_stop_callback = EarlyStoppingTestRestore(None, monitor="train_loss") trainer = Trainer( default_root_dir=tmpdir, callbacks=[early_stop_callback, checkpoint_callback], num_sanity_val_steps=0, max_epochs=4, ) trainer.fit(model, datamodule=dm) assert len(early_stop_callback.saved_states) == 4 checkpoint_filepath = checkpoint_callback.kth_best_model_path # ensure state is persisted properly checkpoint = torch.load(checkpoint_filepath) # the checkpoint saves "epoch + 1" early_stop_callback_state = early_stop_callback.saved_states[checkpoint["epoch"]] assert len(early_stop_callback.saved_states) == 4 es_name = "EarlyStoppingTestRestore{'monitor': 'train_loss', 'mode': 'min'}" assert checkpoint["callbacks"][es_name] == early_stop_callback_state # ensure state is reloaded properly (assertion in the callback) early_stop_callback = EarlyStoppingTestRestore(early_stop_callback_state, monitor="train_loss") new_trainer = Trainer( default_root_dir=tmpdir, max_epochs=1, callbacks=[early_stop_callback], ) with pytest.raises(MisconfigurationException, match=r"You restored a checkpoint with current_epoch"): new_trainer.fit(model, datamodule=dm, ckpt_path=checkpoint_filepath) @RunIf(sklearn=True) def test_early_stopping_no_extraneous_invocations(tmpdir): """Test to ensure that callback methods aren't being invoked outside of the callback handler.""" model = ClassificationModel() dm = ClassifDataModule() early_stop_callback = EarlyStopping(monitor="train_loss") early_stop_callback._run_early_stopping_check = Mock() expected_count = 4 trainer = Trainer( default_root_dir=tmpdir, callbacks=[early_stop_callback], limit_train_batches=4, limit_val_batches=4, max_epochs=expected_count, enable_checkpointing=False, ) trainer.fit(model, datamodule=dm) assert trainer.early_stopping_callback == early_stop_callback assert trainer.early_stopping_callbacks == [early_stop_callback] assert early_stop_callback._run_early_stopping_check.call_count == expected_count @pytest.mark.parametrize( ("loss_values", "patience", "expected_stop_epoch"), [([6, 5, 5, 5, 5, 5], 3, 4), ([6, 5, 4, 4, 3, 3], 1, 3), ([6, 5, 6, 5, 5, 5], 3, 4)], ) def test_early_stopping_patience(tmpdir, loss_values: list, patience: int, expected_stop_epoch: int): """Test to ensure that early stopping is not triggered before patience is exhausted.""" class ModelOverrideValidationReturn(BoringModel): validation_return_values = torch.tensor(loss_values) def on_validation_epoch_end(self): loss = self.validation_return_values[self.current_epoch] self.log("test_val_loss", loss) model = ModelOverrideValidationReturn() early_stop_callback = EarlyStopping(monitor="test_val_loss", patience=patience, verbose=True) trainer = Trainer( default_root_dir=tmpdir, callbacks=[early_stop_callback], num_sanity_val_steps=0, max_epochs=10, enable_progress_bar=False, ) trainer.fit(model) assert trainer.current_epoch - 1 == expected_stop_epoch @pytest.mark.parametrize("validation_step_none", [True, False]) @pytest.mark.parametrize( ("loss_values", "patience", "expected_stop_epoch"), [([6, 5, 5, 5, 5, 5], 3, 4), ([6, 5, 4, 4, 3, 3], 1, 3), ([6, 5, 6, 5, 5, 5], 3, 4)], ) def test_early_stopping_patience_train( tmpdir, validation_step_none: bool, loss_values: list, patience: int, expected_stop_epoch: int ): """Test to ensure that early stopping is not triggered before patience is exhausted.""" class ModelOverrideTrainReturn(BoringModel): train_return_values = torch.tensor(loss_values) def on_train_epoch_end(self): loss = self.train_return_values[self.current_epoch] self.log("train_loss", loss) model = ModelOverrideTrainReturn() if validation_step_none: model.validation_step = None early_stop_callback = EarlyStopping( monitor="train_loss", patience=patience, verbose=True, check_on_train_epoch_end=True ) trainer = Trainer( default_root_dir=tmpdir, callbacks=[early_stop_callback], num_sanity_val_steps=0, max_epochs=10, enable_progress_bar=False, ) trainer.fit(model) assert trainer.current_epoch - 1 == expected_stop_epoch def test_pickling(): early_stopping = EarlyStopping(monitor="foo") early_stopping_pickled = pickle.dumps(early_stopping) early_stopping_loaded = pickle.loads(early_stopping_pickled) assert vars(early_stopping) == vars(early_stopping_loaded) early_stopping_pickled = cloudpickle.dumps(early_stopping) early_stopping_loaded = cloudpickle.loads(early_stopping_pickled) assert vars(early_stopping) == vars(early_stopping_loaded) @RunIf(sklearn=True) def test_early_stopping_no_val_step(tmpdir): """Test that early stopping callback falls back to training metrics when no validation defined.""" model = ClassificationModel() dm = ClassifDataModule() model.validation_step = None model.val_dataloader = None stopping = EarlyStopping(monitor="train_loss", min_delta=0.1, patience=0, check_on_train_epoch_end=True) trainer = Trainer(default_root_dir=tmpdir, callbacks=[stopping], overfit_batches=0.20, max_epochs=10) trainer.fit(model, datamodule=dm) assert trainer.state.finished, f"Training failed with {trainer.state}" assert trainer.current_epoch < trainer.max_epochs - 1 @pytest.mark.parametrize( ("stopping_threshold", "divergence_threshold", "losses", "expected_epoch"), [ (None, None, [8, 4, 2, 3, 4, 5, 8, 10], 5), (2.9, None, [9, 8, 7, 6, 5, 6, 4, 3, 2, 1], 8), (None, 15.9, [9, 4, 2, 16, 32, 64], 3), ], ) def test_early_stopping_thresholds(tmpdir, stopping_threshold, divergence_threshold, losses, expected_epoch): class CurrentModel(BoringModel): def on_validation_epoch_end(self): val_loss = losses[self.current_epoch] self.log("abc", val_loss) model = CurrentModel() early_stopping = EarlyStopping( monitor="abc", stopping_threshold=stopping_threshold, divergence_threshold=divergence_threshold ) trainer = Trainer( default_root_dir=tmpdir, callbacks=[early_stopping], limit_train_batches=0.2, limit_val_batches=0.2, max_epochs=20, ) trainer.fit(model) assert trainer.current_epoch - 1 == expected_epoch, "early_stopping failed" @pytest.mark.parametrize("stop_value", [torch.tensor(torch.inf), torch.tensor(torch.nan)]) def test_early_stopping_on_non_finite_monitor(tmpdir, stop_value): losses = [4, 3, stop_value, 2, 1] expected_stop_epoch = 2 class CurrentModel(BoringModel): def on_validation_epoch_end(self): val_loss = losses[self.current_epoch] self.log("val_loss", val_loss) model = CurrentModel() early_stopping = EarlyStopping(monitor="val_loss", check_finite=True) trainer = Trainer( default_root_dir=tmpdir, callbacks=[early_stopping], limit_train_batches=0.2, limit_val_batches=0.2, max_epochs=10, ) trainer.fit(model) assert trainer.current_epoch - 1 == expected_stop_epoch assert early_stopping.stopped_epoch == expected_stop_epoch @pytest.mark.parametrize( ("limit_train_batches", "min_epochs", "min_steps", "stop_step"), [ # IF `min_steps` was set to a higher value than the `trainer.global_step` when `early_stopping` is being # triggered, THEN the trainer should continue until reaching `trainer.global_step == min_steps` and stop (3, 0, 10, 10), (5, 0, 10, 10), # IF `min_epochs` resulted in a higher number of steps than the `trainer.global_step` when `early_stopping` is # being triggered, THEN the trainer should continue until reaching # `trainer.global_step` == `min_epochs * len(train_dataloader)` (3, 2, 0, 6), (5, 2, 0, 10), # IF both `min_epochs` and `min_steps` are provided and higher than the `trainer.global_step` when # `early_stopping` is being triggered, THEN the highest between `min_epochs * len(train_dataloader)` and # `min_steps` would be reached (3, 1, 10, 10), (5, 1, 10, 10), (3, 3, 10, 10), (5, 3, 10, 15), ], ) def test_min_epochs_min_steps_global_step(tmpdir, limit_train_batches, min_epochs, min_steps, stop_step): if min_steps: assert limit_train_batches < min_steps class TestModel(BoringModel): def training_step(self, batch, batch_idx): self.log("foo", batch_idx) return super().training_step(batch, batch_idx) es_callback = EarlyStopping("foo") trainer = Trainer( default_root_dir=tmpdir, callbacks=es_callback, limit_val_batches=0, limit_train_batches=limit_train_batches, min_epochs=min_epochs, min_steps=min_steps, logger=False, enable_checkpointing=False, enable_progress_bar=False, enable_model_summary=False, ) model = TestModel() expected_epochs = max(math.ceil(min_steps / limit_train_batches), min_epochs) # trigger early stopping directly after the first epoch side_effect = [(True, "")] * expected_epochs with mock.patch.object(es_callback, "_evaluate_stopping_criteria", side_effect=side_effect): trainer.fit(model) assert trainer.should_stop # epochs continue until min steps are reached assert trainer.current_epoch == expected_epochs # steps continue until min steps are reached AND the epoch is exhausted assert trainer.global_step == stop_step def test_early_stopping_mode_options(): with pytest.raises(MisconfigurationException, match="`mode` can be .* got unknown_option"): EarlyStopping(monitor="foo", mode="unknown_option") class EarlyStoppingModel(BoringModel): def __init__(self, expected_end_epoch: int, early_stop_on_train: bool, dist_diverge_epoch: Optional[int] = None): super().__init__() self.expected_end_epoch = expected_end_epoch self.early_stop_on_train = early_stop_on_train self.dist_diverge_epoch = dist_diverge_epoch def _dist_diverge(self): should_diverge = ( self.dist_diverge_epoch and self.current_epoch >= self.dist_diverge_epoch and self.trainer.global_rank == 0 ) return 10 if should_diverge else None def _epoch_end(self) -> None: losses = [8, 4, 2, 3, 4, 5, 8, 10] loss = self._dist_diverge() or losses[self.current_epoch] self.log("abc", torch.tensor(loss)) self.log("cba", torch.tensor(0)) def on_train_epoch_end(self): if not self.early_stop_on_train: return self._epoch_end() def on_validation_epoch_end(self): if self.early_stop_on_train: return self._epoch_end() def on_train_end(self) -> None: assert self.trainer.current_epoch - 1 == self.expected_end_epoch, "Early Stopping Failed" _ES_CHECK = {"check_on_train_epoch_end": True} _ES_CHECK_P3 = {"patience": 3, "check_on_train_epoch_end": True} _SPAWN_MARK = {"marks": RunIf(skip_windows=True)} @pytest.mark.parametrize( ("callbacks", "expected_stop_epoch", "check_on_train_epoch_end", "strategy", "devices", "dist_diverge_epoch"), [ ([EarlyStopping("abc"), EarlyStopping("cba", patience=3)], 3, False, "auto", 1, None), ([EarlyStopping("cba", patience=3), EarlyStopping("abc")], 3, False, "auto", 1, None), pytest.param( [EarlyStopping("abc", patience=1), EarlyStopping("cba")], 2, False, "ddp_spawn", 2, 2, **_SPAWN_MARK ), pytest.param( [EarlyStopping("abc"), EarlyStopping("cba", patience=3)], 3, False, "ddp_spawn", 2, None, **_SPAWN_MARK ), pytest.param( [EarlyStopping("cba", patience=3), EarlyStopping("abc")], 3, False, "ddp_spawn", 2, None, **_SPAWN_MARK ), ([EarlyStopping("abc", **_ES_CHECK), EarlyStopping("cba", **_ES_CHECK_P3)], 3, True, "auto", 1, None), ([EarlyStopping("cba", **_ES_CHECK_P3), EarlyStopping("abc", **_ES_CHECK)], 3, True, "auto", 1, None), pytest.param( [EarlyStopping("abc", **_ES_CHECK), EarlyStopping("cba", **_ES_CHECK_P3)], 3, True, "ddp_spawn", 2, None, **_SPAWN_MARK, ), pytest.param( [EarlyStopping("cba", **_ES_CHECK_P3), EarlyStopping("abc", **_ES_CHECK)], 3, True, "ddp_spawn", 2, None, **_SPAWN_MARK, ), ], ) def test_multiple_early_stopping_callbacks( tmpdir, callbacks: List[EarlyStopping], expected_stop_epoch: int, check_on_train_epoch_end: bool, strategy: str, devices: int, dist_diverge_epoch: Optional[int], ): """Ensure when using multiple early stopping callbacks we stop if any signals we should stop.""" model = EarlyStoppingModel(expected_stop_epoch, check_on_train_epoch_end, dist_diverge_epoch=dist_diverge_epoch) trainer = Trainer( default_root_dir=tmpdir, callbacks=callbacks, limit_train_batches=0.1, limit_val_batches=0.1, max_epochs=20, strategy=strategy, accelerator="cpu", devices=devices, ) trainer.fit(model) @pytest.mark.parametrize( "case", { "val_check_interval": {"val_check_interval": 0.3, "limit_train_batches": 10, "max_epochs": 10}, "check_val_every_n_epoch": {"check_val_every_n_epoch": 2, "max_epochs": 5}, }.items(), ) def test_check_on_train_epoch_end_smart_handling(tmpdir, case): class TestModel(BoringModel): def validation_step(self, batch, batch_idx): self.log("foo", 1) return super().validation_step(batch, batch_idx) case, kwargs = case model = TestModel() trainer = Trainer( default_root_dir=tmpdir, limit_val_batches=1, callbacks=EarlyStopping(monitor="foo"), enable_progress_bar=False, **kwargs, ) side_effect = [(False, "A"), (True, "B")] with mock.patch( "lightning.pytorch.callbacks.EarlyStopping._evaluate_stopping_criteria", side_effect=side_effect ) as es_mock: trainer.fit(model) assert es_mock.call_count == len(side_effect) if case == "val_check_interval": assert trainer.global_step == len(side_effect) * int(trainer.limit_train_batches * trainer.val_check_interval) else: assert trainer.current_epoch == len(side_effect) * trainer.check_val_every_n_epoch def test_early_stopping_squeezes(): early_stopping = EarlyStopping(monitor="foo") trainer = Trainer() trainer.callback_metrics["foo"] = torch.tensor([[[0]]]) with mock.patch( "lightning.pytorch.callbacks.EarlyStopping._evaluate_stopping_criteria", return_value=(False, "") ) as es_mock: early_stopping._run_early_stopping_check(trainer) es_mock.assert_called_once_with(torch.tensor(0)) @pytest.mark.parametrize("trainer", [Trainer(), None]) @pytest.mark.parametrize( ("log_rank_zero_only", "world_size", "global_rank", "expected_log"), [ (False, 1, 0, "bar"), (False, 2, 0, "[rank: 0] bar"), (False, 2, 1, "[rank: 1] bar"), (True, 1, 0, "bar"), (True, 2, 0, "[rank: 0] bar"), (True, 2, 1, None), ], ) def test_early_stopping_log_info(trainer, log_rank_zero_only, world_size, global_rank, expected_log): """Checks if log.info() gets called with expected message when used within EarlyStopping.""" # set the global_rank and world_size if trainer is not None # or else always expect the simple logging message if trainer: trainer.strategy.global_rank = global_rank trainer.strategy.world_size = world_size else: expected_log = "bar" with mock.patch("lightning.pytorch.callbacks.early_stopping.log.info") as log_mock: EarlyStopping._log_info(trainer, "bar", log_rank_zero_only) # check log.info() was called or not with expected arg if expected_log: log_mock.assert_called_once_with(expected_log) else: log_mock.assert_not_called()