diff --git a/pytorch_lightning/loggers/tensorboard.py b/pytorch_lightning/loggers/tensorboard.py index 94268f6063..b69f31ae53 100644 --- a/pytorch_lightning/loggers/tensorboard.py +++ b/pytorch_lightning/loggers/tensorboard.py @@ -267,14 +267,16 @@ class TensorBoardLogger(LightningLoggerBase): return self._version def _get_next_version(self): - root_dir = os.path.join(self.save_dir, self.name) + root_dir = self.root_dir - if not self._fs.isdir(root_dir): + try: + listdir_info = self._fs.listdir(root_dir) + except OSError: log.warning('Missing logger folder: %s', root_dir) return 0 existing_versions = [] - for listing in self._fs.listdir(root_dir): + for listing in listdir_info: d = listing["name"] bn = os.path.basename(d) if self._fs.isdir(d) and bn.startswith("version_"): diff --git a/tests/loggers/test_tensorboard.py b/tests/loggers/test_tensorboard.py index f7fe1c3bfd..ffd89a0c14 100644 --- a/tests/loggers/test_tensorboard.py +++ b/tests/loggers/test_tensorboard.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import logging import os from argparse import Namespace from unittest import mock @@ -340,3 +341,15 @@ def test_tensorboard_with_symlink(log, tmpdir): _ = logger.version log.warning.assert_not_called() + + +def test_tensorboard_missing_folder_warning(tmpdir, caplog): + """Verify that the logger throws a warning for invalid directory""" + + name = "fake_dir" + logger = TensorBoardLogger(save_dir=tmpdir, name=name) + + with caplog.at_level(logging.WARNING): + assert logger.version == 0 + + assert 'Missing logger folder:' in caplog.text