Fix ModelCheckpoint tests from incomplete PR (#19205)

* Update src/lightning/pytorch/trainer/trainer.py

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
awaelchli 2023-12-22 08:49:09 +01:00 committed by GitHub
parent 6ea0e2d632
commit 858803236e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 7 additions and 9 deletions

View File

@ -54,8 +54,8 @@ def test_default_checkpoint_freq(save_mock, tmpdir, epochs: int, val_check_inter
@pytest.mark.parametrize( @pytest.mark.parametrize(
("k", "epochs", "val_check_interval", "expected"), [(1, 1, 1.0, 1), (2, 2, 1.0, 2), (2, 1, 0.25, 4), (2, 2, 0.3, 6)] ("k", "epochs", "val_check_interval", "expected"), [(1, 1, 1.0, 1), (2, 2, 1.0, 2), (2, 1, 0.25, 4), (2, 2, 0.3, 6)]
) )
@pytest.mark.parametrize("save_last", [False, True]) @pytest.mark.parametrize("save_last", [False, True, "link"])
def test_top_k(save_mock, tmpdir, k: int, epochs: int, val_check_interval: float, expected: int, save_last: bool): def test_top_k(save_mock, tmpdir, k, epochs, val_check_interval, expected, save_last):
class TestModel(BoringModel): class TestModel(BoringModel):
def __init__(self): def __init__(self):
super().__init__() super().__init__()
@ -79,8 +79,8 @@ def test_top_k(save_mock, tmpdir, k: int, epochs: int, val_check_interval: float
) )
trainer.fit(model) trainer.fit(model)
if save_last: # save_last=True: last epochs are saved every step (so double the save calls)
expected = expected expected = expected * 2 if save_last is True else expected
assert save_mock.call_count == expected assert save_mock.call_count == expected

View File

@ -311,11 +311,9 @@ def test_callbacks_state_fit_ckpt_path(tmpdir):
"best_k_models", "best_k_models",
"kth_best_model_path", "kth_best_model_path",
"kth_value", "kth_value",
"last_model_path",
): ):
assert getattr(before, attribute) == getattr(after, attribute), f"{attribute}" assert getattr(before, attribute) == getattr(after, attribute), f"{attribute}"
# `before.last_model_path` is a symlink pointing to a checkpoint saved before that symlink was created,
# hence reloading that checkpoint will restore `after.last_model_path = ""`
assert after.last_model_path == ""
@RunIf(sklearn=True) @RunIf(sklearn=True)

View File

@ -61,7 +61,7 @@ def test_checkpoint_plugin_called(tmpdir):
assert ckpt_files == {"epoch=1-step=2.ckpt", "last.ckpt"} assert ckpt_files == {"epoch=1-step=2.ckpt", "last.ckpt"}
assert trainer.checkpoint_callback.best_model_path == tmpdir / "epoch=1-step=2.ckpt" assert trainer.checkpoint_callback.best_model_path == tmpdir / "epoch=1-step=2.ckpt"
assert trainer.checkpoint_callback.last_model_path == tmpdir / "last.ckpt" assert trainer.checkpoint_callback.last_model_path == tmpdir / "last.ckpt"
assert checkpoint_plugin.save_checkpoint.call_count == 2 assert checkpoint_plugin.save_checkpoint.call_count == 4
assert checkpoint_plugin.remove_checkpoint.call_count == 1 assert checkpoint_plugin.remove_checkpoint.call_count == 1
trainer.test(model, ckpt_path=ck.last_model_path) trainer.test(model, ckpt_path=ck.last_model_path)
@ -88,7 +88,7 @@ def test_checkpoint_plugin_called(tmpdir):
assert ckpt_files == {"epoch=1-step=2.ckpt", "last.ckpt", "epoch=1-step=2-v1.ckpt", "last-v1.ckpt"} assert ckpt_files == {"epoch=1-step=2.ckpt", "last.ckpt", "epoch=1-step=2-v1.ckpt", "last-v1.ckpt"}
assert trainer.checkpoint_callback.best_model_path == tmpdir / "epoch=1-step=2-v1.ckpt" assert trainer.checkpoint_callback.best_model_path == tmpdir / "epoch=1-step=2-v1.ckpt"
assert trainer.checkpoint_callback.last_model_path == tmpdir / "last-v1.ckpt" assert trainer.checkpoint_callback.last_model_path == tmpdir / "last-v1.ckpt"
assert checkpoint_plugin.save_checkpoint.call_count == 2 assert checkpoint_plugin.save_checkpoint.call_count == 4
assert checkpoint_plugin.remove_checkpoint.call_count == 1 assert checkpoint_plugin.remove_checkpoint.call_count == 1
trainer.test(model, ckpt_path=ck.last_model_path) trainer.test(model, ckpt_path=ck.last_model_path)