Fix last checkpoint finding in filtered files with correct extension (#17072)
Co-authored-by: awaelchli <aedu.waelchli@gmail.com>
This commit is contained in:
parent
d4614d043e
commit
67d3844818
|
@ -47,6 +47,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
|
|||
- Fixed checks for local file protocol due to fsspec changes in 2023.10.0 ([#19023](https://github.com/Lightning-AI/lightning/pull/19023))
|
||||
|
||||
|
||||
- Fixed automatic detection of 'last.ckpt' files to respect the extension when filtering ([#17072](https://github.com/Lightning-AI/lightning/pull/17072))
|
||||
|
||||
|
||||
|
||||
## [2.1.2] - 2023-11-15
|
||||
|
||||
|
|
|
@ -626,12 +626,13 @@ class ModelCheckpoint(Checkpoint):
|
|||
def _find_last_checkpoints(self, trainer: "pl.Trainer") -> Set[str]:
|
||||
# find all checkpoints in the folder
|
||||
ckpt_path = self.__resolve_ckpt_dir(trainer)
|
||||
last_pattern = rf"^{self.CHECKPOINT_NAME_LAST}(-(\d+))?"
|
||||
|
||||
def _is_last(path: Path) -> bool:
|
||||
return path.suffix == self.FILE_EXTENSION and bool(re.match(last_pattern, path.stem))
|
||||
|
||||
if self._fs.exists(ckpt_path):
|
||||
return {
|
||||
os.path.normpath(p)
|
||||
for p in self._fs.ls(ckpt_path, detail=False)
|
||||
if self.CHECKPOINT_NAME_LAST in os.path.split(p)[1]
|
||||
}
|
||||
return {os.path.normpath(p) for p in self._fs.ls(ckpt_path, detail=False) if _is_last(Path(p))}
|
||||
return set()
|
||||
|
||||
def __warn_if_dir_not_empty(self, dirpath: _PATH) -> None:
|
||||
|
|
|
@ -1509,3 +1509,28 @@ def test_resume_and_old_checkpoint_files_remain(same_resume_folder, tmp_path):
|
|||
else:
|
||||
assert set(os.listdir(first)) == {"epoch=0-step=2.ckpt", "epoch=0-step=4.ckpt"} # no files deleted
|
||||
assert set(os.listdir(second)) == {"epoch=0-step=6.ckpt", "epoch=0-step=8.ckpt"}
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("name", "extension", "folder_contents", "expected"),
|
||||
[
|
||||
("last", ".ckpt", {}, {}),
|
||||
("any", ".any", {}, {}),
|
||||
("last", ".ckpt", {"last"}, {}),
|
||||
("any", ".any", {"last"}, {}),
|
||||
("last", ".ckpt", {"last", "last.ckpt"}, {"last.ckpt"}),
|
||||
("other", ".pt", {"last", "last.pt", "other.pt"}, {"other.pt"}),
|
||||
("last", ".ckpt", {"log.txt", "last-v0.ckpt", "last-v1.ckpt"}, {"last-v0.ckpt", "last-v1.ckpt"}),
|
||||
("other", ".pt", {"log.txt", "last-v0.ckpt", "other-v0.pt", "other-v1.pt"}, {"other-v0.pt", "other-v1.pt"}),
|
||||
],
|
||||
)
|
||||
def test_find_last_checkpoints(name, extension, folder_contents, expected, tmp_path):
|
||||
for file in folder_contents:
|
||||
(tmp_path / file).touch()
|
||||
|
||||
trainer = Trainer()
|
||||
callback = ModelCheckpoint(dirpath=tmp_path)
|
||||
callback.CHECKPOINT_NAME_LAST = name
|
||||
callback.FILE_EXTENSION = extension
|
||||
files = callback._find_last_checkpoints(trainer)
|
||||
assert files == {str(tmp_path / p) for p in expected}
|
||||
|
|
Loading…
Reference in New Issue