From de2ccc03a8df997b8841f33ae70050498960f08c Mon Sep 17 00:00:00 2001 From: Z ZH Date: Sat, 18 Jan 2020 21:17:54 +0900 Subject: [PATCH] add version_ prefix to log_dir (#706) * add version_ prefix to log_dir * add version_ prefix --- pytorch_lightning/logging/tensorboard.py | 10 ++++++---- tests/test_logging.py | 10 +++++----- 2 files changed, 11 insertions(+), 9 deletions(-) diff --git a/pytorch_lightning/logging/tensorboard.py b/pytorch_lightning/logging/tensorboard.py index 937fc3e6fd..73862a0755 100644 --- a/pytorch_lightning/logging/tensorboard.py +++ b/pytorch_lightning/logging/tensorboard.py @@ -63,7 +63,7 @@ class TensorBoardLogger(LightningLoggerBase): root_dir = os.path.join(self.save_dir, self.name) os.makedirs(root_dir, exist_ok=True) - log_dir = os.path.join(root_dir, str(self.version)) + log_dir = os.path.join(root_dir, "version_" + str(self.version)) self._experiment = SummaryWriter(log_dir=log_dir, **self.kwargs) return self._experiment @@ -131,9 +131,11 @@ class TensorBoardLogger(LightningLoggerBase): def _get_next_version(self): root_dir = os.path.join(self.save_dir, self.name) - existing_versions = [ - int(d) for d in os.listdir(root_dir) if os.path.isdir(os.path.join(root_dir, d)) and d.isdigit() - ] + existing_versions = [] + for d in os.listdir(root_dir): + if os.path.isdir(os.path.join(root_dir, d)) and d.startswith("version_"): + existing_versions.append(int(d.split("_")[1])) + if len(existing_versions) == 0: return 0 else: diff --git a/tests/test_logging.py b/tests/test_logging.py index 5467d0aab3..a91f6087d9 100644 --- a/tests/test_logging.py +++ b/tests/test_logging.py @@ -296,8 +296,8 @@ def test_tensorboard_automatic_versioning(tmpdir): """Verify that automatic versioning works""" root_dir = tmpdir.mkdir("tb_versioning") - root_dir.mkdir("0") - root_dir.mkdir("1") + root_dir.mkdir("version_0") + root_dir.mkdir("version_1") logger = TensorBoardLogger(save_dir=tmpdir, name="tb_versioning") @@ -308,9 +308,9 @@ def test_tensorboard_manual_versioning(tmpdir): """Verify that manual versioning works""" root_dir = tmpdir.mkdir("tb_versioning") - root_dir.mkdir("0") - root_dir.mkdir("1") - root_dir.mkdir("2") + root_dir.mkdir("version_0") + root_dir.mkdir("version_1") + root_dir.mkdir("version_2") logger = TensorBoardLogger(save_dir=tmpdir, name="tb_versioning", version=1)