add version_ prefix to log_dir (#706)
* add version_ prefix to log_dir * add version_ prefix
This commit is contained in:
parent
53b7644c15
commit
de2ccc03a8
|
@ -63,7 +63,7 @@ class TensorBoardLogger(LightningLoggerBase):
|
||||||
|
|
||||||
root_dir = os.path.join(self.save_dir, self.name)
|
root_dir = os.path.join(self.save_dir, self.name)
|
||||||
os.makedirs(root_dir, exist_ok=True)
|
os.makedirs(root_dir, exist_ok=True)
|
||||||
log_dir = os.path.join(root_dir, str(self.version))
|
log_dir = os.path.join(root_dir, "version_" + str(self.version))
|
||||||
self._experiment = SummaryWriter(log_dir=log_dir, **self.kwargs)
|
self._experiment = SummaryWriter(log_dir=log_dir, **self.kwargs)
|
||||||
return self._experiment
|
return self._experiment
|
||||||
|
|
||||||
|
@ -131,9 +131,11 @@ class TensorBoardLogger(LightningLoggerBase):
|
||||||
|
|
||||||
def _get_next_version(self):
|
def _get_next_version(self):
|
||||||
root_dir = os.path.join(self.save_dir, self.name)
|
root_dir = os.path.join(self.save_dir, self.name)
|
||||||
existing_versions = [
|
existing_versions = []
|
||||||
int(d) for d in os.listdir(root_dir) if os.path.isdir(os.path.join(root_dir, d)) and d.isdigit()
|
for d in os.listdir(root_dir):
|
||||||
]
|
if os.path.isdir(os.path.join(root_dir, d)) and d.startswith("version_"):
|
||||||
|
existing_versions.append(int(d.split("_")[1]))
|
||||||
|
|
||||||
if len(existing_versions) == 0:
|
if len(existing_versions) == 0:
|
||||||
return 0
|
return 0
|
||||||
else:
|
else:
|
||||||
|
|
|
@ -296,8 +296,8 @@ def test_tensorboard_automatic_versioning(tmpdir):
|
||||||
"""Verify that automatic versioning works"""
|
"""Verify that automatic versioning works"""
|
||||||
|
|
||||||
root_dir = tmpdir.mkdir("tb_versioning")
|
root_dir = tmpdir.mkdir("tb_versioning")
|
||||||
root_dir.mkdir("0")
|
root_dir.mkdir("version_0")
|
||||||
root_dir.mkdir("1")
|
root_dir.mkdir("version_1")
|
||||||
|
|
||||||
logger = TensorBoardLogger(save_dir=tmpdir, name="tb_versioning")
|
logger = TensorBoardLogger(save_dir=tmpdir, name="tb_versioning")
|
||||||
|
|
||||||
|
@ -308,9 +308,9 @@ def test_tensorboard_manual_versioning(tmpdir):
|
||||||
"""Verify that manual versioning works"""
|
"""Verify that manual versioning works"""
|
||||||
|
|
||||||
root_dir = tmpdir.mkdir("tb_versioning")
|
root_dir = tmpdir.mkdir("tb_versioning")
|
||||||
root_dir.mkdir("0")
|
root_dir.mkdir("version_0")
|
||||||
root_dir.mkdir("1")
|
root_dir.mkdir("version_1")
|
||||||
root_dir.mkdir("2")
|
root_dir.mkdir("version_2")
|
||||||
|
|
||||||
logger = TensorBoardLogger(save_dir=tmpdir, name="tb_versioning", version=1)
|
logger = TensorBoardLogger(save_dir=tmpdir, name="tb_versioning", version=1)
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue