Fix ModelCheckpoint dirpath expanding home prefix (#19058)

This commit is contained in:
Adrian Wälchli 2023-11-23 15:11:43 +01:00 committed by GitHub
parent 85adf17328
commit 58c905b940
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 19 additions and 1 deletions

View File

@ -53,6 +53,10 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Fixed an issue where setting `CHECKPOINT_JOIN_CHAR` or `CHECKPOINT_EQUALS_CHAR` would only work on the `ModelCheckpoint` class but not on an instance ([#19054](https://github.com/Lightning-AI/lightning/pull/19054))
- Fixed `ModelCheckpoint` not expanding the `dirpath` if it has the `~` (home) prefix ([#19058](https://github.com/Lightning-AI/lightning/pull/19058))
## [2.1.2] - 2023-11-15
### Fixed

View File

@ -467,7 +467,7 @@ class ModelCheckpoint(Checkpoint):
self._fs = get_filesystem(dirpath if dirpath else "")
if dirpath and _is_local_file_protocol(dirpath if dirpath else ""):
dirpath = os.path.realpath(dirpath)
dirpath = os.path.realpath(os.path.expanduser(dirpath))
self.dirpath = dirpath
self.filename = filename

View File

@ -1536,3 +1536,17 @@ def test_find_last_checkpoints(name, extension, folder_contents, expected, tmp_p
callback.FILE_EXTENSION = extension
files = callback._find_last_checkpoints(trainer)
assert files == {str(tmp_path / p) for p in expected}
def test_expand_home():
"""Test that the dirpath gets expanded if it contains `~`."""
home_root = Path.home()
checkpoint = ModelCheckpoint(dirpath="~/checkpoints")
assert checkpoint.dirpath == str(home_root / "checkpoints")
checkpoint = ModelCheckpoint(dirpath=Path("~/checkpoints"))
assert checkpoint.dirpath == str(home_root / "checkpoints")
# it is possible to have a folder with the name `~`
checkpoint = ModelCheckpoint(dirpath="./~/checkpoints")
assert checkpoint.dirpath == str(Path.cwd() / "~" / "checkpoints")