Fix last checkpoint finding in filtered files with correct extension (#17072)

Co-authored-by: awaelchli <aedu.waelchli@gmail.com>
This commit is contained in:
Yasser Souri 2023-11-21 14:12:02 -08:00 committed by GitHub
parent d4614d043e
commit 67d3844818
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 34 additions and 5 deletions

View File

@ -47,6 +47,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Fixed checks for local file protocol due to fsspec changes in 2023.10.0 ([#19023](https://github.com/Lightning-AI/lightning/pull/19023))
- Fixed automatic detection of 'last.ckpt' files to respect the extension when filtering ([#17072](https://github.com/Lightning-AI/lightning/pull/17072))
## [2.1.2] - 2023-11-15

View File

@ -626,12 +626,13 @@ class ModelCheckpoint(Checkpoint):
def _find_last_checkpoints(self, trainer: "pl.Trainer") -> Set[str]:
# find all checkpoints in the folder
ckpt_path = self.__resolve_ckpt_dir(trainer)
last_pattern = rf"^{self.CHECKPOINT_NAME_LAST}(-(\d+))?"
def _is_last(path: Path) -> bool:
return path.suffix == self.FILE_EXTENSION and bool(re.match(last_pattern, path.stem))
if self._fs.exists(ckpt_path):
return {
os.path.normpath(p)
for p in self._fs.ls(ckpt_path, detail=False)
if self.CHECKPOINT_NAME_LAST in os.path.split(p)[1]
}
return {os.path.normpath(p) for p in self._fs.ls(ckpt_path, detail=False) if _is_last(Path(p))}
return set()
def __warn_if_dir_not_empty(self, dirpath: _PATH) -> None:

View File

@ -1509,3 +1509,28 @@ def test_resume_and_old_checkpoint_files_remain(same_resume_folder, tmp_path):
else:
assert set(os.listdir(first)) == {"epoch=0-step=2.ckpt", "epoch=0-step=4.ckpt"} # no files deleted
assert set(os.listdir(second)) == {"epoch=0-step=6.ckpt", "epoch=0-step=8.ckpt"}
@pytest.mark.parametrize(
("name", "extension", "folder_contents", "expected"),
[
("last", ".ckpt", {}, {}),
("any", ".any", {}, {}),
("last", ".ckpt", {"last"}, {}),
("any", ".any", {"last"}, {}),
("last", ".ckpt", {"last", "last.ckpt"}, {"last.ckpt"}),
("other", ".pt", {"last", "last.pt", "other.pt"}, {"other.pt"}),
("last", ".ckpt", {"log.txt", "last-v0.ckpt", "last-v1.ckpt"}, {"last-v0.ckpt", "last-v1.ckpt"}),
("other", ".pt", {"log.txt", "last-v0.ckpt", "other-v0.pt", "other-v1.pt"}, {"other-v0.pt", "other-v1.pt"}),
],
)
def test_find_last_checkpoints(name, extension, folder_contents, expected, tmp_path):
for file in folder_contents:
(tmp_path / file).touch()
trainer = Trainer()
callback = ModelCheckpoint(dirpath=tmp_path)
callback.CHECKPOINT_NAME_LAST = name
callback.FILE_EXTENSION = extension
files = callback._find_last_checkpoints(trainer)
assert files == {str(tmp_path / p) for p in expected}