Fix logs overwriting issue for remote fs (#7889)

* Fix logs overwriting issue for remote fs

* Add test
This commit is contained in:
Kaushik B 2021-06-09 14:35:01 +05:30 committed by GitHub
parent c310ce661e
commit 7f4ef6d135
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 18 additions and 3 deletions

View File

@ -267,14 +267,16 @@ class TensorBoardLogger(LightningLoggerBase):
return self._version
def _get_next_version(self):
root_dir = os.path.join(self.save_dir, self.name)
root_dir = self.root_dir
if not self._fs.isdir(root_dir):
try:
listdir_info = self._fs.listdir(root_dir)
except OSError:
log.warning('Missing logger folder: %s', root_dir)
return 0
existing_versions = []
for listing in self._fs.listdir(root_dir):
for listing in listdir_info:
d = listing["name"]
bn = os.path.basename(d)
if self._fs.isdir(d) and bn.startswith("version_"):

View File

@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import os
from argparse import Namespace
from unittest import mock
@ -340,3 +341,15 @@ def test_tensorboard_with_symlink(log, tmpdir):
_ = logger.version
log.warning.assert_not_called()
def test_tensorboard_missing_folder_warning(tmpdir, caplog):
"""Verify that the logger throws a warning for invalid directory"""
name = "fake_dir"
logger = TensorBoardLogger(save_dir=tmpdir, name=name)
with caplog.at_level(logging.WARNING):
assert logger.version == 0
assert 'Missing logger folder:' in caplog.text