Fix saving relative symlink for ModelCheckpoint callback (#19303)
Co-authored-by: awaelchli <aedu.waelchli@gmail.com>
This commit is contained in:
parent
e89f46a74e
commit
d02009af76
|
@ -74,6 +74,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
|
|||
- Fixed warning for Dataloader if `num_workers=1` and CPU count is 1 ([#19224](https://github.com/Lightning-AI/lightning/pull/19224))
|
||||
|
||||
|
||||
- Fixed an issue with the ModelCheckpoint callback not saving relative symlinks with `ModelCheckpoint(save_last="link")` ([#19303](https://github.com/Lightning-AI/lightning/pull/19303))
|
||||
|
||||
|
||||
## [2.1.3] - 2023-12-21
|
||||
|
||||
### Changed
|
||||
|
|
|
@ -403,7 +403,7 @@ class ModelCheckpoint(Checkpoint):
|
|||
elif os.path.isdir(linkpath):
|
||||
shutil.rmtree(linkpath)
|
||||
try:
|
||||
os.symlink(filepath, linkpath)
|
||||
os.symlink(os.path.relpath(filepath, os.path.dirname(linkpath)), linkpath)
|
||||
except OSError:
|
||||
# on Windows, special permissions are required to create symbolic links as a regular user
|
||||
# fall back to copying the file
|
||||
|
|
|
@ -534,6 +534,7 @@ def test_model_checkpoint_link_checkpoint(tmp_path):
|
|||
ModelCheckpoint._link_checkpoint(trainer, filepath=str(file), linkpath=str(link))
|
||||
assert os.path.islink(link)
|
||||
assert os.path.realpath(link) == str(file)
|
||||
assert not os.path.isabs(os.readlink(link))
|
||||
|
||||
# link exists (is a file)
|
||||
new_file1 = tmp_path / "new_file1"
|
||||
|
@ -541,6 +542,7 @@ def test_model_checkpoint_link_checkpoint(tmp_path):
|
|||
ModelCheckpoint._link_checkpoint(trainer, filepath=str(new_file1), linkpath=str(link))
|
||||
assert os.path.islink(link)
|
||||
assert os.path.realpath(link) == str(new_file1)
|
||||
assert not os.path.isabs(os.readlink(link))
|
||||
|
||||
# link exists (is a link)
|
||||
new_file2 = tmp_path / "new_file2"
|
||||
|
@ -548,6 +550,7 @@ def test_model_checkpoint_link_checkpoint(tmp_path):
|
|||
ModelCheckpoint._link_checkpoint(trainer, filepath=str(new_file2), linkpath=str(link))
|
||||
assert os.path.islink(link)
|
||||
assert os.path.realpath(link) == str(new_file2)
|
||||
assert not os.path.isabs(os.readlink(link))
|
||||
|
||||
# link exists (is a folder)
|
||||
folder = tmp_path / "folder"
|
||||
|
@ -557,6 +560,7 @@ def test_model_checkpoint_link_checkpoint(tmp_path):
|
|||
ModelCheckpoint._link_checkpoint(trainer, filepath=str(folder), linkpath=str(folder_link))
|
||||
assert os.path.islink(folder_link)
|
||||
assert os.path.realpath(folder_link) == str(folder)
|
||||
assert not os.path.isabs(os.readlink(folder_link))
|
||||
|
||||
# link exists (is a link to a folder)
|
||||
new_folder = tmp_path / "new_folder"
|
||||
|
@ -564,6 +568,7 @@ def test_model_checkpoint_link_checkpoint(tmp_path):
|
|||
ModelCheckpoint._link_checkpoint(trainer, filepath=str(new_folder), linkpath=str(folder_link))
|
||||
assert os.path.islink(folder_link)
|
||||
assert os.path.realpath(folder_link) == str(new_folder)
|
||||
assert not os.path.isabs(os.readlink(folder_link))
|
||||
|
||||
# simulate permission error on Windows (creation of symbolic links requires privileges)
|
||||
file = tmp_path / "win_file"
|
||||
|
@ -575,6 +580,22 @@ def test_model_checkpoint_link_checkpoint(tmp_path):
|
|||
assert os.path.isfile(link) # fall back to copying instead of linking
|
||||
|
||||
|
||||
def test_model_checkpoint_link_checkpoint_relative_path(tmp_path, monkeypatch):
|
||||
"""Test that linking a checkpoint works with relative paths."""
|
||||
trainer = Mock()
|
||||
monkeypatch.chdir(tmp_path)
|
||||
|
||||
folder = Path("x/z/z")
|
||||
folder.mkdir(parents=True)
|
||||
file = folder / "file"
|
||||
file.touch()
|
||||
link = folder / "link"
|
||||
ModelCheckpoint._link_checkpoint(trainer, filepath=str(file.absolute()), linkpath=str(link.absolute()))
|
||||
assert os.path.islink(link)
|
||||
assert Path(os.readlink(link)) == file.relative_to(folder)
|
||||
assert not os.path.isabs(os.readlink(link))
|
||||
|
||||
|
||||
def test_invalid_top_k(tmpdir):
|
||||
"""Make sure that a MisconfigurationException is raised for a negative save_top_k argument."""
|
||||
with pytest.raises(MisconfigurationException, match=r".*Must be >= -1"):
|
||||
|
|
Loading…
Reference in New Issue