From 858803236e02653fadd25f2fe3c01ad1eadb5311 Mon Sep 17 00:00:00 2001 From: awaelchli Date: Fri, 22 Dec 2023 08:49:09 +0100 Subject: [PATCH] Fix ModelCheckpoint tests from incomplete PR (#19205) * Update src/lightning/pytorch/trainer/trainer.py --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- .../checkpointing/test_checkpoint_callback_frequency.py | 8 ++++---- tests/tests_pytorch/models/test_restore.py | 4 +--- tests/tests_pytorch/plugins/test_checkpoint_io_plugin.py | 4 ++-- 3 files changed, 7 insertions(+), 9 deletions(-) diff --git a/tests/tests_pytorch/checkpointing/test_checkpoint_callback_frequency.py b/tests/tests_pytorch/checkpointing/test_checkpoint_callback_frequency.py index 6ba2e85f06..16a3066a14 100644 --- a/tests/tests_pytorch/checkpointing/test_checkpoint_callback_frequency.py +++ b/tests/tests_pytorch/checkpointing/test_checkpoint_callback_frequency.py @@ -54,8 +54,8 @@ def test_default_checkpoint_freq(save_mock, tmpdir, epochs: int, val_check_inter @pytest.mark.parametrize( ("k", "epochs", "val_check_interval", "expected"), [(1, 1, 1.0, 1), (2, 2, 1.0, 2), (2, 1, 0.25, 4), (2, 2, 0.3, 6)] ) -@pytest.mark.parametrize("save_last", [False, True]) -def test_top_k(save_mock, tmpdir, k: int, epochs: int, val_check_interval: float, expected: int, save_last: bool): +@pytest.mark.parametrize("save_last", [False, True, "link"]) +def test_top_k(save_mock, tmpdir, k, epochs, val_check_interval, expected, save_last): class TestModel(BoringModel): def __init__(self): super().__init__() @@ -79,8 +79,8 @@ def test_top_k(save_mock, tmpdir, k: int, epochs: int, val_check_interval: float ) trainer.fit(model) - if save_last: - expected = expected + # save_last=True: last epochs are saved every step (so double the save calls) + expected = expected * 2 if save_last is True else expected assert save_mock.call_count == expected diff --git a/tests/tests_pytorch/models/test_restore.py b/tests/tests_pytorch/models/test_restore.py index ccb8380ca0..36955cef46 100644 --- a/tests/tests_pytorch/models/test_restore.py +++ b/tests/tests_pytorch/models/test_restore.py @@ -311,11 +311,9 @@ def test_callbacks_state_fit_ckpt_path(tmpdir): "best_k_models", "kth_best_model_path", "kth_value", + "last_model_path", ): assert getattr(before, attribute) == getattr(after, attribute), f"{attribute}" - # `before.last_model_path` is a symlink pointing to a checkpoint saved before that symlink was created, - # hence reloading that checkpoint will restore `after.last_model_path = ""` - assert after.last_model_path == "" @RunIf(sklearn=True) diff --git a/tests/tests_pytorch/plugins/test_checkpoint_io_plugin.py b/tests/tests_pytorch/plugins/test_checkpoint_io_plugin.py index 3c94ab6929..4e4a569caa 100644 --- a/tests/tests_pytorch/plugins/test_checkpoint_io_plugin.py +++ b/tests/tests_pytorch/plugins/test_checkpoint_io_plugin.py @@ -61,7 +61,7 @@ def test_checkpoint_plugin_called(tmpdir): assert ckpt_files == {"epoch=1-step=2.ckpt", "last.ckpt"} assert trainer.checkpoint_callback.best_model_path == tmpdir / "epoch=1-step=2.ckpt" assert trainer.checkpoint_callback.last_model_path == tmpdir / "last.ckpt" - assert checkpoint_plugin.save_checkpoint.call_count == 2 + assert checkpoint_plugin.save_checkpoint.call_count == 4 assert checkpoint_plugin.remove_checkpoint.call_count == 1 trainer.test(model, ckpt_path=ck.last_model_path) @@ -88,7 +88,7 @@ def test_checkpoint_plugin_called(tmpdir): assert ckpt_files == {"epoch=1-step=2.ckpt", "last.ckpt", "epoch=1-step=2-v1.ckpt", "last-v1.ckpt"} assert trainer.checkpoint_callback.best_model_path == tmpdir / "epoch=1-step=2-v1.ckpt" assert trainer.checkpoint_callback.last_model_path == tmpdir / "last-v1.ckpt" - assert checkpoint_plugin.save_checkpoint.call_count == 2 + assert checkpoint_plugin.save_checkpoint.call_count == 4 assert checkpoint_plugin.remove_checkpoint.call_count == 1 trainer.test(model, ckpt_path=ck.last_model_path)