Fix ModelCheckpoint tests from incomplete PR (#19205)
* Update src/lightning/pytorch/trainer/trainer.py --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
parent
6ea0e2d632
commit
858803236e
|
@ -54,8 +54,8 @@ def test_default_checkpoint_freq(save_mock, tmpdir, epochs: int, val_check_inter
|
|||
@pytest.mark.parametrize(
|
||||
("k", "epochs", "val_check_interval", "expected"), [(1, 1, 1.0, 1), (2, 2, 1.0, 2), (2, 1, 0.25, 4), (2, 2, 0.3, 6)]
|
||||
)
|
||||
@pytest.mark.parametrize("save_last", [False, True])
|
||||
def test_top_k(save_mock, tmpdir, k: int, epochs: int, val_check_interval: float, expected: int, save_last: bool):
|
||||
@pytest.mark.parametrize("save_last", [False, True, "link"])
|
||||
def test_top_k(save_mock, tmpdir, k, epochs, val_check_interval, expected, save_last):
|
||||
class TestModel(BoringModel):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
@ -79,8 +79,8 @@ def test_top_k(save_mock, tmpdir, k: int, epochs: int, val_check_interval: float
|
|||
)
|
||||
trainer.fit(model)
|
||||
|
||||
if save_last:
|
||||
expected = expected
|
||||
# save_last=True: last epochs are saved every step (so double the save calls)
|
||||
expected = expected * 2 if save_last is True else expected
|
||||
assert save_mock.call_count == expected
|
||||
|
||||
|
||||
|
|
|
@ -311,11 +311,9 @@ def test_callbacks_state_fit_ckpt_path(tmpdir):
|
|||
"best_k_models",
|
||||
"kth_best_model_path",
|
||||
"kth_value",
|
||||
"last_model_path",
|
||||
):
|
||||
assert getattr(before, attribute) == getattr(after, attribute), f"{attribute}"
|
||||
# `before.last_model_path` is a symlink pointing to a checkpoint saved before that symlink was created,
|
||||
# hence reloading that checkpoint will restore `after.last_model_path = ""`
|
||||
assert after.last_model_path == ""
|
||||
|
||||
|
||||
@RunIf(sklearn=True)
|
||||
|
|
|
@ -61,7 +61,7 @@ def test_checkpoint_plugin_called(tmpdir):
|
|||
assert ckpt_files == {"epoch=1-step=2.ckpt", "last.ckpt"}
|
||||
assert trainer.checkpoint_callback.best_model_path == tmpdir / "epoch=1-step=2.ckpt"
|
||||
assert trainer.checkpoint_callback.last_model_path == tmpdir / "last.ckpt"
|
||||
assert checkpoint_plugin.save_checkpoint.call_count == 2
|
||||
assert checkpoint_plugin.save_checkpoint.call_count == 4
|
||||
assert checkpoint_plugin.remove_checkpoint.call_count == 1
|
||||
|
||||
trainer.test(model, ckpt_path=ck.last_model_path)
|
||||
|
@ -88,7 +88,7 @@ def test_checkpoint_plugin_called(tmpdir):
|
|||
assert ckpt_files == {"epoch=1-step=2.ckpt", "last.ckpt", "epoch=1-step=2-v1.ckpt", "last-v1.ckpt"}
|
||||
assert trainer.checkpoint_callback.best_model_path == tmpdir / "epoch=1-step=2-v1.ckpt"
|
||||
assert trainer.checkpoint_callback.last_model_path == tmpdir / "last-v1.ckpt"
|
||||
assert checkpoint_plugin.save_checkpoint.call_count == 2
|
||||
assert checkpoint_plugin.save_checkpoint.call_count == 4
|
||||
assert checkpoint_plugin.remove_checkpoint.call_count == 1
|
||||
|
||||
trainer.test(model, ckpt_path=ck.last_model_path)
|
||||
|
|
Loading…
Reference in New Issue