Removed duplicated file extension when uploading model checkpoints with NeptuneLogger (#11015)
This commit is contained in:
parent
5576fbc5f9
commit
ed84cef3af
|
@ -102,6 +102,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
|
|||
* Some configuration errors that were previously raised as `MisconfigurationException`s will now be raised as `ProcessRaisedException` (torch>=1.8) or as `Exception` (torch<1.8)
|
||||
|
||||
|
||||
- Removed duplicated file extension when uploading model checkpoints with `NeptuneLogger` ([#11015](https://github.com/PyTorchLightning/pytorch-lightning/pull/11015))
|
||||
|
||||
|
||||
### Deprecated
|
||||
|
||||
- Deprecated `ClusterEnvironment.master_{address,port}` in favor of `ClusterEnvironment.main_{address,port}` ([#10103](https://github.com/PyTorchLightning/pytorch-lightning/issues/10103))
|
||||
|
|
|
@ -553,7 +553,10 @@ class NeptuneLogger(LightningLoggerBase):
|
|||
expected_model_path = f"{checkpoint_callback.dirpath}{os.path.sep}"
|
||||
if not model_path.startswith(expected_model_path):
|
||||
raise ValueError(f"{model_path} was expected to start with {expected_model_path}.")
|
||||
return model_path[len(expected_model_path) :]
|
||||
# Remove extension from filepath
|
||||
filepath, _ = os.path.splitext(model_path[len(expected_model_path) :])
|
||||
|
||||
return filepath
|
||||
|
||||
@classmethod
|
||||
def _get_full_model_names_from_exp_structure(cls, exp_structure: dict, namespace: str) -> Set[str]:
|
||||
|
|
|
@ -397,9 +397,9 @@ class TestNeptuneLoggerUtils(unittest.TestCase):
|
|||
# given:
|
||||
SimpleCheckpoint = namedtuple("SimpleCheckpoint", ["dirpath"])
|
||||
test_input_data = [
|
||||
("key.ext", os.path.join("foo", "bar", "key.ext"), SimpleCheckpoint(dirpath=os.path.join("foo", "bar"))),
|
||||
("key", os.path.join("foo", "bar", "key.ext"), SimpleCheckpoint(dirpath=os.path.join("foo", "bar"))),
|
||||
(
|
||||
"key/in/parts.ext",
|
||||
"key/in/parts",
|
||||
os.path.join("foo", "bar", "key/in/parts.ext"),
|
||||
SimpleCheckpoint(dirpath=os.path.join("foo", "bar")),
|
||||
),
|
||||
|
|
Loading…
Reference in New Issue