add version_ prefix to log_dir (#706)

* add version_ prefix to log_dir

* add version_ prefix
This commit is contained in:
Z ZH 2020-01-18 21:17:54 +09:00 committed by William Falcon
parent 53b7644c15
commit de2ccc03a8
2 changed files with 11 additions and 9 deletions

View File

@ -63,7 +63,7 @@ class TensorBoardLogger(LightningLoggerBase):
root_dir = os.path.join(self.save_dir, self.name) root_dir = os.path.join(self.save_dir, self.name)
os.makedirs(root_dir, exist_ok=True) os.makedirs(root_dir, exist_ok=True)
log_dir = os.path.join(root_dir, str(self.version)) log_dir = os.path.join(root_dir, "version_" + str(self.version))
self._experiment = SummaryWriter(log_dir=log_dir, **self.kwargs) self._experiment = SummaryWriter(log_dir=log_dir, **self.kwargs)
return self._experiment return self._experiment
@ -131,9 +131,11 @@ class TensorBoardLogger(LightningLoggerBase):
def _get_next_version(self): def _get_next_version(self):
root_dir = os.path.join(self.save_dir, self.name) root_dir = os.path.join(self.save_dir, self.name)
existing_versions = [ existing_versions = []
int(d) for d in os.listdir(root_dir) if os.path.isdir(os.path.join(root_dir, d)) and d.isdigit() for d in os.listdir(root_dir):
] if os.path.isdir(os.path.join(root_dir, d)) and d.startswith("version_"):
existing_versions.append(int(d.split("_")[1]))
if len(existing_versions) == 0: if len(existing_versions) == 0:
return 0 return 0
else: else:

View File

@ -296,8 +296,8 @@ def test_tensorboard_automatic_versioning(tmpdir):
"""Verify that automatic versioning works""" """Verify that automatic versioning works"""
root_dir = tmpdir.mkdir("tb_versioning") root_dir = tmpdir.mkdir("tb_versioning")
root_dir.mkdir("0") root_dir.mkdir("version_0")
root_dir.mkdir("1") root_dir.mkdir("version_1")
logger = TensorBoardLogger(save_dir=tmpdir, name="tb_versioning") logger = TensorBoardLogger(save_dir=tmpdir, name="tb_versioning")
@ -308,9 +308,9 @@ def test_tensorboard_manual_versioning(tmpdir):
"""Verify that manual versioning works""" """Verify that manual versioning works"""
root_dir = tmpdir.mkdir("tb_versioning") root_dir = tmpdir.mkdir("tb_versioning")
root_dir.mkdir("0") root_dir.mkdir("version_0")
root_dir.mkdir("1") root_dir.mkdir("version_1")
root_dir.mkdir("2") root_dir.mkdir("version_2")
logger = TensorBoardLogger(save_dir=tmpdir, name="tb_versioning", version=1) logger = TensorBoardLogger(save_dir=tmpdir, name="tb_versioning", version=1)