diff --git a/src/lightning/pytorch/CHANGELOG.md b/src/lightning/pytorch/CHANGELOG.md index 9cefa18f7e..9e6411cd7b 100644 --- a/src/lightning/pytorch/CHANGELOG.md +++ b/src/lightning/pytorch/CHANGELOG.md @@ -47,6 +47,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixed checks for local file protocol due to fsspec changes in 2023.10.0 ([#19023](https://github.com/Lightning-AI/lightning/pull/19023)) +- Fixed automatic detection of 'last.ckpt' files to respect the extension when filtering ([#17072](https://github.com/Lightning-AI/lightning/pull/17072)) + + ## [2.1.2] - 2023-11-15 diff --git a/src/lightning/pytorch/callbacks/model_checkpoint.py b/src/lightning/pytorch/callbacks/model_checkpoint.py index 9a4a4cd285..e0140ff74f 100644 --- a/src/lightning/pytorch/callbacks/model_checkpoint.py +++ b/src/lightning/pytorch/callbacks/model_checkpoint.py @@ -626,12 +626,13 @@ class ModelCheckpoint(Checkpoint): def _find_last_checkpoints(self, trainer: "pl.Trainer") -> Set[str]: # find all checkpoints in the folder ckpt_path = self.__resolve_ckpt_dir(trainer) + last_pattern = rf"^{self.CHECKPOINT_NAME_LAST}(-(\d+))?" + + def _is_last(path: Path) -> bool: + return path.suffix == self.FILE_EXTENSION and bool(re.match(last_pattern, path.stem)) + if self._fs.exists(ckpt_path): - return { - os.path.normpath(p) - for p in self._fs.ls(ckpt_path, detail=False) - if self.CHECKPOINT_NAME_LAST in os.path.split(p)[1] - } + return {os.path.normpath(p) for p in self._fs.ls(ckpt_path, detail=False) if _is_last(Path(p))} return set() def __warn_if_dir_not_empty(self, dirpath: _PATH) -> None: diff --git a/tests/tests_pytorch/checkpointing/test_model_checkpoint.py b/tests/tests_pytorch/checkpointing/test_model_checkpoint.py index de6d1180f1..a7eb8b544a 100644 --- a/tests/tests_pytorch/checkpointing/test_model_checkpoint.py +++ b/tests/tests_pytorch/checkpointing/test_model_checkpoint.py @@ -1509,3 +1509,28 @@ def test_resume_and_old_checkpoint_files_remain(same_resume_folder, tmp_path): else: assert set(os.listdir(first)) == {"epoch=0-step=2.ckpt", "epoch=0-step=4.ckpt"} # no files deleted assert set(os.listdir(second)) == {"epoch=0-step=6.ckpt", "epoch=0-step=8.ckpt"} + + +@pytest.mark.parametrize( + ("name", "extension", "folder_contents", "expected"), + [ + ("last", ".ckpt", {}, {}), + ("any", ".any", {}, {}), + ("last", ".ckpt", {"last"}, {}), + ("any", ".any", {"last"}, {}), + ("last", ".ckpt", {"last", "last.ckpt"}, {"last.ckpt"}), + ("other", ".pt", {"last", "last.pt", "other.pt"}, {"other.pt"}), + ("last", ".ckpt", {"log.txt", "last-v0.ckpt", "last-v1.ckpt"}, {"last-v0.ckpt", "last-v1.ckpt"}), + ("other", ".pt", {"log.txt", "last-v0.ckpt", "other-v0.pt", "other-v1.pt"}, {"other-v0.pt", "other-v1.pt"}), + ], +) +def test_find_last_checkpoints(name, extension, folder_contents, expected, tmp_path): + for file in folder_contents: + (tmp_path / file).touch() + + trainer = Trainer() + callback = ModelCheckpoint(dirpath=tmp_path) + callback.CHECKPOINT_NAME_LAST = name + callback.FILE_EXTENSION = extension + files = callback._find_last_checkpoints(trainer) + assert files == {str(tmp_path / p) for p in expected}