Fix ModelCheckpoint dirpath expanding home prefix (#19058)
This commit is contained in:
parent
85adf17328
commit
58c905b940
|
@ -53,6 +53,10 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
|
|||
- Fixed an issue where setting `CHECKPOINT_JOIN_CHAR` or `CHECKPOINT_EQUALS_CHAR` would only work on the `ModelCheckpoint` class but not on an instance ([#19054](https://github.com/Lightning-AI/lightning/pull/19054))
|
||||
|
||||
|
||||
- Fixed `ModelCheckpoint` not expanding the `dirpath` if it has the `~` (home) prefix ([#19058](https://github.com/Lightning-AI/lightning/pull/19058))
|
||||
|
||||
|
||||
|
||||
## [2.1.2] - 2023-11-15
|
||||
|
||||
### Fixed
|
||||
|
|
|
@ -467,7 +467,7 @@ class ModelCheckpoint(Checkpoint):
|
|||
self._fs = get_filesystem(dirpath if dirpath else "")
|
||||
|
||||
if dirpath and _is_local_file_protocol(dirpath if dirpath else ""):
|
||||
dirpath = os.path.realpath(dirpath)
|
||||
dirpath = os.path.realpath(os.path.expanduser(dirpath))
|
||||
|
||||
self.dirpath = dirpath
|
||||
self.filename = filename
|
||||
|
|
|
@ -1536,3 +1536,17 @@ def test_find_last_checkpoints(name, extension, folder_contents, expected, tmp_p
|
|||
callback.FILE_EXTENSION = extension
|
||||
files = callback._find_last_checkpoints(trainer)
|
||||
assert files == {str(tmp_path / p) for p in expected}
|
||||
|
||||
|
||||
def test_expand_home():
|
||||
"""Test that the dirpath gets expanded if it contains `~`."""
|
||||
home_root = Path.home()
|
||||
|
||||
checkpoint = ModelCheckpoint(dirpath="~/checkpoints")
|
||||
assert checkpoint.dirpath == str(home_root / "checkpoints")
|
||||
checkpoint = ModelCheckpoint(dirpath=Path("~/checkpoints"))
|
||||
assert checkpoint.dirpath == str(home_root / "checkpoints")
|
||||
|
||||
# it is possible to have a folder with the name `~`
|
||||
checkpoint = ModelCheckpoint(dirpath="./~/checkpoints")
|
||||
assert checkpoint.dirpath == str(Path.cwd() / "~" / "checkpoints")
|
||||
|
|
Loading…
Reference in New Issue