diff --git a/CHANGELOG.md b/CHANGELOG.md index d1fc665048..d1a2e19679 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -102,6 +102,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). * Some configuration errors that were previously raised as `MisconfigurationException`s will now be raised as `ProcessRaisedException` (torch>=1.8) or as `Exception` (torch<1.8) +- Removed duplicated file extension when uploading model checkpoints with `NeptuneLogger` ([#11015](https://github.com/PyTorchLightning/pytorch-lightning/pull/11015)) + + ### Deprecated - Deprecated `ClusterEnvironment.master_{address,port}` in favor of `ClusterEnvironment.main_{address,port}` ([#10103](https://github.com/PyTorchLightning/pytorch-lightning/issues/10103)) diff --git a/pytorch_lightning/loggers/neptune.py b/pytorch_lightning/loggers/neptune.py index 569223754e..398b93e06c 100644 --- a/pytorch_lightning/loggers/neptune.py +++ b/pytorch_lightning/loggers/neptune.py @@ -553,7 +553,10 @@ class NeptuneLogger(LightningLoggerBase): expected_model_path = f"{checkpoint_callback.dirpath}{os.path.sep}" if not model_path.startswith(expected_model_path): raise ValueError(f"{model_path} was expected to start with {expected_model_path}.") - return model_path[len(expected_model_path) :] + # Remove extension from filepath + filepath, _ = os.path.splitext(model_path[len(expected_model_path) :]) + + return filepath @classmethod def _get_full_model_names_from_exp_structure(cls, exp_structure: dict, namespace: str) -> Set[str]: diff --git a/tests/loggers/test_neptune.py b/tests/loggers/test_neptune.py index 6238b408c1..cb7ef9c515 100644 --- a/tests/loggers/test_neptune.py +++ b/tests/loggers/test_neptune.py @@ -397,9 +397,9 @@ class TestNeptuneLoggerUtils(unittest.TestCase): # given: SimpleCheckpoint = namedtuple("SimpleCheckpoint", ["dirpath"]) test_input_data = [ - ("key.ext", os.path.join("foo", "bar", "key.ext"), SimpleCheckpoint(dirpath=os.path.join("foo", "bar"))), + ("key", os.path.join("foo", "bar", "key.ext"), SimpleCheckpoint(dirpath=os.path.join("foo", "bar"))), ( - "key/in/parts.ext", + "key/in/parts", os.path.join("foo", "bar", "key/in/parts.ext"), SimpleCheckpoint(dirpath=os.path.join("foo", "bar")), ),