From b4605b44ee482f43b07c5cde8c226472d2432528 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Sun, 12 Nov 2023 11:01:25 +0100 Subject: [PATCH] Fix `ModelCheckpoint.CHECKPOINT_NAME_LAST` test interaction (#18993) --- tests/tests_pytorch/checkpointing/test_model_checkpoint.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/tests/tests_pytorch/checkpointing/test_model_checkpoint.py b/tests/tests_pytorch/checkpointing/test_model_checkpoint.py index 5ce0a1172e..62ec65e1b8 100644 --- a/tests/tests_pytorch/checkpointing/test_model_checkpoint.py +++ b/tests/tests_pytorch/checkpointing/test_model_checkpoint.py @@ -485,12 +485,12 @@ def test_model_checkpoint_file_extension(tmpdir): assert set(expected) == set(os.listdir(tmpdir)) -def test_model_checkpoint_save_last(tmpdir): +def test_model_checkpoint_save_last(tmpdir, monkeypatch): """Tests that save_last produces only one last checkpoint.""" seed_everything() model = LogInTwoMethods() epochs = 3 - ModelCheckpoint.CHECKPOINT_NAME_LAST = "last-{epoch}" + monkeypatch.setattr(ModelCheckpoint, "CHECKPOINT_NAME_LAST", "last-{epoch}") model_checkpoint = ModelCheckpoint(monitor="early_stop_on", dirpath=tmpdir, save_top_k=-1, save_last=True) trainer = Trainer( default_root_dir=tmpdir, @@ -511,7 +511,6 @@ def test_model_checkpoint_save_last(tmpdir): ) assert os.path.islink(tmpdir / last_filename) assert os.path.realpath(tmpdir / last_filename) == model_checkpoint._last_checkpoint_saved - ModelCheckpoint.CHECKPOINT_NAME_LAST = "last" def test_model_checkpoint_link_checkpoint(tmp_path):