Fix saving relative symlink for ModelCheckpoint callback (#19303)

Co-authored-by: awaelchli <aedu.waelchli@gmail.com>
This commit is contained in:
shenmishajing 2024-01-20 22:32:08 +08:00 committed by GitHub
parent e89f46a74e
commit d02009af76
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 25 additions and 1 deletions

View File

@ -74,6 +74,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Fixed warning for Dataloader if `num_workers=1` and CPU count is 1 ([#19224](https://github.com/Lightning-AI/lightning/pull/19224))
- Fixed an issue with the ModelCheckpoint callback not saving relative symlinks with `ModelCheckpoint(save_last="link")` ([#19303](https://github.com/Lightning-AI/lightning/pull/19303))
## [2.1.3] - 2023-12-21
### Changed

View File

@ -403,7 +403,7 @@ class ModelCheckpoint(Checkpoint):
elif os.path.isdir(linkpath):
shutil.rmtree(linkpath)
try:
os.symlink(filepath, linkpath)
os.symlink(os.path.relpath(filepath, os.path.dirname(linkpath)), linkpath)
except OSError:
# on Windows, special permissions are required to create symbolic links as a regular user
# fall back to copying the file

View File

@ -534,6 +534,7 @@ def test_model_checkpoint_link_checkpoint(tmp_path):
ModelCheckpoint._link_checkpoint(trainer, filepath=str(file), linkpath=str(link))
assert os.path.islink(link)
assert os.path.realpath(link) == str(file)
assert not os.path.isabs(os.readlink(link))
# link exists (is a file)
new_file1 = tmp_path / "new_file1"
@ -541,6 +542,7 @@ def test_model_checkpoint_link_checkpoint(tmp_path):
ModelCheckpoint._link_checkpoint(trainer, filepath=str(new_file1), linkpath=str(link))
assert os.path.islink(link)
assert os.path.realpath(link) == str(new_file1)
assert not os.path.isabs(os.readlink(link))
# link exists (is a link)
new_file2 = tmp_path / "new_file2"
@ -548,6 +550,7 @@ def test_model_checkpoint_link_checkpoint(tmp_path):
ModelCheckpoint._link_checkpoint(trainer, filepath=str(new_file2), linkpath=str(link))
assert os.path.islink(link)
assert os.path.realpath(link) == str(new_file2)
assert not os.path.isabs(os.readlink(link))
# link exists (is a folder)
folder = tmp_path / "folder"
@ -557,6 +560,7 @@ def test_model_checkpoint_link_checkpoint(tmp_path):
ModelCheckpoint._link_checkpoint(trainer, filepath=str(folder), linkpath=str(folder_link))
assert os.path.islink(folder_link)
assert os.path.realpath(folder_link) == str(folder)
assert not os.path.isabs(os.readlink(folder_link))
# link exists (is a link to a folder)
new_folder = tmp_path / "new_folder"
@ -564,6 +568,7 @@ def test_model_checkpoint_link_checkpoint(tmp_path):
ModelCheckpoint._link_checkpoint(trainer, filepath=str(new_folder), linkpath=str(folder_link))
assert os.path.islink(folder_link)
assert os.path.realpath(folder_link) == str(new_folder)
assert not os.path.isabs(os.readlink(folder_link))
# simulate permission error on Windows (creation of symbolic links requires privileges)
file = tmp_path / "win_file"
@ -575,6 +580,22 @@ def test_model_checkpoint_link_checkpoint(tmp_path):
assert os.path.isfile(link) # fall back to copying instead of linking
def test_model_checkpoint_link_checkpoint_relative_path(tmp_path, monkeypatch):
"""Test that linking a checkpoint works with relative paths."""
trainer = Mock()
monkeypatch.chdir(tmp_path)
folder = Path("x/z/z")
folder.mkdir(parents=True)
file = folder / "file"
file.touch()
link = folder / "link"
ModelCheckpoint._link_checkpoint(trainer, filepath=str(file.absolute()), linkpath=str(link.absolute()))
assert os.path.islink(link)
assert Path(os.readlink(link)) == file.relative_to(folder)
assert not os.path.isabs(os.readlink(link))
def test_invalid_top_k(tmpdir):
"""Make sure that a MisconfigurationException is raised for a negative save_top_k argument."""
with pytest.raises(MisconfigurationException, match=r".*Must be >= -1"):