Handle checkpoint dirpath suffix in NeptuneLogger (#18863)
Co-authored-by: Siddhant Sadangi <siddhant.sadangi@gmail.com> Co-authored-by: Sabine <sabine.nyholm@neptune.ai> Co-authored-by: Carlos Mocholí <carlossmocholi@gmail.com> Co-authored-by: awaelchli <aedu.waelchli@gmail.com>
This commit is contained in:
parent
1fcb4ae637
commit
af852ff590
|
@ -56,6 +56,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
|
|||
- Fixed `ModelCheckpoint` not expanding the `dirpath` if it has the `~` (home) prefix ([#19058](https://github.com/Lightning-AI/lightning/pull/19058))
|
||||
|
||||
|
||||
- Fixed handling checkpoint dirpath suffix in NeptuneLogger ([#18863](https://github.com/Lightning-AI/lightning/pull/18863))
|
||||
|
||||
|
||||
|
||||
## [2.1.2] - 2023-11-15
|
||||
|
||||
|
|
|
@ -557,13 +557,14 @@ class NeptuneLogger(Logger):
|
|||
def _get_full_model_name(model_path: str, checkpoint_callback: Checkpoint) -> str:
|
||||
"""Returns model name which is string `model_path` appended to `checkpoint_callback.dirpath`."""
|
||||
if hasattr(checkpoint_callback, "dirpath"):
|
||||
expected_model_path = f"{checkpoint_callback.dirpath}{os.path.sep}"
|
||||
model_path = os.path.normpath(model_path)
|
||||
expected_model_path = os.path.normpath(checkpoint_callback.dirpath)
|
||||
if not model_path.startswith(expected_model_path):
|
||||
raise ValueError(f"{model_path} was expected to start with {expected_model_path}.")
|
||||
# Remove extension from filepath
|
||||
filepath, _ = os.path.splitext(model_path[len(expected_model_path) :])
|
||||
return filepath
|
||||
return model_path
|
||||
filepath, _ = os.path.splitext(model_path[len(expected_model_path) + 1 :])
|
||||
return filepath.replace(os.sep, "/")
|
||||
return model_path.replace(os.sep, "/")
|
||||
|
||||
@classmethod
|
||||
def _get_full_model_names_from_exp_structure(cls, exp_structure: Dict[str, Any], namespace: str) -> Set[str]:
|
||||
|
|
|
@ -284,10 +284,12 @@ def test_get_full_model_name():
|
|||
os.path.join("foo", "bar", "key/in/parts.ext"),
|
||||
SimpleCheckpoint(dirpath=os.path.join("foo", "bar")),
|
||||
),
|
||||
("key", os.path.join("../foo", "bar", "key.ext"), SimpleCheckpoint(dirpath=os.path.join("../foo", "bar"))),
|
||||
("key", os.path.join("foo", "key.ext"), SimpleCheckpoint(dirpath=os.path.join("./foo", "bar/../"))),
|
||||
]
|
||||
|
||||
for expected_model_name, *key_and_path in test_input_data:
|
||||
assert NeptuneLogger._get_full_model_name(*key_and_path) == expected_model_name
|
||||
for expected_model_name, model_path, checkpoint in test_input_data:
|
||||
assert NeptuneLogger._get_full_model_name(model_path, checkpoint) == expected_model_name
|
||||
|
||||
|
||||
def test_get_full_model_names_from_exp_structure():
|
||||
|
|
Loading…
Reference in New Issue