lightning/tests/tests_fabric/loggers/test_tensorboard.py

230 lines
8.2 KiB
Python

# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import os
from argparse import Namespace
from unittest import mock
from unittest.mock import Mock
import numpy as np
import pytest
import torch
from tests_fabric.test_fabric import BoringModel
from lightning_fabric.loggers import TensorBoardLogger
from lightning_fabric.loggers.tensorboard import _TENSORBOARD_AVAILABLE
def test_tensorboard_automatic_versioning(tmpdir):
"""Verify that automatic versioning works."""
root_dir = tmpdir / "tb_versioning"
root_dir.mkdir()
(root_dir / "version_0").mkdir()
(root_dir / "version_1").mkdir()
logger = TensorBoardLogger(root_dir=tmpdir, name="tb_versioning")
assert logger.version == 2
def test_tensorboard_manual_versioning(tmpdir):
"""Verify that manual versioning works."""
root_dir = tmpdir / "tb_versioning"
root_dir.mkdir()
(root_dir / "version_0").mkdir()
(root_dir / "version_1").mkdir()
(root_dir / "version_2").mkdir()
logger = TensorBoardLogger(root_dir=tmpdir, name="tb_versioning", version=1)
assert logger.version == 1
def test_tensorboard_named_version(tmpdir):
"""Verify that manual versioning works for string versions, e.g. '2020-02-05-162402'."""
name = "tb_versioning"
(tmpdir / name).mkdir()
expected_version = "2020-02-05-162402"
logger = TensorBoardLogger(root_dir=tmpdir, name=name, version=expected_version)
logger.log_hyperparams({"a": 1, "b": 2, 123: 3, 3.5: 4, 5j: 5}) # Force data to be written
assert logger.version == expected_version
assert os.listdir(tmpdir / name) == [expected_version]
assert os.listdir(tmpdir / name / expected_version)
@pytest.mark.parametrize("name", ["", None])
def test_tensorboard_no_name(tmpdir, name):
"""Verify that None or empty name works."""
logger = TensorBoardLogger(root_dir=tmpdir, name=name)
logger.log_hyperparams({"a": 1, "b": 2, 123: 3, 3.5: 4, 5j: 5}) # Force data to be written
assert os.path.normpath(logger.root_dir) == tmpdir # use os.path.normpath to handle trailing /
assert os.listdir(tmpdir / "version_0")
def test_tensorboard_log_sub_dir(tmpdir):
# no sub_dir specified
root_dir = tmpdir / "logs"
logger = TensorBoardLogger(root_dir, name="name", version="version")
assert logger.log_dir == os.path.join(root_dir, "name", "version")
# sub_dir specified
logger = TensorBoardLogger(root_dir, name="name", version="version", sub_dir="sub_dir")
assert logger.log_dir == os.path.join(root_dir, "name", "version", "sub_dir")
def test_tensorboard_expand_home():
"""Test that the home dir (`~`) gets expanded properly."""
root_dir = "~/tmp"
explicit_root_dir = os.path.expanduser(root_dir)
logger = TensorBoardLogger(root_dir, name="name", version="version", sub_dir="sub_dir")
assert logger.root_dir == root_dir
assert logger.log_dir == os.path.join(explicit_root_dir, "name", "version", "sub_dir")
@mock.patch.dict(os.environ, {"TEST_ENV_DIR": "some_directory"})
def test_tensorboard_expand_env_vars():
"""Test that the env vars in path names (`$`) get handled properly."""
test_env_dir = os.environ["TEST_ENV_DIR"]
root_dir = "$TEST_ENV_DIR/tmp"
explicit_root_dir = f"{test_env_dir}/tmp"
logger = TensorBoardLogger(root_dir, name="name", version="version", sub_dir="sub_dir")
assert logger.log_dir == os.path.join(explicit_root_dir, "name", "version", "sub_dir")
@pytest.mark.parametrize("step_idx", [10, None])
def test_tensorboard_log_metrics(tmpdir, step_idx):
logger = TensorBoardLogger(tmpdir)
metrics = {"float": 0.3, "int": 1, "FloatTensor": torch.tensor(0.1), "IntTensor": torch.tensor(1)}
logger.log_metrics(metrics, step_idx)
def test_tensorboard_log_hyperparams(tmpdir):
logger = TensorBoardLogger(tmpdir)
hparams = {
"float": 0.3,
"int": 1,
"string": "abc",
"bool": True,
"dict": {"a": {"b": "c"}},
"list": [1, 2, 3],
"namespace": Namespace(foo=Namespace(bar="buzz")),
"layer": torch.nn.BatchNorm1d,
"tensor": torch.empty(2, 2, 2),
"array": np.empty([2, 2, 2]),
}
logger.log_hyperparams(hparams)
def test_tensorboard_log_hparams_and_metrics(tmpdir):
logger = TensorBoardLogger(tmpdir, default_hp_metric=False)
hparams = {
"float": 0.3,
"int": 1,
"string": "abc",
"bool": True,
"dict": {"a": {"b": "c"}},
"list": [1, 2, 3],
"namespace": Namespace(foo=Namespace(bar="buzz")),
"layer": torch.nn.BatchNorm1d,
"tensor": torch.empty(2, 2, 2),
"array": np.empty([2, 2, 2]),
}
metrics = {"abc": torch.tensor([0.54])}
logger.log_hyperparams(hparams, metrics)
@pytest.mark.parametrize("example_input_array", [None, torch.rand(2, 32)])
def test_tensorboard_log_graph(tmpdir, example_input_array):
"""test that log graph works with both model.example_input_array and if array is passed externally."""
# TODO(fabric): Test both nn.Module and LightningModule
# TODO(fabric): Assert _apply_batch_transfer_handler is calling the batch transfer hooks
model = BoringModel()
if example_input_array is not None:
model.example_input_array = None
logger = TensorBoardLogger(tmpdir, log_graph=True)
logger.log_graph(model, example_input_array)
@pytest.mark.skipif(not _TENSORBOARD_AVAILABLE, reason=str(_TENSORBOARD_AVAILABLE))
def test_tensorboard_log_graph_warning_no_example_input_array(tmpdir):
"""test that log graph throws warning if model.example_input_array is None."""
model = BoringModel()
model.example_input_array = None
logger = TensorBoardLogger(tmpdir, log_graph=True)
with pytest.warns(
UserWarning,
match="Could not log computational graph to TensorBoard: The `model.example_input_array` .* was not given",
):
logger.log_graph(model)
model.example_input_array = dict(x=1, y=2)
with pytest.warns(
UserWarning, match="Could not log computational graph to TensorBoard: .* can't be traced by TensorBoard"
):
logger.log_graph(model)
def test_tensorboard_finalize(monkeypatch, tmpdir):
"""Test that the SummaryWriter closes in finalize."""
if _TENSORBOARD_AVAILABLE:
import torch.utils.tensorboard as tb
else:
import tensorboardX as tb
monkeypatch.setattr(tb, "SummaryWriter", Mock())
logger = TensorBoardLogger(root_dir=tmpdir)
assert logger._experiment is None
logger.finalize("any")
# no log calls, no experiment created -> nothing to flush
logger.experiment.assert_not_called()
logger = TensorBoardLogger(root_dir=tmpdir)
logger.log_metrics({"flush_me": 11.1}) # trigger creation of an experiment
logger.finalize("any")
# finalize flushes to experiment directory
logger.experiment.flush.assert_called()
logger.experiment.close.assert_called()
@mock.patch("lightning_fabric.loggers.tensorboard.log")
def test_tensorboard_with_symlink(log, tmpdir):
"""Tests a specific failure case when tensorboard logger is used with empty name, symbolic link ``save_dir``,
and relative paths."""
os.chdir(tmpdir) # need to use relative paths
source = os.path.join(".", "lightning_logs")
dest = os.path.join(".", "sym_lightning_logs")
os.makedirs(source, exist_ok=True)
os.symlink(source, dest)
logger = TensorBoardLogger(root_dir=dest, name="")
_ = logger.version
log.warning.assert_not_called()
def test_tensorboard_missing_folder_warning(tmpdir, caplog):
"""Verify that the logger throws a warning for invalid directory."""
name = "fake_dir"
logger = TensorBoardLogger(root_dir=tmpdir, name=name)
with caplog.at_level(logging.WARNING):
assert logger.version == 0
assert "Missing logger folder:" in caplog.text