Fix logs overwriting issue for remote fs (#7889)
* Fix logs overwriting issue for remote fs * Add test
This commit is contained in:
parent
c310ce661e
commit
7f4ef6d135
|
@ -267,14 +267,16 @@ class TensorBoardLogger(LightningLoggerBase):
|
|||
return self._version
|
||||
|
||||
def _get_next_version(self):
|
||||
root_dir = os.path.join(self.save_dir, self.name)
|
||||
root_dir = self.root_dir
|
||||
|
||||
if not self._fs.isdir(root_dir):
|
||||
try:
|
||||
listdir_info = self._fs.listdir(root_dir)
|
||||
except OSError:
|
||||
log.warning('Missing logger folder: %s', root_dir)
|
||||
return 0
|
||||
|
||||
existing_versions = []
|
||||
for listing in self._fs.listdir(root_dir):
|
||||
for listing in listdir_info:
|
||||
d = listing["name"]
|
||||
bn = os.path.basename(d)
|
||||
if self._fs.isdir(d) and bn.startswith("version_"):
|
||||
|
|
|
@ -11,6 +11,7 @@
|
|||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import logging
|
||||
import os
|
||||
from argparse import Namespace
|
||||
from unittest import mock
|
||||
|
@ -340,3 +341,15 @@ def test_tensorboard_with_symlink(log, tmpdir):
|
|||
_ = logger.version
|
||||
|
||||
log.warning.assert_not_called()
|
||||
|
||||
|
||||
def test_tensorboard_missing_folder_warning(tmpdir, caplog):
|
||||
"""Verify that the logger throws a warning for invalid directory"""
|
||||
|
||||
name = "fake_dir"
|
||||
logger = TensorBoardLogger(save_dir=tmpdir, name=name)
|
||||
|
||||
with caplog.at_level(logging.WARNING):
|
||||
assert logger.version == 0
|
||||
|
||||
assert 'Missing logger folder:' in caplog.text
|
||||
|
|
Loading…
Reference in New Issue