From 58c905b940e3c75d81cafb83ec142c42ca17dd47 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Thu, 23 Nov 2023 15:11:43 +0100 Subject: [PATCH] Fix ModelCheckpoint dirpath expanding home prefix (#19058) --- src/lightning/pytorch/CHANGELOG.md | 4 ++++ .../pytorch/callbacks/model_checkpoint.py | 2 +- .../checkpointing/test_model_checkpoint.py | 14 ++++++++++++++ 3 files changed, 19 insertions(+), 1 deletion(-) diff --git a/src/lightning/pytorch/CHANGELOG.md b/src/lightning/pytorch/CHANGELOG.md index 2a88d29f5b..fb94ef6c42 100644 --- a/src/lightning/pytorch/CHANGELOG.md +++ b/src/lightning/pytorch/CHANGELOG.md @@ -53,6 +53,10 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixed an issue where setting `CHECKPOINT_JOIN_CHAR` or `CHECKPOINT_EQUALS_CHAR` would only work on the `ModelCheckpoint` class but not on an instance ([#19054](https://github.com/Lightning-AI/lightning/pull/19054)) +- Fixed `ModelCheckpoint` not expanding the `dirpath` if it has the `~` (home) prefix ([#19058](https://github.com/Lightning-AI/lightning/pull/19058)) + + + ## [2.1.2] - 2023-11-15 ### Fixed diff --git a/src/lightning/pytorch/callbacks/model_checkpoint.py b/src/lightning/pytorch/callbacks/model_checkpoint.py index 565aefaf3a..1d3ec47e87 100644 --- a/src/lightning/pytorch/callbacks/model_checkpoint.py +++ b/src/lightning/pytorch/callbacks/model_checkpoint.py @@ -467,7 +467,7 @@ class ModelCheckpoint(Checkpoint): self._fs = get_filesystem(dirpath if dirpath else "") if dirpath and _is_local_file_protocol(dirpath if dirpath else ""): - dirpath = os.path.realpath(dirpath) + dirpath = os.path.realpath(os.path.expanduser(dirpath)) self.dirpath = dirpath self.filename = filename diff --git a/tests/tests_pytorch/checkpointing/test_model_checkpoint.py b/tests/tests_pytorch/checkpointing/test_model_checkpoint.py index 81cc98cf50..66764c7830 100644 --- a/tests/tests_pytorch/checkpointing/test_model_checkpoint.py +++ b/tests/tests_pytorch/checkpointing/test_model_checkpoint.py @@ -1536,3 +1536,17 @@ def test_find_last_checkpoints(name, extension, folder_contents, expected, tmp_p callback.FILE_EXTENSION = extension files = callback._find_last_checkpoints(trainer) assert files == {str(tmp_path / p) for p in expected} + + +def test_expand_home(): + """Test that the dirpath gets expanded if it contains `~`.""" + home_root = Path.home() + + checkpoint = ModelCheckpoint(dirpath="~/checkpoints") + assert checkpoint.dirpath == str(home_root / "checkpoints") + checkpoint = ModelCheckpoint(dirpath=Path("~/checkpoints")) + assert checkpoint.dirpath == str(home_root / "checkpoints") + + # it is possible to have a folder with the name `~` + checkpoint = ModelCheckpoint(dirpath="./~/checkpoints") + assert checkpoint.dirpath == str(Path.cwd() / "~" / "checkpoints")