From 7f4ef6d135b8b824d23be9fe2ab21f184b12ea9b Mon Sep 17 00:00:00 2001 From: Kaushik B <45285388+kaushikb11@users.noreply.github.com> Date: Wed, 9 Jun 2021 14:35:01 +0530 Subject: [PATCH] Fix logs overwriting issue for remote fs (#7889) * Fix logs overwriting issue for remote fs * Add test --- pytorch_lightning/loggers/tensorboard.py | 8 +++++--- tests/loggers/test_tensorboard.py | 13 +++++++++++++ 2 files changed, 18 insertions(+), 3 deletions(-) diff --git a/pytorch_lightning/loggers/tensorboard.py b/pytorch_lightning/loggers/tensorboard.py index 94268f6063..b69f31ae53 100644 --- a/pytorch_lightning/loggers/tensorboard.py +++ b/pytorch_lightning/loggers/tensorboard.py @@ -267,14 +267,16 @@ class TensorBoardLogger(LightningLoggerBase): return self._version def _get_next_version(self): - root_dir = os.path.join(self.save_dir, self.name) + root_dir = self.root_dir - if not self._fs.isdir(root_dir): + try: + listdir_info = self._fs.listdir(root_dir) + except OSError: log.warning('Missing logger folder: %s', root_dir) return 0 existing_versions = [] - for listing in self._fs.listdir(root_dir): + for listing in listdir_info: d = listing["name"] bn = os.path.basename(d) if self._fs.isdir(d) and bn.startswith("version_"): diff --git a/tests/loggers/test_tensorboard.py b/tests/loggers/test_tensorboard.py index f7fe1c3bfd..ffd89a0c14 100644 --- a/tests/loggers/test_tensorboard.py +++ b/tests/loggers/test_tensorboard.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import logging import os from argparse import Namespace from unittest import mock @@ -340,3 +341,15 @@ def test_tensorboard_with_symlink(log, tmpdir): _ = logger.version log.warning.assert_not_called() + + +def test_tensorboard_missing_folder_warning(tmpdir, caplog): + """Verify that the logger throws a warning for invalid directory""" + + name = "fake_dir" + logger = TensorBoardLogger(save_dir=tmpdir, name=name) + + with caplog.at_level(logging.WARNING): + assert logger.version == 0 + + assert 'Missing logger folder:' in caplog.text