# Copyright The PyTorch Lightning team. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import math import os import pickle import re import time from argparse import Namespace from datetime import timedelta from logging import INFO from pathlib import Path from typing import Union from unittest import mock from unittest.mock import call, Mock, patch import cloudpickle import pytest import torch import yaml from torch import optim import pytorch_lightning as pl import tests.helpers.utils as tutils from pytorch_lightning import seed_everything, Trainer from pytorch_lightning.callbacks import ModelCheckpoint from pytorch_lightning.loggers import TensorBoardLogger from pytorch_lightning.utilities.cloud_io import load as pl_load from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.imports import _OMEGACONF_AVAILABLE from tests.helpers import BoringModel from tests.helpers.runif import RunIf if _OMEGACONF_AVAILABLE: from omegaconf import Container, OmegaConf def test_model_checkpoint_state_key(): early_stopping = ModelCheckpoint(monitor="val_loss") expected_id = ( "ModelCheckpoint{'monitor': 'val_loss', 'mode': 'min', 'every_n_train_steps': 0, 'every_n_epochs': 1," " 'train_time_interval': None, 'save_on_train_epoch_end': None}" ) assert early_stopping.state_key == expected_id class LogInTwoMethods(BoringModel): def training_step(self, batch, batch_idx): out = super().training_step(batch, batch_idx) self.log("early_stop_on", out["loss"]) return out def validation_epoch_end(self, outputs): outs = torch.stack([x["x"] for x in outputs]).mean() self.log("val_acc", outs) def mock_training_epoch_loop(trainer): # do not use `unittest.Mock` because we need to store the return value calls = {} old_get_monitor_value = trainer.fit_loop.epoch_loop._get_monitor_value def mock(key): value = old_get_monitor_value(key) calls[trainer.current_epoch] = {key: value} return value trainer.fit_loop.epoch_loop._get_monitor_value = mock return calls @pytest.mark.parametrize( "validation_step_none,val_dataloaders_none,monitor", [(False, False, "val_log"), (True, False, "train_log_epoch"), (False, True, "val_log")], ) @pytest.mark.parametrize("reduce_lr_on_plateau", [False, True]) def test_model_checkpoint_score_and_ckpt( tmpdir, validation_step_none: bool, val_dataloaders_none: bool, monitor: str, reduce_lr_on_plateau: bool ): """Test that when a model checkpoint is saved, it saves with the correct score appended to ckpt_path and checkpoint data.""" max_epochs = 3 limit_train_batches = 5 limit_val_batches = 7 lr, gamma = 1e-1, 2 class CustomBoringModel(BoringModel): def __init__(self): super().__init__() self.train_log_epochs = torch.randn(max_epochs, limit_train_batches) self.val_logs = torch.randn(max_epochs, limit_val_batches) self.scores = [] def training_step(self, batch, batch_idx): log_value = self.train_log_epochs[self.current_epoch, batch_idx] self.log("train_log", log_value, on_epoch=True) return super().training_step(batch, batch_idx) def validation_step(self, batch, batch_idx): log_value = self.val_logs[self.current_epoch, batch_idx] self.log("val_log", log_value) return super().validation_step(batch, batch_idx) def configure_optimizers(self): optimizer = optim.SGD(self.parameters(), lr=lr) if reduce_lr_on_plateau: lr_scheduler = { "scheduler": optim.lr_scheduler.ReduceLROnPlateau(optimizer), "monitor": monitor, "strict": True, } else: lr_scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=1, gamma=gamma) return [optimizer], [lr_scheduler] def on_train_epoch_end(self): if "train" in monitor: self.scores.append(self.trainer.logged_metrics[monitor]) def on_validation_epoch_end(self): if not self.trainer.sanity_checking and "val" in monitor: self.scores.append(self.trainer.logged_metrics[monitor]) filename = "{" + f"{monitor}" + ":.4f}-{epoch}" checkpoint = ModelCheckpoint(dirpath=tmpdir, filename=filename, monitor=monitor, save_top_k=-1) model = CustomBoringModel() if validation_step_none: model.validation_step = None if val_dataloaders_none: model.val_dataloaders = None trainer = Trainer( default_root_dir=tmpdir, callbacks=[checkpoint], limit_train_batches=limit_train_batches, limit_val_batches=limit_val_batches, max_epochs=max_epochs, enable_progress_bar=False, ) calls = mock_training_epoch_loop(trainer) trainer.fit(model) ckpt_files = list(Path(tmpdir).glob("*.ckpt")) assert len(ckpt_files) == len(model.scores) == max_epochs for epoch in range(max_epochs): score = model.scores[epoch] expected_score = getattr(model, f"{monitor}s")[epoch].mean().item() assert math.isclose(score, expected_score, rel_tol=1e-4) expected_filename = f"{monitor}={score:.4f}-epoch={epoch}.ckpt" chk = pl_load(os.path.join(checkpoint.dirpath, expected_filename)) assert chk["epoch"] == epoch assert chk["global_step"] == limit_train_batches * (epoch + 1) mc_specific_data = chk["callbacks"][ f"ModelCheckpoint{{'monitor': '{monitor}', 'mode': 'min', 'every_n_train_steps': 0, 'every_n_epochs': 1," " 'train_time_interval': None, 'save_on_train_epoch_end': True}" ] assert mc_specific_data["dirpath"] == checkpoint.dirpath assert mc_specific_data["monitor"] == monitor assert mc_specific_data["current_score"] == score if not reduce_lr_on_plateau: actual_step_count = chk["lr_schedulers"][0]["_step_count"] actual_lr = chk["lr_schedulers"][0]["_last_lr"][0] # checkpoint is saved after updating lr_scheduler states assert actual_step_count == epoch + 2 # step_count starts at 1 assert actual_lr == lr * gamma ** (epoch + 1) else: assert calls[epoch] == {monitor: score} @pytest.mark.parametrize( "val_check_interval,reduce_lr_on_plateau,epoch_aligned", [(0.25, True, True), (0.25, False, True), (0.42, False, False)], ) def test_model_checkpoint_score_and_ckpt_val_check_interval( tmpdir, val_check_interval, reduce_lr_on_plateau, epoch_aligned ): """Test that when a model checkpoint is saved, it saves with the correct score appended to ckpt_path and checkpoint data with val_check_interval.""" seed_everything(0) max_epochs = 3 limit_train_batches = 12 limit_val_batches = 7 lr, gamma = 1e-1, 2 monitor = "val_log" per_val_train_batches = int(limit_train_batches * val_check_interval) per_epoch_val_checks, leftover_train_batches = divmod(limit_train_batches, per_val_train_batches) class CustomBoringModel(BoringModel): def __init__(self): super().__init__() self.val_logs = torch.randn(per_epoch_val_checks * max_epochs, limit_val_batches) self.val_loop_count = 0 self.scores = [] def validation_step(self, batch, batch_idx): log_value = self.val_logs[self.val_loop_count, batch_idx] self.log("val_log", log_value) return super().validation_step(batch, batch_idx) def validation_epoch_end(self, outputs): self.val_loop_count += 1 super().validation_epoch_end(outputs) self.scores.append(self.trainer.logged_metrics[monitor]) def configure_optimizers(self): optimizer = optim.SGD(self.parameters(), lr=lr) if reduce_lr_on_plateau: lr_scheduler = { "scheduler": optim.lr_scheduler.ReduceLROnPlateau(optimizer), "monitor": monitor, "strict": True, } else: lr_scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=1, gamma=gamma) return [optimizer], [lr_scheduler] filename = "{" + f"{monitor}" + ":.4f}-{epoch}" checkpoint = ModelCheckpoint(dirpath=tmpdir, filename=filename, monitor=monitor, save_top_k=-1) model = CustomBoringModel() trainer = Trainer( default_root_dir=tmpdir, callbacks=[checkpoint], limit_train_batches=limit_train_batches, limit_val_batches=limit_val_batches, max_epochs=max_epochs, val_check_interval=val_check_interval, enable_progress_bar=False, num_sanity_val_steps=0, ) calls = mock_training_epoch_loop(trainer) trainer.fit(model) def _make_assertions(epoch, ix): global_ix = ix + per_epoch_val_checks * epoch # checkpoint saved at the end of training epoch will have updated lr_scheduler states epoch_end_checkpoint = epoch_aligned and ix == (per_epoch_val_checks - 1) score = model.scores[global_ix] expected_score = getattr(model, f"{monitor}s")[global_ix].mean().item() expected_filename = f"{monitor}={score:.4f}-epoch={epoch}.ckpt" assert math.isclose(score, expected_score, rel_tol=1e-4) chk = pl_load(os.path.join(checkpoint.dirpath, expected_filename)) assert chk["epoch"] == epoch expected_global_step = per_val_train_batches * (global_ix + 1) + (leftover_train_batches * epoch) assert chk["global_step"] == expected_global_step mc_specific_data = chk["callbacks"][ f"ModelCheckpoint{{'monitor': '{monitor}', 'mode': 'min', 'every_n_train_steps': 0, 'every_n_epochs': 1," " 'train_time_interval': None, 'save_on_train_epoch_end': False}" ] assert mc_specific_data["dirpath"] == checkpoint.dirpath assert mc_specific_data["monitor"] == monitor assert mc_specific_data["current_score"] == score if not reduce_lr_on_plateau: actual_step_count = chk["lr_schedulers"][0]["_step_count"] actual_lr = chk["lr_schedulers"][0]["_last_lr"][0] assert actual_step_count == epoch + 1 + epoch_end_checkpoint assert actual_lr == lr * gamma ** (epoch + epoch_end_checkpoint) return score ckpt_files = list(Path(tmpdir).glob("*.ckpt")) assert len(ckpt_files) == len(model.scores) == per_epoch_val_checks * max_epochs for epoch in range(max_epochs): for i in range(per_epoch_val_checks): score = _make_assertions(epoch, i) if reduce_lr_on_plateau: assert calls[epoch] == {monitor: score} @pytest.mark.parametrize("save_top_k", [-1, 0, 1, 2]) def test_model_checkpoint_with_non_string_input(tmpdir, save_top_k: int): """Test that dirpath=None in checkpoint callback is valid and that ckpt_path is set correctly.""" tutils.reset_seed() model = LogInTwoMethods() checkpoint = ModelCheckpoint(monitor="early_stop_on", dirpath=None, filename="{epoch}", save_top_k=save_top_k) max_epochs = 2 trainer = Trainer(default_root_dir=tmpdir, callbacks=[checkpoint], overfit_batches=0.20, max_epochs=max_epochs) trainer.fit(model) assert checkpoint.dirpath == tmpdir / trainer.logger.name / "version_0" / "checkpoints" if save_top_k == -1: ckpt_files = os.listdir(checkpoint.dirpath) expected_ckpt_files = [f"epoch={i}.ckpt" for i in range(max_epochs)] assert len(ckpt_files) == len(expected_ckpt_files) == max_epochs assert set(ckpt_files) == set(expected_ckpt_files) @pytest.mark.parametrize("save_top_k", [-1, 0, 1, 2]) def test_model_checkpoint_to_yaml(tmpdir, save_top_k: int): """Test that None in checkpoint callback is valid and that chkp_path is set correctly.""" tutils.reset_seed() model = LogInTwoMethods() checkpoint = ModelCheckpoint(dirpath=tmpdir, monitor="early_stop_on", save_top_k=save_top_k) trainer = Trainer(default_root_dir=tmpdir, callbacks=[checkpoint], overfit_batches=0.20, max_epochs=2) trainer.fit(model) path_yaml = os.path.join(tmpdir, "best_k_models.yaml") checkpoint.to_yaml(path_yaml) d = yaml.full_load(open(path_yaml)) best_k = dict(checkpoint.best_k_models.items()) assert d == best_k @pytest.mark.parametrize("logger_version,expected", [(None, "version_0"), (1, "version_1"), ("awesome", "awesome")]) def test_model_checkpoint_path(tmpdir, logger_version: Union[None, int, str], expected: str): """Test that "version_" prefix is only added when logger's version is an integer.""" tutils.reset_seed() model = LogInTwoMethods() logger = TensorBoardLogger(str(tmpdir), version=logger_version) trainer = Trainer(default_root_dir=tmpdir, overfit_batches=0.2, max_epochs=2, logger=logger) trainer.fit(model) ckpt_version = Path(trainer.checkpoint_callback.dirpath).parent.name assert ckpt_version == expected def test_pickling(tmpdir): ckpt = ModelCheckpoint(dirpath=tmpdir) ckpt_pickled = pickle.dumps(ckpt) ckpt_loaded = pickle.loads(ckpt_pickled) assert vars(ckpt) == vars(ckpt_loaded) ckpt_pickled = cloudpickle.dumps(ckpt) ckpt_loaded = cloudpickle.loads(ckpt_pickled) assert vars(ckpt) == vars(ckpt_loaded) class ModelCheckpointTestInvocations(ModelCheckpoint): # this class has to be defined outside the test function, otherwise we get pickle error # due to the way ddp process is launched def __init__(self, expected_count, *args, **kwargs): super().__init__(*args, **kwargs) self.expected_count = expected_count self.state_dict_count = 0 def on_train_start(self, trainer, pl_module): torch.save = Mock(wraps=torch.save) def state_dict(self): super().state_dict() self.state_dict_count += 1 def on_train_end(self, trainer, pl_module): super().on_train_end(trainer, pl_module) assert self.best_model_path assert self.best_model_score assert self.state_dict_count == self.expected_count if trainer.is_global_zero: assert torch.save.call_count == self.expected_count else: assert torch.save.call_count == 0 @RunIf(skip_windows=True) def test_model_checkpoint_no_extraneous_invocations(tmpdir): """Test to ensure that the model callback saves the checkpoints only once in distributed mode.""" model = LogInTwoMethods() num_epochs = 4 model_checkpoint = ModelCheckpointTestInvocations(monitor="early_stop_on", expected_count=num_epochs, save_top_k=-1) trainer = Trainer( strategy="ddp_spawn", accelerator="cpu", devices=2, default_root_dir=tmpdir, callbacks=[model_checkpoint], max_epochs=num_epochs, ) trainer.fit(model) assert trainer.state.finished, f"Training failed with {trainer.state}" def test_model_checkpoint_format_checkpoint_name(tmpdir): # empty filename: ckpt_name = ModelCheckpoint._format_checkpoint_name("", {"epoch": 3, "step": 2}) assert ckpt_name == "epoch=3-step=2" ckpt_name = ModelCheckpoint._format_checkpoint_name(None, {"epoch": 3, "step": 2}, prefix="test") assert ckpt_name == "test-epoch=3-step=2" # no groups case: ckpt_name = ModelCheckpoint._format_checkpoint_name("ckpt", {}, prefix="test") assert ckpt_name == "test-ckpt" # no prefix ckpt_name = ModelCheckpoint._format_checkpoint_name("{epoch:03d}-{acc}", {"epoch": 3, "acc": 0.03}) assert ckpt_name == "epoch=003-acc=0.03" # prefix char_org = ModelCheckpoint.CHECKPOINT_JOIN_CHAR ModelCheckpoint.CHECKPOINT_JOIN_CHAR = "@" ckpt_name = ModelCheckpoint._format_checkpoint_name("{epoch},{acc:.5f}", {"epoch": 3, "acc": 0.03}, prefix="test") assert ckpt_name == "test@epoch=3,acc=0.03000" ModelCheckpoint.CHECKPOINT_JOIN_CHAR = char_org # no dirpath set ckpt_name = ModelCheckpoint(monitor="early_stop_on", dirpath=None).format_checkpoint_name({"epoch": 3, "step": 2}) assert ckpt_name == "epoch=3-step=2.ckpt" ckpt_name = ModelCheckpoint(monitor="early_stop_on", dirpath="").format_checkpoint_name({"epoch": 5, "step": 4}) assert ckpt_name == "epoch=5-step=4.ckpt" # CWD ckpt_name = ModelCheckpoint(monitor="early_stop_on", dirpath=".").format_checkpoint_name({"epoch": 3, "step": 4}) assert ckpt_name == str(Path(".").resolve() / "epoch=3-step=4.ckpt") # with version ckpt = ModelCheckpoint(monitor="early_stop_on", dirpath=tmpdir, filename="name") ckpt_name = ckpt.format_checkpoint_name({}, ver=3) assert ckpt_name == tmpdir / "name-v3.ckpt" # using slashes ckpt = ModelCheckpoint(monitor="early_stop_on", dirpath=None, filename="{epoch}_{val/loss:.5f}") ckpt_name = ckpt.format_checkpoint_name({"epoch": 4, "val/loss": 0.03}) assert ckpt_name == "epoch=4_val/loss=0.03000.ckpt" # auto_insert_metric_name=False ckpt_name = ModelCheckpoint._format_checkpoint_name( "epoch={epoch:03d}-val_acc={val/acc}", {"epoch": 3, "val/acc": 0.03}, auto_insert_metric_name=False ) assert ckpt_name == "epoch=003-val_acc=0.03" # dots in the metric name ckpt_name = ModelCheckpoint._format_checkpoint_name( "mAP@0.50={val/mAP@0.50:.4f}", {"val/mAP@0.50": 0.2}, auto_insert_metric_name=False ) assert ckpt_name == "mAP@0.50=0.2000" class ModelCheckpointExtensionTest(ModelCheckpoint): FILE_EXTENSION = ".tpkc" def test_model_checkpoint_file_extension(tmpdir): """Test ModelCheckpoint with different file extension.""" model = LogInTwoMethods() model_checkpoint = ModelCheckpointExtensionTest( monitor="early_stop_on", dirpath=tmpdir, save_top_k=1, save_last=True ) trainer = Trainer(default_root_dir=tmpdir, callbacks=[model_checkpoint], max_steps=1, logger=False) trainer.fit(model) expected = ["epoch=0-step=1.tpkc", "last.tpkc"] assert set(expected) == set(os.listdir(tmpdir)) def test_model_checkpoint_save_last(tmpdir): """Tests that save_last produces only one last checkpoint.""" seed_everything() model = LogInTwoMethods() epochs = 3 ModelCheckpoint.CHECKPOINT_NAME_LAST = "last-{epoch}" model_checkpoint = ModelCheckpoint(monitor="early_stop_on", dirpath=tmpdir, save_top_k=-1, save_last=True) trainer = Trainer( default_root_dir=tmpdir, callbacks=[model_checkpoint], max_epochs=epochs, limit_train_batches=10, limit_val_batches=10, logger=False, ) trainer.fit(model) last_filename = model_checkpoint._format_checkpoint_name( ModelCheckpoint.CHECKPOINT_NAME_LAST, {"epoch": trainer.current_epoch - 1} ) last_filename = last_filename + ".ckpt" assert str(tmpdir / last_filename) == model_checkpoint.last_model_path assert set(os.listdir(tmpdir)) == set( [f"epoch={i}-step={j}.ckpt" for i, j in zip(range(epochs), [10, 20, 30])] + [last_filename] ) ModelCheckpoint.CHECKPOINT_NAME_LAST = "last" def test_invalid_top_k(tmpdir): """Make sure that a MisconfigurationException is raised for a negative save_top_k argument.""" with pytest.raises(MisconfigurationException, match=r".*Must be >= -1"): ModelCheckpoint(dirpath=tmpdir, save_top_k=-3) def test_none_monitor_top_k(tmpdir): """Test that a warning appears for positive top_k with monitor=None.""" with pytest.raises( MisconfigurationException, match=r"ModelCheckpoint\(save_top_k=3, monitor=None\) is not a valid*" ): ModelCheckpoint(dirpath=tmpdir, save_top_k=3) # These should not fail ModelCheckpoint(dirpath=tmpdir, save_top_k=-1) ModelCheckpoint(dirpath=tmpdir, save_top_k=0) ModelCheckpoint(dirpath=tmpdir, save_top_k=1) def test_invalid_every_n_epochs(tmpdir): """Make sure that a MisconfigurationException is raised for a negative every_n_epochs argument.""" with pytest.raises(MisconfigurationException, match=r".*Must be >= 0"): ModelCheckpoint(dirpath=tmpdir, every_n_epochs=-3) # These should not fail ModelCheckpoint(dirpath=tmpdir, every_n_epochs=0) ModelCheckpoint(dirpath=tmpdir, every_n_epochs=1) ModelCheckpoint(dirpath=tmpdir, every_n_epochs=2) def test_invalid_every_n_train_steps(tmpdir): """Make sure that a MisconfigurationException is raised for a negative every_n_epochs argument.""" with pytest.raises(MisconfigurationException, match=r".*Must be >= 0"): ModelCheckpoint(dirpath=tmpdir, every_n_train_steps=-3) # These should not fail ModelCheckpoint(dirpath=tmpdir, every_n_train_steps=0) ModelCheckpoint(dirpath=tmpdir, every_n_train_steps=1) ModelCheckpoint(dirpath=tmpdir, every_n_epochs=2) def test_invalid_trigger_combination(tmpdir): """Test that a MisconfigurationException is raised if more than one of every_n_epochs, every_n_train_steps, and train_time_interval are enabled together.""" with pytest.raises(MisconfigurationException, match=r".*Combination of parameters every_n_train_steps"): ModelCheckpoint(dirpath=tmpdir, every_n_train_steps=1, every_n_epochs=2) with pytest.raises(MisconfigurationException, match=r".*Combination of parameters every_n_train_steps"): ModelCheckpoint(train_time_interval=timedelta(minutes=1), every_n_epochs=2) with pytest.raises(MisconfigurationException, match=r".*Combination of parameters every_n_train_steps"): ModelCheckpoint(train_time_interval=timedelta(minutes=1), every_n_train_steps=2) # These should not fail ModelCheckpoint(dirpath=tmpdir, every_n_train_steps=0, every_n_epochs=3) ModelCheckpoint(dirpath=tmpdir, every_n_train_steps=4, every_n_epochs=0) ModelCheckpoint(dirpath=tmpdir, every_n_train_steps=0, every_n_epochs=0, train_time_interval=timedelta(minutes=1)) def test_none_every_n_train_steps_val_epochs(tmpdir): checkpoint_callback = ModelCheckpoint(dirpath=tmpdir) assert checkpoint_callback.every_n_epochs == 1 assert checkpoint_callback._every_n_train_steps == 0 def test_model_checkpoint_save_last_none_monitor(tmpdir, caplog): """Test that it is possible to save all checkpoints when monitor=None.""" seed_everything() model = LogInTwoMethods() epochs = 2 checkpoint_callback = ModelCheckpoint(monitor=None, dirpath=tmpdir, save_top_k=-1, save_last=True) trainer = Trainer( default_root_dir=tmpdir, callbacks=[checkpoint_callback], limit_train_batches=10, limit_val_batches=10, max_epochs=epochs, logger=False, ) with caplog.at_level(INFO): trainer.fit(model) assert "will duplicate the last checkpoint saved" in caplog.text # these should not be set if monitor is None assert checkpoint_callback.monitor is None assert checkpoint_callback.best_model_path == tmpdir / "epoch=1-step=20.ckpt" assert checkpoint_callback.last_model_path == tmpdir / "last.ckpt" assert checkpoint_callback.best_model_score is None assert checkpoint_callback.best_k_models == {} assert checkpoint_callback.kth_best_model_path == "" # check that the correct ckpts were created expected = [f"epoch={i}-step={j}.ckpt" for i, j in zip(range(epochs), [10, 20])] expected.append("last.ckpt") assert set(os.listdir(tmpdir)) == set(expected) @pytest.mark.parametrize("every_n_epochs", list(range(4))) def test_model_checkpoint_every_n_epochs(tmpdir, every_n_epochs): model = LogInTwoMethods() epochs = 5 checkpoint_callback = ModelCheckpoint( dirpath=tmpdir, filename="{epoch}", save_top_k=-1, every_n_epochs=every_n_epochs ) trainer = Trainer( default_root_dir=tmpdir, callbacks=[checkpoint_callback], max_epochs=epochs, limit_train_batches=1, limit_val_batches=1, logger=False, ) trainer.fit(model) # check that the correct ckpts were created, the modulo condition is checked in `ModelCheckpoint` expected = [f"epoch={e}.ckpt" for e in range(epochs) if not (e + 1) % every_n_epochs] if every_n_epochs > 0 else [] assert set(os.listdir(tmpdir)) == set(expected) def test_ckpt_every_n_train_steps(tmpdir): """Tests that the checkpoints are saved every n training steps.""" model = LogInTwoMethods() every_n_train_steps = 16 max_epochs = 2 epoch_length = 64 checkpoint_callback = ModelCheckpoint( filename="{step}", every_n_epochs=0, every_n_train_steps=every_n_train_steps, dirpath=tmpdir, save_top_k=-1, save_last=False, ) trainer = Trainer( default_root_dir=tmpdir, max_epochs=2, enable_progress_bar=False, callbacks=[checkpoint_callback], logger=False, ) trainer.fit(model) expected = [ f"step={i}.ckpt" for i in range(every_n_train_steps, max_epochs * epoch_length + 1, every_n_train_steps) ] assert set(os.listdir(tmpdir)) == set(expected) @mock.patch("pytorch_lightning.callbacks.model_checkpoint.time") def test_model_checkpoint_train_time_interval(mock_datetime, tmpdir) -> None: """Tests that the checkpoints are saved at the specified time interval.""" seconds_per_batch = 7 start_time = time.monotonic() batches_per_epoch = 64 num_epochs = 2 max_batches = batches_per_epoch * num_epochs + 1 mock_datetime.monotonic.side_effect = [start_time + seconds_per_batch * i for i in range(max_batches)] model = BoringModel() trainer = Trainer( default_root_dir=tmpdir, min_epochs=num_epochs, max_epochs=num_epochs, enable_progress_bar=False, callbacks=[ ModelCheckpoint( filename="{epoch}-{step}", dirpath=tmpdir, train_time_interval=timedelta(minutes=1), save_top_k=-1, save_last=False, ) ], logger=False, ) trainer.fit(model) # Each batch takes 7 sec and we checkpoint every minute. There are 64 # batches per epoch, so total time to run is 7*64*2 = 896 sec < 14.96 minutes, # so we should have 14 checkpoints. assert len(os.listdir(tmpdir)) == 14 def test_model_checkpoint_topk_zero(tmpdir): """Test that no checkpoints are saved when save_top_k=0.""" model = LogInTwoMethods() checkpoint_callback = ModelCheckpoint(dirpath=tmpdir, save_top_k=0, save_last=True) trainer = Trainer(default_root_dir=tmpdir, callbacks=[checkpoint_callback], max_epochs=2, logger=False) trainer.fit(model) # these should not be set if monitor is None assert checkpoint_callback.monitor is None assert checkpoint_callback.best_model_path == "" assert checkpoint_callback.best_model_score is None assert checkpoint_callback.best_k_models == {} assert checkpoint_callback.kth_best_model_path == "" # check that only the last ckpt was created assert os.listdir(tmpdir) == ["last.ckpt"] assert checkpoint_callback.last_model_path == tmpdir / "last.ckpt" def test_model_checkpoint_topk_all(tmpdir): """Test that save_top_k=-1 tracks the best models when monitor key is provided.""" seed_everything(1000) epochs = 3 model = BoringModel() checkpoint_callback = ModelCheckpoint( dirpath=tmpdir, filename="{epoch}", monitor="epoch", mode="max", save_top_k=-1 ) trainer = Trainer( default_root_dir=tmpdir, callbacks=[checkpoint_callback], max_epochs=epochs, logger=False, val_check_interval=1.0, ) trainer.fit(model) assert checkpoint_callback.monitor == "epoch" assert checkpoint_callback.best_model_path == tmpdir / "epoch=2.ckpt" assert checkpoint_callback.best_model_score == epochs - 1 assert len(os.listdir(tmpdir)) == len(checkpoint_callback.best_k_models) == epochs assert set(checkpoint_callback.best_k_models.keys()) == {str(tmpdir / f"epoch={i}.ckpt") for i in range(epochs)} assert checkpoint_callback.kth_best_model_path == tmpdir / "epoch=0.ckpt" def test_ckpt_metric_names(tmpdir): model = LogInTwoMethods() trainer = Trainer( default_root_dir=tmpdir, max_epochs=1, gradient_clip_val=1.0, overfit_batches=0.20, enable_progress_bar=False, limit_train_batches=0.01, limit_val_batches=0.01, callbacks=[ModelCheckpoint(monitor="early_stop_on", dirpath=tmpdir, filename="{val_loss:.2f}")], ) trainer.fit(model) # make sure the checkpoint we saved has the metric in the name ckpts = os.listdir(tmpdir) ckpts = [x for x in ckpts if "val_loss" in x] assert len(ckpts) == 1 val = re.sub("[^0-9.]", "", ckpts[0]) assert len(val) > 3 def test_default_checkpoint_behavior(tmpdir): seed_everything(1234) model = LogInTwoMethods() trainer = Trainer( default_root_dir=tmpdir, max_epochs=3, enable_progress_bar=False, limit_train_batches=5, limit_val_batches=5 ) with patch.object(trainer, "save_checkpoint", wraps=trainer.save_checkpoint) as save_mock: trainer.fit(model) results = trainer.test() assert len(results) == 1 save_dir = tmpdir / "lightning_logs" / "version_0" / "checkpoints" save_weights_only = trainer.checkpoint_callback.save_weights_only save_mock.assert_has_calls( [ call(save_dir / "epoch=0-step=5.ckpt", save_weights_only), call(save_dir / "epoch=1-step=10.ckpt", save_weights_only), call(save_dir / "epoch=2-step=15.ckpt", save_weights_only), ] ) ckpts = os.listdir(save_dir) assert len(ckpts) == 1 assert ckpts[0] == "epoch=2-step=15.ckpt" def test_model_checkpoint_save_last_checkpoint_contents(tmpdir): """Tests that the save_last checkpoint contains the latest information.""" seed_everything(100) model = LogInTwoMethods() num_epochs = 3 model_checkpoint = ModelCheckpoint( monitor="early_stop_on", dirpath=tmpdir, filename="{epoch}", save_top_k=num_epochs, save_last=True ) trainer = Trainer( default_root_dir=tmpdir, callbacks=[model_checkpoint], max_epochs=num_epochs, limit_train_batches=2, limit_val_batches=2, ) trainer.fit(model) path_last_epoch = str(tmpdir / f"epoch={num_epochs - 1}.ckpt") path_last = str(tmpdir / "last.ckpt") assert path_last == model_checkpoint.last_model_path assert os.path.isfile(path_last_epoch) ckpt_last_epoch = torch.load(path_last_epoch) ckpt_last = torch.load(path_last) assert ckpt_last_epoch["epoch"] == ckpt_last["epoch"] assert ckpt_last_epoch["global_step"] == ckpt_last["global_step"] ckpt_id = ( "ModelCheckpoint{'monitor': 'early_stop_on', 'mode': 'min', 'every_n_train_steps': 0, 'every_n_epochs': 1," " 'train_time_interval': None, 'save_on_train_epoch_end': True}" ) assert ckpt_last["callbacks"][ckpt_id] == ckpt_last_epoch["callbacks"][ckpt_id] # it is easier to load the model objects than to iterate over the raw dict of tensors model_last_epoch = LogInTwoMethods.load_from_checkpoint(path_last_epoch) model_last = LogInTwoMethods.load_from_checkpoint(model_checkpoint.last_model_path) for w0, w1 in zip(model_last_epoch.parameters(), model_last.parameters()): assert w0.eq(w1).all() @pytest.mark.parametrize("mode", ["min", "max"]) def test_checkpointing_with_nan_as_first(tmpdir, mode): monitor = [float("nan")] monitor += [5, 7, 8] if mode == "max" else [8, 7, 5] class CurrentModel(LogInTwoMethods): def validation_epoch_end(self, outputs): val_loss = monitor[self.current_epoch] self.log("abc", val_loss) model = CurrentModel() callback = ModelCheckpoint(monitor="abc", mode=mode, save_top_k=1, dirpath=tmpdir) trainer = Trainer( callbacks=[callback], default_root_dir=tmpdir, val_check_interval=1.0, max_epochs=len(monitor), ) trainer.save_checkpoint = Mock() trainer.fit(model) # check that last one is also the best one assert trainer.save_checkpoint.call_count == len(monitor) assert mode == "min" and callback.best_model_score == 5 or mode == "max" and callback.best_model_score == 8 def test_checkpoint_repeated_strategy(tmpdir): """This test validates checkpoint can be called several times without increasing internally its global step if nothing run.""" checkpoint_callback = ModelCheckpoint(monitor="val_loss", dirpath=tmpdir, filename="{epoch:02d}") class ExtendedBoringModel(BoringModel): def validation_step(self, batch, batch_idx): output = self.layer(batch) loss = self.loss(batch, output) self.log("val_loss", loss) model = ExtendedBoringModel() model.validation_epoch_end = None trainer_kwargs = { "max_epochs": 1, "limit_train_batches": 2, "limit_val_batches": 2, "limit_test_batches": 2, "enable_progress_bar": False, "enable_model_summary": False, } trainer = Trainer(**trainer_kwargs, callbacks=[checkpoint_callback]) trainer.fit(model) assert os.listdir(tmpdir) == ["epoch=00.ckpt"] for idx in range(4): # load from checkpoint trainer = pl.Trainer(**trainer_kwargs, default_root_dir=tmpdir) trainer.fit(model, ckpt_path=checkpoint_callback.best_model_path) trainer.test(ckpt_path=checkpoint_callback.best_model_path, verbose=False) assert set(os.listdir(tmpdir)) == {"epoch=00.ckpt", "lightning_logs"} assert set(os.listdir(tmpdir / "lightning_logs")) == {f"version_{i}" for i in range(4)} def test_checkpoint_repeated_strategy_extended(tmpdir): """This test validates checkpoint can be called several times without increasing internally its global step if nothing run.""" class ExtendedBoringModel(BoringModel): def validation_step(self, batch, batch_idx): output = self.layer(batch) loss = self.loss(batch, output) return {"val_loss": loss} def validation_epoch_end(self, *_): ... def assert_trainer_init(trainer): assert trainer.global_step == 0 assert trainer.current_epoch == 0 def get_last_checkpoint(ckpt_dir): last = ckpt_dir.listdir(sort=True)[-1] return str(last) def assert_checkpoint_content(ckpt_dir): chk = pl_load(get_last_checkpoint(ckpt_dir)) # `-1` because this checkpoint is saved `on_train_epoch_end` which is considered part of the epoch so the # `current_epoch` count has not been increased yet assert chk["epoch"] == epochs - 1 assert chk["global_step"] == 4 def assert_checkpoint_log_dir(idx): lightning_logs = tmpdir / "lightning_logs" actual = [d.basename for d in lightning_logs.listdir(sort=True)] assert actual == [f"version_{i}" for i in range(idx + 1)] actual = [d.basename for d in ckpt_dir.listdir()] assert len(actual) == epochs, actual ckpt_dir = tmpdir / "checkpoints" checkpoint_cb = ModelCheckpoint(dirpath=ckpt_dir, save_top_k=-1) epochs = 2 limit_train_batches = 2 trainer_config = dict( default_root_dir=tmpdir, max_epochs=epochs, limit_train_batches=limit_train_batches, limit_val_batches=3, limit_test_batches=4, callbacks=[checkpoint_cb], ) trainer = pl.Trainer(**trainer_config) assert_trainer_init(trainer) model = ExtendedBoringModel() trainer.fit(model) assert trainer.global_step == epochs * limit_train_batches assert trainer.current_epoch == epochs assert_checkpoint_log_dir(0) assert_checkpoint_content(ckpt_dir) trainer.validate(model) assert trainer.current_epoch == epochs trainer.test(model) assert trainer.current_epoch == epochs for idx in range(1, 5): chk = get_last_checkpoint(ckpt_dir) assert_checkpoint_content(ckpt_dir) # load from checkpoint trainer = pl.Trainer(**trainer_config) assert_trainer_init(trainer) model = ExtendedBoringModel() trainer.test(model) assert_trainer_init(trainer) trainer.fit(model, ckpt_path=chk) assert trainer.global_step == epochs * limit_train_batches assert trainer.current_epoch == epochs assert trainer.fit_loop.epoch_progress.current.processed == epochs trainer.validate(model) assert trainer.global_step == epochs * limit_train_batches assert trainer.current_epoch == epochs assert trainer.fit_loop.epoch_progress.current.processed == epochs trainer.fit(model) assert trainer.global_step == epochs * limit_train_batches assert trainer.current_epoch == epochs assert trainer.fit_loop.epoch_progress.current.processed == epochs assert_checkpoint_log_dir(idx) def test_configure_model_checkpoint(tmpdir): """Test all valid and invalid ways a checkpoint callback can be passed to the Trainer.""" kwargs = dict(default_root_dir=tmpdir) callback1 = ModelCheckpoint() callback2 = ModelCheckpoint() # no callbacks trainer = Trainer(enable_checkpointing=False, callbacks=[], **kwargs) assert not any(isinstance(c, ModelCheckpoint) for c in trainer.callbacks) assert trainer.checkpoint_callback is None # default configuration trainer = Trainer(callbacks=[], **kwargs) assert sum(1 for c in trainer.callbacks if isinstance(c, ModelCheckpoint)) == 1 assert isinstance(trainer.checkpoint_callback, ModelCheckpoint) # custom callback passed to callbacks list, enable_checkpointing=True is ignored trainer = Trainer(enable_checkpointing=True, callbacks=[callback1], **kwargs) assert [c for c in trainer.callbacks if isinstance(c, ModelCheckpoint)] == [callback1] assert trainer.checkpoint_callback == callback1 # multiple checkpoint callbacks trainer = Trainer(callbacks=[callback1, callback2], **kwargs) assert trainer.checkpoint_callback == callback1 assert trainer.checkpoint_callbacks == [callback1, callback2] with pytest.raises(MisconfigurationException, match="`enable_checkpointing=False` but found `ModelCheckpoint`"): Trainer(enable_checkpointing=False, callbacks=[callback1], **kwargs) def test_val_check_interval_checkpoint_files(tmpdir): """Test correct checkpoint naming when validating/checkpointing multiple times per epoch.""" model = LogInTwoMethods() model_checkpoint = ModelCheckpoint(dirpath=tmpdir, save_top_k=-1, monitor="val_acc", mode="max") trainer = Trainer( default_root_dir=tmpdir, val_check_interval=0.2, max_epochs=1, limit_train_batches=10, callbacks=[model_checkpoint], logger=False, enable_progress_bar=False, enable_model_summary=False, ) trainer.fit(model) files = {p.basename for p in tmpdir.listdir()} assert files == {f"epoch=0-step={s}.ckpt" for s in [2, 4, 6, 8, 10]} def test_current_score(tmpdir): """Check that the current_score value is correct and was saved.""" class TestModel(BoringModel): def training_step(self, *args): self.log("foo", (self.current_epoch + 1) / 10) return super().training_step(*args) model_checkpoint = ModelCheckpoint(dirpath=tmpdir, save_top_k=3, monitor="foo", mode="min") trainer = Trainer( default_root_dir=tmpdir, max_epochs=3, limit_train_batches=1, limit_val_batches=1, callbacks=[model_checkpoint], logger=False, enable_progress_bar=False, enable_model_summary=False, ) trainer.fit(TestModel()) assert model_checkpoint.current_score == 0.3 ckpts = [torch.load(str(ckpt)) for ckpt in tmpdir.listdir()] ckpts = [ ckpt["callbacks"][ "ModelCheckpoint{'monitor': 'foo', 'mode': 'min', 'every_n_train_steps': 0, 'every_n_epochs': 1," " 'train_time_interval': None, 'save_on_train_epoch_end': True}" ] for ckpt in ckpts ] assert sorted(ckpt["current_score"] for ckpt in ckpts) == [0.1, 0.2, 0.3] @pytest.mark.parametrize("mode", ["min", "max"]) def test_current_score_when_nan(tmpdir, mode: str): """Check that ModelCheckpoint handles NaN values correctly.""" class TestModel(BoringModel): def training_step(self, *args): self.log("foo", float("nan")) return super().training_step(*args) model_checkpoint = ModelCheckpoint(dirpath=tmpdir, save_top_k=1, monitor="foo", mode=mode) trainer = Trainer( default_root_dir=tmpdir, max_epochs=1, limit_train_batches=1, limit_val_batches=1, callbacks=[model_checkpoint], logger=False, enable_progress_bar=False, enable_model_summary=False, ) trainer.fit(TestModel()) expected = float("inf" if mode == "min" else "-inf") assert model_checkpoint.best_model_score == expected assert model_checkpoint.current_score == expected @pytest.mark.parametrize("use_omegaconf", [False, pytest.param(True, marks=RunIf(omegaconf=True))]) def test_hparams_type(tmpdir, use_omegaconf): class TestModel(BoringModel): def __init__(self, hparams): super().__init__() self.save_hyperparameters(hparams) model_checkpoint = ModelCheckpoint(dirpath=tmpdir, save_top_k=1) trainer = Trainer( max_epochs=1, default_root_dir=tmpdir, limit_train_batches=1, limit_val_batches=1, callbacks=[model_checkpoint], logger=False, enable_progress_bar=False, enable_model_summary=False, ) hp = {"test_hp_0": 1, "test_hp_1": 2} hp = OmegaConf.create(hp) if use_omegaconf else Namespace(**hp) model = TestModel(hp) trainer.fit(model) ckpt = trainer._checkpoint_connector.dump_checkpoint() if use_omegaconf: assert isinstance(ckpt[model.CHECKPOINT_HYPER_PARAMS_KEY], Container) else: # make sure it's not AttributeDict assert type(ckpt[model.CHECKPOINT_HYPER_PARAMS_KEY]) is dict def test_ckpt_version_after_rerun_new_trainer(tmpdir): """Check that previous checkpoints are renamed to have the correct version suffix when new trainer instances are used.""" epochs = 2 for i in range(epochs): mc = ModelCheckpoint(dirpath=tmpdir, save_top_k=-1, monitor="epoch", filename="{epoch}") trainer = Trainer( max_epochs=epochs, limit_train_batches=1, limit_val_batches=1, default_root_dir=tmpdir, callbacks=[mc], logger=False, enable_progress_bar=False, enable_model_summary=False, ) trainer.fit(BoringModel()) # check best_k_models state expected = {"epoch=0-v1.ckpt", "epoch=1-v1.ckpt"} if i else {"epoch=0.ckpt", "epoch=1.ckpt"} assert {Path(f).name for f in mc.best_k_models} == expected # check created ckpts actual = {f.basename for f in tmpdir.listdir()} assert actual == {"epoch=0.ckpt", "epoch=1.ckpt", "epoch=0-v1.ckpt", "epoch=1-v1.ckpt"} def test_ckpt_version_after_rerun_same_trainer(tmpdir): """Check that previous checkpoints are renamed to have the correct version suffix when the same trainer instance is used.""" mc = ModelCheckpoint(dirpath=tmpdir, save_top_k=-1, monitor="epoch", filename="test") mc.STARTING_VERSION = 9 trainer = Trainer( max_epochs=2, limit_train_batches=1, limit_val_batches=1, default_root_dir=tmpdir, callbacks=[mc], logger=False, enable_progress_bar=False, enable_model_summary=False, ) trainer.fit(BoringModel()) trainer.fit_loop.max_epochs = 4 trainer.fit(BoringModel()) ckpt_range = range(mc.STARTING_VERSION, trainer.max_epochs + mc.STARTING_VERSION - 1) expected = {"test.ckpt", *(f"test-v{i}.ckpt" for i in ckpt_range)} # check best_k_models state assert {Path(f).name for f in mc.best_k_models} == expected # check created ckpts assert set(os.listdir(tmpdir)) == expected def test_model_checkpoint_mode_options(): with pytest.raises(MisconfigurationException, match="`mode` can be .* but got unknown_option"): ModelCheckpoint(mode="unknown_option") def test_check_val_every_n_epochs_top_k_integration(tmpdir): model = BoringModel() mc = ModelCheckpoint(dirpath=tmpdir, monitor="epoch", save_top_k=-1, filename="{epoch}") trainer = Trainer( default_root_dir=tmpdir, limit_train_batches=1, limit_val_batches=1, num_sanity_val_steps=0, max_epochs=5, check_val_every_n_epoch=2, callbacks=mc, enable_model_summary=False, logger=False, ) trainer.fit(model) assert set(os.listdir(tmpdir)) == {"epoch=1.ckpt", "epoch=3.ckpt"} def test_model_checkpoint_saveload_ckpt(tmpdir): def make_assertions(cb_restore, written_ckpt): expected_keys = { "dirpath": False, "best_model_score": False, "kth_best_model_path": False, "kth_value": False, "best_k_models": False, "last_model_path": False, "best_model_path": True, } for key, should_match in expected_keys.items(): if should_match: assert getattr(cb_restore, key) == written_ckpt[key] else: assert getattr(cb_restore, key) != written_ckpt[key] class CustomModelCheckpoint(ModelCheckpoint): def on_load_checkpoint(self, *args, **kwargs): assert self.dirpath is not None return super().on_load_checkpoint(*args, **kwargs) ckpt = { "best_model_path": "epoch=10-step=1436.ckpt", "best_model_score": torch.tensor(2.246), "best_k_models": {"epoch=10-step=1436.ckpt": torch.tensor(2.246)}, "kth_best_model_path": "epoch=10-step=1436.ckpt", "kth_value": torch.tensor(2.246), "last_model_path": "last2245.ckpt", } # test state_dict cb_write = ModelCheckpoint(dirpath=tmpdir, save_top_k=-1, save_last=True) for key, val in ckpt.items(): setattr(cb_write, key, val) written_ckpt = cb_write.state_dict() for state in ckpt: assert ckpt[state] == written_ckpt[state] # Case - 1 # test load_state_dict # Notes: # 1. "current_score", "dirpath" and "monitor" are currently not restored by load_state_dict. # We therefore set "dirpath" and "monitor" to something different than for ckpt/cb_write so we can assert them. # 2. "current_score" is left as initialized, i.e. None, and can therefore also be asserted # 3. When a different `dirpath` is passed to `ModelCheckpoint` to resume training, only # `best_model_path` and `last_model_path` are reloaded (reloading for others is stopped). cb_restore = ModelCheckpoint(dirpath=tmpdir + "/restore", monitor=None, save_top_k=-1, save_last=True) with pytest.warns(UserWarning, match="The dirpath has changed from*"): cb_restore.load_state_dict(written_ckpt) make_assertions(cb_restore, written_ckpt) # Case - 2 # Make sure that everything runs when dirpath is not initialized explicitly cb_restore = CustomModelCheckpoint() cb_restore.setup(Trainer(), BoringModel()) with pytest.warns(UserWarning, match="The dirpath has changed from*"): cb_restore.load_state_dict(written_ckpt) make_assertions(cb_restore, written_ckpt) def test_resume_training_preserves_old_ckpt_last(tmpdir): """Ensures that the last saved checkpoint is not deleted from the previous folder when training is resumed from the old checkpoint.""" model = BoringModel() trainer_kwargs = { "default_root_dir": tmpdir, "max_epochs": 1, "limit_train_batches": 3, "limit_val_batches": 0, "enable_model_summary": False, "logger": False, } mc_kwargs = { "filename": "{step}", "monitor": "step", "mode": "max", "save_last": True, "save_top_k": 2, "every_n_train_steps": 1, } trainer = Trainer(**trainer_kwargs, callbacks=ModelCheckpoint(**mc_kwargs)) trainer.fit(model) # Make sure that the last checkpoint file exists in the dirpath passed (`tmpdir`) assert set(os.listdir(tmpdir / "checkpoints")) == {"last.ckpt", "step=2.ckpt", "step=3.ckpt"} # Training it for 2 epochs for extra surety, that nothing gets deleted after multiple epochs trainer_kwargs["max_epochs"] += 1 mc_kwargs["dirpath"] = f"{tmpdir}/new" trainer = Trainer(**trainer_kwargs, callbacks=ModelCheckpoint(**mc_kwargs)) trainer.fit(model, ckpt_path=f"{tmpdir}/checkpoints/step=2.ckpt") # Ensure that the file is not deleted from the old folder assert os.path.isfile(f"{tmpdir}/checkpoints/last.ckpt") def test_save_last_saves_correct_last_model_path(tmpdir): mc = ModelCheckpoint(dirpath=tmpdir, save_last=True) mc.CHECKPOINT_NAME_LAST = "{foo}-last" trainer = Trainer(callbacks=mc) trainer.strategy.connect(BoringModel()) mc._save_last_checkpoint(trainer, {"foo": 1}) expected = "foo=1-last.ckpt" assert os.listdir(tmpdir) == [expected] full_path = str(tmpdir / expected) ckpt = torch.load(full_path) assert ckpt["callbacks"][mc.state_key]["last_model_path"] == full_path def test_none_monitor_saves_correct_best_model_path(tmpdir): mc = ModelCheckpoint(dirpath=tmpdir, monitor=None) trainer = Trainer(callbacks=mc) trainer.strategy.connect(BoringModel()) mc._save_none_monitor_checkpoint(trainer, {}) expected = "epoch=0-step=0.ckpt" assert os.listdir(tmpdir) == [expected] full_path = str(tmpdir / expected) ckpt = torch.load(full_path) assert ckpt["callbacks"][mc.state_key]["best_model_path"] == full_path def test_last_global_step_saved(): # this should not save anything model_checkpoint = ModelCheckpoint(save_top_k=0, save_last=False, monitor="foo") trainer = Mock() monitor_candidates = {"foo": 123} model_checkpoint._save_topk_checkpoint(trainer, monitor_candidates) model_checkpoint._save_last_checkpoint(trainer, monitor_candidates) assert model_checkpoint._last_global_step_saved == 0 @pytest.mark.parametrize("every_n_epochs", (0, 5)) def test_save_last_every_n_epochs_interaction(tmpdir, every_n_epochs): """Test that `save_last` ignores `every_n_epochs`.""" mc = ModelCheckpoint(every_n_epochs=every_n_epochs, save_last=True, save_top_k=0, save_on_train_epoch_end=True) trainer = Trainer( default_root_dir=tmpdir, max_epochs=2, callbacks=mc, limit_train_batches=1, limit_val_batches=0, enable_progress_bar=False, enable_model_summary=False, logger=False, ) model = BoringModel() with patch.object(trainer, "save_checkpoint") as save_mock: trainer.fit(model) assert mc.last_model_path # a "last" ckpt was saved assert save_mock.call_count == trainer.max_epochs