diff --git a/src/lightning/pytorch/CHANGELOG.md b/src/lightning/pytorch/CHANGELOG.md index 274e647bd4..30d525f90c 100644 --- a/src/lightning/pytorch/CHANGELOG.md +++ b/src/lightning/pytorch/CHANGELOG.md @@ -74,6 +74,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixed warning for Dataloader if `num_workers=1` and CPU count is 1 ([#19224](https://github.com/Lightning-AI/lightning/pull/19224)) +- Fixed an issue with the ModelCheckpoint callback not saving relative symlinks with `ModelCheckpoint(save_last="link")` ([#19303](https://github.com/Lightning-AI/lightning/pull/19303)) + + ## [2.1.3] - 2023-12-21 ### Changed diff --git a/src/lightning/pytorch/callbacks/model_checkpoint.py b/src/lightning/pytorch/callbacks/model_checkpoint.py index 0cb4adb82f..e1e823fc92 100644 --- a/src/lightning/pytorch/callbacks/model_checkpoint.py +++ b/src/lightning/pytorch/callbacks/model_checkpoint.py @@ -403,7 +403,7 @@ class ModelCheckpoint(Checkpoint): elif os.path.isdir(linkpath): shutil.rmtree(linkpath) try: - os.symlink(filepath, linkpath) + os.symlink(os.path.relpath(filepath, os.path.dirname(linkpath)), linkpath) except OSError: # on Windows, special permissions are required to create symbolic links as a regular user # fall back to copying the file diff --git a/tests/tests_pytorch/checkpointing/test_model_checkpoint.py b/tests/tests_pytorch/checkpointing/test_model_checkpoint.py index 65d625e604..fb2c8d8e35 100644 --- a/tests/tests_pytorch/checkpointing/test_model_checkpoint.py +++ b/tests/tests_pytorch/checkpointing/test_model_checkpoint.py @@ -534,6 +534,7 @@ def test_model_checkpoint_link_checkpoint(tmp_path): ModelCheckpoint._link_checkpoint(trainer, filepath=str(file), linkpath=str(link)) assert os.path.islink(link) assert os.path.realpath(link) == str(file) + assert not os.path.isabs(os.readlink(link)) # link exists (is a file) new_file1 = tmp_path / "new_file1" @@ -541,6 +542,7 @@ def test_model_checkpoint_link_checkpoint(tmp_path): ModelCheckpoint._link_checkpoint(trainer, filepath=str(new_file1), linkpath=str(link)) assert os.path.islink(link) assert os.path.realpath(link) == str(new_file1) + assert not os.path.isabs(os.readlink(link)) # link exists (is a link) new_file2 = tmp_path / "new_file2" @@ -548,6 +550,7 @@ def test_model_checkpoint_link_checkpoint(tmp_path): ModelCheckpoint._link_checkpoint(trainer, filepath=str(new_file2), linkpath=str(link)) assert os.path.islink(link) assert os.path.realpath(link) == str(new_file2) + assert not os.path.isabs(os.readlink(link)) # link exists (is a folder) folder = tmp_path / "folder" @@ -557,6 +560,7 @@ def test_model_checkpoint_link_checkpoint(tmp_path): ModelCheckpoint._link_checkpoint(trainer, filepath=str(folder), linkpath=str(folder_link)) assert os.path.islink(folder_link) assert os.path.realpath(folder_link) == str(folder) + assert not os.path.isabs(os.readlink(folder_link)) # link exists (is a link to a folder) new_folder = tmp_path / "new_folder" @@ -564,6 +568,7 @@ def test_model_checkpoint_link_checkpoint(tmp_path): ModelCheckpoint._link_checkpoint(trainer, filepath=str(new_folder), linkpath=str(folder_link)) assert os.path.islink(folder_link) assert os.path.realpath(folder_link) == str(new_folder) + assert not os.path.isabs(os.readlink(folder_link)) # simulate permission error on Windows (creation of symbolic links requires privileges) file = tmp_path / "win_file" @@ -575,6 +580,22 @@ def test_model_checkpoint_link_checkpoint(tmp_path): assert os.path.isfile(link) # fall back to copying instead of linking +def test_model_checkpoint_link_checkpoint_relative_path(tmp_path, monkeypatch): + """Test that linking a checkpoint works with relative paths.""" + trainer = Mock() + monkeypatch.chdir(tmp_path) + + folder = Path("x/z/z") + folder.mkdir(parents=True) + file = folder / "file" + file.touch() + link = folder / "link" + ModelCheckpoint._link_checkpoint(trainer, filepath=str(file.absolute()), linkpath=str(link.absolute())) + assert os.path.islink(link) + assert Path(os.readlink(link)) == file.relative_to(folder) + assert not os.path.isabs(os.readlink(link)) + + def test_invalid_top_k(tmpdir): """Make sure that a MisconfigurationException is raised for a negative save_top_k argument.""" with pytest.raises(MisconfigurationException, match=r".*Must be >= -1"):