Lazy import tensorboard (#15762)

This commit is contained in:
Carlos Mocholí 2022-11-23 13:59:38 +01:00 committed by GitHub
parent 952b64b358
commit a9708105f7
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 59 additions and 19 deletions

View File

@ -19,12 +19,10 @@ TensorBoard Logger
import logging
import os
from argparse import Namespace
from typing import Any, Dict, Mapping, Optional, Union
from typing import Any, Dict, Mapping, Optional, TYPE_CHECKING, Union
import numpy as np
from lightning_utilities.core.imports import RequirementCache
from tensorboardX import SummaryWriter
from tensorboardX.summary import hparams
from torch import Tensor
import pytorch_lightning as pl
@ -40,6 +38,13 @@ from pytorch_lightning.utilities.rank_zero import rank_zero_only, rank_zero_warn
log = logging.getLogger(__name__)
_TENSORBOARD_AVAILABLE = RequirementCache("tensorboard")
_TENSORBOARDX_AVAILABLE = RequirementCache("tensorboardX")
if TYPE_CHECKING:
# assumes at least one will be installed when type checking
if _TENSORBOARD_AVAILABLE:
from torch.utils.tensorboard import SummaryWriter
else:
from tensorboardX import SummaryWriter # type: ignore[no-redef]
if _OMEGACONF_AVAILABLE:
from omegaconf import Container, OmegaConf
@ -109,6 +114,10 @@ class TensorBoardLogger(Logger):
sub_dir: Optional[_PATH] = None,
**kwargs: Any,
):
if not _TENSORBOARD_AVAILABLE and not _TENSORBOARDX_AVAILABLE:
raise ModuleNotFoundError(
"Neither `tensorboard` nor `tensorboardX` is available. Try `pip install`ing either."
)
super().__init__()
save_dir = os.fspath(save_dir)
self._save_dir = save_dir
@ -172,7 +181,7 @@ class TensorBoardLogger(Logger):
@property
@rank_zero_experiment
def experiment(self) -> SummaryWriter:
def experiment(self) -> "SummaryWriter":
r"""
Actual tensorboard object. To use TensorBoard features in your
:class:`~pytorch_lightning.core.module.LightningModule` do the following.
@ -188,6 +197,12 @@ class TensorBoardLogger(Logger):
assert rank_zero_only.rank == 0, "tried to init log dirs in non global_rank=0"
if self.root_dir:
self._fs.makedirs(self.root_dir, exist_ok=True)
if _TENSORBOARD_AVAILABLE:
from torch.utils.tensorboard import SummaryWriter
else:
from tensorboardX import SummaryWriter # type: ignore[no-redef]
self._experiment = SummaryWriter(log_dir=self.log_dir, **self._kwargs)
return self._experiment
@ -224,6 +239,12 @@ class TensorBoardLogger(Logger):
if metrics:
self.log_metrics(metrics, 0)
if _TENSORBOARD_AVAILABLE:
from torch.utils.tensorboard.summary import hparams
else:
from tensorboardX.summary import hparams # type: ignore[no-redef]
exp, ssi, sei = hparams(params, metrics)
writer = self.experiment._get_file_writer()
writer.add_summary(exp)

View File

@ -15,7 +15,7 @@ import contextlib
import inspect
import pickle
from unittest import mock
from unittest.mock import ANY
from unittest.mock import ANY, Mock
import pytest
import torch
@ -31,6 +31,7 @@ from pytorch_lightning.loggers import (
WandbLogger,
)
from pytorch_lightning.loggers.logger import DummyExperiment
from pytorch_lightning.loggers.tensorboard import _TENSORBOARD_AVAILABLE
from tests_pytorch.helpers.runif import RunIf
from tests_pytorch.loggers.test_comet import _patch_comet_atexit
from tests_pytorch.loggers.test_mlflow import mock_mlflow_run_creation
@ -300,10 +301,15 @@ def test_logger_with_prefix_all(tmpdir, monkeypatch):
logger.experiment.__getitem__().log.assert_called_once_with(1.0)
# TensorBoard
with mock.patch("pytorch_lightning.loggers.tensorboard.SummaryWriter"):
logger = _instantiate_logger(TensorBoardLogger, save_dir=tmpdir, prefix=prefix)
logger.log_metrics({"test": 1.0}, step=0)
logger.experiment.add_scalar.assert_called_once_with("tmp-test", 1.0, 0)
if _TENSORBOARD_AVAILABLE:
import torch.utils.tensorboard as tb
else:
import tensorboardX as tb
monkeypatch.setattr(tb, "SummaryWriter", Mock())
logger = _instantiate_logger(TensorBoardLogger, save_dir=tmpdir, prefix=prefix)
logger.log_metrics({"test": 1.0}, step=0)
logger.experiment.add_scalar.assert_called_once_with("tmp-test", 1.0, 0)
# WandB
with mock.patch("pytorch_lightning.loggers.wandb.wandb") as wandb, mock.patch(
@ -316,7 +322,7 @@ def test_logger_with_prefix_all(tmpdir, monkeypatch):
logger.experiment.log.assert_called_once_with({"tmp-test": 1.0, "trainer/global_step": 0})
def test_logger_default_name(tmpdir):
def test_logger_default_name(tmpdir, monkeypatch):
"""Test that the default logger name is lightning_logs."""
# CSV
@ -324,9 +330,14 @@ def test_logger_default_name(tmpdir):
assert logger.name == "lightning_logs"
# TensorBoard
with mock.patch("pytorch_lightning.loggers.tensorboard.SummaryWriter"):
logger = _instantiate_logger(TensorBoardLogger, save_dir=tmpdir)
assert logger.name == "lightning_logs"
if _TENSORBOARD_AVAILABLE:
import torch.utils.tensorboard as tb
else:
import tensorboardX as tb
monkeypatch.setattr(tb, "SummaryWriter", Mock())
logger = _instantiate_logger(TensorBoardLogger, save_dir=tmpdir)
assert logger.name == "lightning_logs"
# MLflow
with mock.patch("pytorch_lightning.loggers.mlflow.mlflow"), mock.patch(

View File

@ -15,6 +15,7 @@ import logging
import os
from argparse import Namespace
from unittest import mock
from unittest.mock import Mock
import numpy as np
import pytest
@ -278,23 +279,28 @@ def test_tensorboard_with_accummulated_gradients(mock_log_metrics, tmpdir):
assert count_steps == model.indexes
@mock.patch("pytorch_lightning.loggers.tensorboard.SummaryWriter")
def test_tensorboard_finalize(summary_writer, tmpdir):
def test_tensorboard_finalize(monkeypatch, tmpdir):
"""Test that the SummaryWriter closes in finalize."""
if _TENSORBOARD_AVAILABLE:
import torch.utils.tensorboard as tb
else:
import tensorboardX as tb
monkeypatch.setattr(tb, "SummaryWriter", Mock())
logger = TensorBoardLogger(save_dir=tmpdir)
assert logger._experiment is None
logger.finalize("any")
# no log calls, no experiment created -> nothing to flush
summary_writer.assert_not_called()
logger.experiment.assert_not_called()
logger = TensorBoardLogger(save_dir=tmpdir)
logger.log_metrics({"flush_me": 11.1}) # trigger creation of an experiment
logger.finalize("any")
# finalize flushes to experiment directory
summary_writer().flush.assert_called()
summary_writer().close.assert_called()
logger.experiment.flush.assert_called()
logger.experiment.close.assert_called()
def test_tensorboard_save_hparams_to_yaml_once(tmpdir):

View File

@ -1330,7 +1330,9 @@ def test_tensorboard_logger_init_args():
"TensorBoardLogger",
{
"save_dir": "tb", # Resolve from TensorBoardLogger.__init__
"comment": "tb", # Resolve from tensorboard.writer.SummaryWriter.__init__
},
{
"comment": "tb", # Unsupported resolving from local imports
},
)