Fix ModelCheckpoint tests from incomplete PR (#19205)
* Update src/lightning/pytorch/trainer/trainer.py --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
parent
6ea0e2d632
commit
858803236e
|
@ -54,8 +54,8 @@ def test_default_checkpoint_freq(save_mock, tmpdir, epochs: int, val_check_inter
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
("k", "epochs", "val_check_interval", "expected"), [(1, 1, 1.0, 1), (2, 2, 1.0, 2), (2, 1, 0.25, 4), (2, 2, 0.3, 6)]
|
("k", "epochs", "val_check_interval", "expected"), [(1, 1, 1.0, 1), (2, 2, 1.0, 2), (2, 1, 0.25, 4), (2, 2, 0.3, 6)]
|
||||||
)
|
)
|
||||||
@pytest.mark.parametrize("save_last", [False, True])
|
@pytest.mark.parametrize("save_last", [False, True, "link"])
|
||||||
def test_top_k(save_mock, tmpdir, k: int, epochs: int, val_check_interval: float, expected: int, save_last: bool):
|
def test_top_k(save_mock, tmpdir, k, epochs, val_check_interval, expected, save_last):
|
||||||
class TestModel(BoringModel):
|
class TestModel(BoringModel):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
@ -79,8 +79,8 @@ def test_top_k(save_mock, tmpdir, k: int, epochs: int, val_check_interval: float
|
||||||
)
|
)
|
||||||
trainer.fit(model)
|
trainer.fit(model)
|
||||||
|
|
||||||
if save_last:
|
# save_last=True: last epochs are saved every step (so double the save calls)
|
||||||
expected = expected
|
expected = expected * 2 if save_last is True else expected
|
||||||
assert save_mock.call_count == expected
|
assert save_mock.call_count == expected
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -311,11 +311,9 @@ def test_callbacks_state_fit_ckpt_path(tmpdir):
|
||||||
"best_k_models",
|
"best_k_models",
|
||||||
"kth_best_model_path",
|
"kth_best_model_path",
|
||||||
"kth_value",
|
"kth_value",
|
||||||
|
"last_model_path",
|
||||||
):
|
):
|
||||||
assert getattr(before, attribute) == getattr(after, attribute), f"{attribute}"
|
assert getattr(before, attribute) == getattr(after, attribute), f"{attribute}"
|
||||||
# `before.last_model_path` is a symlink pointing to a checkpoint saved before that symlink was created,
|
|
||||||
# hence reloading that checkpoint will restore `after.last_model_path = ""`
|
|
||||||
assert after.last_model_path == ""
|
|
||||||
|
|
||||||
|
|
||||||
@RunIf(sklearn=True)
|
@RunIf(sklearn=True)
|
||||||
|
|
|
@ -61,7 +61,7 @@ def test_checkpoint_plugin_called(tmpdir):
|
||||||
assert ckpt_files == {"epoch=1-step=2.ckpt", "last.ckpt"}
|
assert ckpt_files == {"epoch=1-step=2.ckpt", "last.ckpt"}
|
||||||
assert trainer.checkpoint_callback.best_model_path == tmpdir / "epoch=1-step=2.ckpt"
|
assert trainer.checkpoint_callback.best_model_path == tmpdir / "epoch=1-step=2.ckpt"
|
||||||
assert trainer.checkpoint_callback.last_model_path == tmpdir / "last.ckpt"
|
assert trainer.checkpoint_callback.last_model_path == tmpdir / "last.ckpt"
|
||||||
assert checkpoint_plugin.save_checkpoint.call_count == 2
|
assert checkpoint_plugin.save_checkpoint.call_count == 4
|
||||||
assert checkpoint_plugin.remove_checkpoint.call_count == 1
|
assert checkpoint_plugin.remove_checkpoint.call_count == 1
|
||||||
|
|
||||||
trainer.test(model, ckpt_path=ck.last_model_path)
|
trainer.test(model, ckpt_path=ck.last_model_path)
|
||||||
|
@ -88,7 +88,7 @@ def test_checkpoint_plugin_called(tmpdir):
|
||||||
assert ckpt_files == {"epoch=1-step=2.ckpt", "last.ckpt", "epoch=1-step=2-v1.ckpt", "last-v1.ckpt"}
|
assert ckpt_files == {"epoch=1-step=2.ckpt", "last.ckpt", "epoch=1-step=2-v1.ckpt", "last-v1.ckpt"}
|
||||||
assert trainer.checkpoint_callback.best_model_path == tmpdir / "epoch=1-step=2-v1.ckpt"
|
assert trainer.checkpoint_callback.best_model_path == tmpdir / "epoch=1-step=2-v1.ckpt"
|
||||||
assert trainer.checkpoint_callback.last_model_path == tmpdir / "last-v1.ckpt"
|
assert trainer.checkpoint_callback.last_model_path == tmpdir / "last-v1.ckpt"
|
||||||
assert checkpoint_plugin.save_checkpoint.call_count == 2
|
assert checkpoint_plugin.save_checkpoint.call_count == 4
|
||||||
assert checkpoint_plugin.remove_checkpoint.call_count == 1
|
assert checkpoint_plugin.remove_checkpoint.call_count == 1
|
||||||
|
|
||||||
trainer.test(model, ckpt_path=ck.last_model_path)
|
trainer.test(model, ckpt_path=ck.last_model_path)
|
||||||
|
|
Loading…
Reference in New Issue