Removed duplicated file extension when uploading model checkpoints with NeptuneLogger (#11015)

This commit is contained in:
Rafał Jankowski 2021-12-10 19:33:12 +01:00 committed by GitHub
parent 5576fbc5f9
commit ed84cef3af
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 9 additions and 3 deletions

View File

@ -102,6 +102,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
* Some configuration errors that were previously raised as `MisconfigurationException`s will now be raised as `ProcessRaisedException` (torch>=1.8) or as `Exception` (torch<1.8)
- Removed duplicated file extension when uploading model checkpoints with `NeptuneLogger` ([#11015](https://github.com/PyTorchLightning/pytorch-lightning/pull/11015))
### Deprecated
- Deprecated `ClusterEnvironment.master_{address,port}` in favor of `ClusterEnvironment.main_{address,port}` ([#10103](https://github.com/PyTorchLightning/pytorch-lightning/issues/10103))

View File

@ -553,7 +553,10 @@ class NeptuneLogger(LightningLoggerBase):
expected_model_path = f"{checkpoint_callback.dirpath}{os.path.sep}"
if not model_path.startswith(expected_model_path):
raise ValueError(f"{model_path} was expected to start with {expected_model_path}.")
return model_path[len(expected_model_path) :]
# Remove extension from filepath
filepath, _ = os.path.splitext(model_path[len(expected_model_path) :])
return filepath
@classmethod
def _get_full_model_names_from_exp_structure(cls, exp_structure: dict, namespace: str) -> Set[str]:

View File

@ -397,9 +397,9 @@ class TestNeptuneLoggerUtils(unittest.TestCase):
# given:
SimpleCheckpoint = namedtuple("SimpleCheckpoint", ["dirpath"])
test_input_data = [
("key.ext", os.path.join("foo", "bar", "key.ext"), SimpleCheckpoint(dirpath=os.path.join("foo", "bar"))),
("key", os.path.join("foo", "bar", "key.ext"), SimpleCheckpoint(dirpath=os.path.join("foo", "bar"))),
(
"key/in/parts.ext",
"key/in/parts",
os.path.join("foo", "bar", "key/in/parts.ext"),
SimpleCheckpoint(dirpath=os.path.join("foo", "bar")),
),