Handle checkpoint dirpath suffix in NeptuneLogger (#18863)

Co-authored-by: Siddhant Sadangi <siddhant.sadangi@gmail.com>
Co-authored-by: Sabine <sabine.nyholm@neptune.ai>
Co-authored-by: Carlos Mocholí <carlossmocholi@gmail.com>
Co-authored-by: awaelchli <aedu.waelchli@gmail.com>
This commit is contained in:
AleksanderWWW 2023-11-25 14:39:46 +01:00 committed by GitHub
parent 1fcb4ae637
commit af852ff590
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 12 additions and 6 deletions

View File

@ -56,6 +56,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Fixed `ModelCheckpoint` not expanding the `dirpath` if it has the `~` (home) prefix ([#19058](https://github.com/Lightning-AI/lightning/pull/19058))
- Fixed handling checkpoint dirpath suffix in NeptuneLogger ([#18863](https://github.com/Lightning-AI/lightning/pull/18863))
## [2.1.2] - 2023-11-15

View File

@ -557,13 +557,14 @@ class NeptuneLogger(Logger):
def _get_full_model_name(model_path: str, checkpoint_callback: Checkpoint) -> str:
"""Returns model name which is string `model_path` appended to `checkpoint_callback.dirpath`."""
if hasattr(checkpoint_callback, "dirpath"):
expected_model_path = f"{checkpoint_callback.dirpath}{os.path.sep}"
model_path = os.path.normpath(model_path)
expected_model_path = os.path.normpath(checkpoint_callback.dirpath)
if not model_path.startswith(expected_model_path):
raise ValueError(f"{model_path} was expected to start with {expected_model_path}.")
# Remove extension from filepath
filepath, _ = os.path.splitext(model_path[len(expected_model_path) :])
return filepath
return model_path
filepath, _ = os.path.splitext(model_path[len(expected_model_path) + 1 :])
return filepath.replace(os.sep, "/")
return model_path.replace(os.sep, "/")
@classmethod
def _get_full_model_names_from_exp_structure(cls, exp_structure: Dict[str, Any], namespace: str) -> Set[str]:

View File

@ -284,10 +284,12 @@ def test_get_full_model_name():
os.path.join("foo", "bar", "key/in/parts.ext"),
SimpleCheckpoint(dirpath=os.path.join("foo", "bar")),
),
("key", os.path.join("../foo", "bar", "key.ext"), SimpleCheckpoint(dirpath=os.path.join("../foo", "bar"))),
("key", os.path.join("foo", "key.ext"), SimpleCheckpoint(dirpath=os.path.join("./foo", "bar/../"))),
]
for expected_model_name, *key_and_path in test_input_data:
assert NeptuneLogger._get_full_model_name(*key_and_path) == expected_model_name
for expected_model_name, model_path, checkpoint in test_input_data:
assert NeptuneLogger._get_full_model_name(model_path, checkpoint) == expected_model_name
def test_get_full_model_names_from_exp_structure():