Fix `ModelCheckpoint.CHECKPOINT_NAME_LAST` test interaction (#18993)
This commit is contained in:
parent
466f772e3e
commit
b4605b44ee
|
@ -485,12 +485,12 @@ def test_model_checkpoint_file_extension(tmpdir):
|
|||
assert set(expected) == set(os.listdir(tmpdir))
|
||||
|
||||
|
||||
def test_model_checkpoint_save_last(tmpdir):
|
||||
def test_model_checkpoint_save_last(tmpdir, monkeypatch):
|
||||
"""Tests that save_last produces only one last checkpoint."""
|
||||
seed_everything()
|
||||
model = LogInTwoMethods()
|
||||
epochs = 3
|
||||
ModelCheckpoint.CHECKPOINT_NAME_LAST = "last-{epoch}"
|
||||
monkeypatch.setattr(ModelCheckpoint, "CHECKPOINT_NAME_LAST", "last-{epoch}")
|
||||
model_checkpoint = ModelCheckpoint(monitor="early_stop_on", dirpath=tmpdir, save_top_k=-1, save_last=True)
|
||||
trainer = Trainer(
|
||||
default_root_dir=tmpdir,
|
||||
|
@ -511,7 +511,6 @@ def test_model_checkpoint_save_last(tmpdir):
|
|||
)
|
||||
assert os.path.islink(tmpdir / last_filename)
|
||||
assert os.path.realpath(tmpdir / last_filename) == model_checkpoint._last_checkpoint_saved
|
||||
ModelCheckpoint.CHECKPOINT_NAME_LAST = "last"
|
||||
|
||||
|
||||
def test_model_checkpoint_link_checkpoint(tmp_path):
|
||||
|
|
Loading…
Reference in New Issue