Fix `ModelCheckpoint.CHECKPOINT_NAME_LAST` test interaction (#18993)

This commit is contained in:
Adrian Wälchli 2023-11-12 11:01:25 +01:00 committed by GitHub
parent 466f772e3e
commit b4605b44ee
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 2 additions and 3 deletions

View File

@ -485,12 +485,12 @@ def test_model_checkpoint_file_extension(tmpdir):
assert set(expected) == set(os.listdir(tmpdir))
def test_model_checkpoint_save_last(tmpdir):
def test_model_checkpoint_save_last(tmpdir, monkeypatch):
"""Tests that save_last produces only one last checkpoint."""
seed_everything()
model = LogInTwoMethods()
epochs = 3
ModelCheckpoint.CHECKPOINT_NAME_LAST = "last-{epoch}"
monkeypatch.setattr(ModelCheckpoint, "CHECKPOINT_NAME_LAST", "last-{epoch}")
model_checkpoint = ModelCheckpoint(monitor="early_stop_on", dirpath=tmpdir, save_top_k=-1, save_last=True)
trainer = Trainer(
default_root_dir=tmpdir,
@ -511,7 +511,6 @@ def test_model_checkpoint_save_last(tmpdir):
)
assert os.path.islink(tmpdir / last_filename)
assert os.path.realpath(tmpdir / last_filename) == model_checkpoint._last_checkpoint_saved
ModelCheckpoint.CHECKPOINT_NAME_LAST = "last"
def test_model_checkpoint_link_checkpoint(tmp_path):