Lazy import tensorboard (#15762)
This commit is contained in:
parent
952b64b358
commit
a9708105f7
|
@ -19,12 +19,10 @@ TensorBoard Logger
|
|||
import logging
|
||||
import os
|
||||
from argparse import Namespace
|
||||
from typing import Any, Dict, Mapping, Optional, Union
|
||||
from typing import Any, Dict, Mapping, Optional, TYPE_CHECKING, Union
|
||||
|
||||
import numpy as np
|
||||
from lightning_utilities.core.imports import RequirementCache
|
||||
from tensorboardX import SummaryWriter
|
||||
from tensorboardX.summary import hparams
|
||||
from torch import Tensor
|
||||
|
||||
import pytorch_lightning as pl
|
||||
|
@ -40,6 +38,13 @@ from pytorch_lightning.utilities.rank_zero import rank_zero_only, rank_zero_warn
|
|||
log = logging.getLogger(__name__)
|
||||
|
||||
_TENSORBOARD_AVAILABLE = RequirementCache("tensorboard")
|
||||
_TENSORBOARDX_AVAILABLE = RequirementCache("tensorboardX")
|
||||
if TYPE_CHECKING:
|
||||
# assumes at least one will be installed when type checking
|
||||
if _TENSORBOARD_AVAILABLE:
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
else:
|
||||
from tensorboardX import SummaryWriter # type: ignore[no-redef]
|
||||
|
||||
if _OMEGACONF_AVAILABLE:
|
||||
from omegaconf import Container, OmegaConf
|
||||
|
@ -109,6 +114,10 @@ class TensorBoardLogger(Logger):
|
|||
sub_dir: Optional[_PATH] = None,
|
||||
**kwargs: Any,
|
||||
):
|
||||
if not _TENSORBOARD_AVAILABLE and not _TENSORBOARDX_AVAILABLE:
|
||||
raise ModuleNotFoundError(
|
||||
"Neither `tensorboard` nor `tensorboardX` is available. Try `pip install`ing either."
|
||||
)
|
||||
super().__init__()
|
||||
save_dir = os.fspath(save_dir)
|
||||
self._save_dir = save_dir
|
||||
|
@ -172,7 +181,7 @@ class TensorBoardLogger(Logger):
|
|||
|
||||
@property
|
||||
@rank_zero_experiment
|
||||
def experiment(self) -> SummaryWriter:
|
||||
def experiment(self) -> "SummaryWriter":
|
||||
r"""
|
||||
Actual tensorboard object. To use TensorBoard features in your
|
||||
:class:`~pytorch_lightning.core.module.LightningModule` do the following.
|
||||
|
@ -188,6 +197,12 @@ class TensorBoardLogger(Logger):
|
|||
assert rank_zero_only.rank == 0, "tried to init log dirs in non global_rank=0"
|
||||
if self.root_dir:
|
||||
self._fs.makedirs(self.root_dir, exist_ok=True)
|
||||
|
||||
if _TENSORBOARD_AVAILABLE:
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
else:
|
||||
from tensorboardX import SummaryWriter # type: ignore[no-redef]
|
||||
|
||||
self._experiment = SummaryWriter(log_dir=self.log_dir, **self._kwargs)
|
||||
return self._experiment
|
||||
|
||||
|
@ -224,6 +239,12 @@ class TensorBoardLogger(Logger):
|
|||
|
||||
if metrics:
|
||||
self.log_metrics(metrics, 0)
|
||||
|
||||
if _TENSORBOARD_AVAILABLE:
|
||||
from torch.utils.tensorboard.summary import hparams
|
||||
else:
|
||||
from tensorboardX.summary import hparams # type: ignore[no-redef]
|
||||
|
||||
exp, ssi, sei = hparams(params, metrics)
|
||||
writer = self.experiment._get_file_writer()
|
||||
writer.add_summary(exp)
|
||||
|
|
|
@ -15,7 +15,7 @@ import contextlib
|
|||
import inspect
|
||||
import pickle
|
||||
from unittest import mock
|
||||
from unittest.mock import ANY
|
||||
from unittest.mock import ANY, Mock
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
@ -31,6 +31,7 @@ from pytorch_lightning.loggers import (
|
|||
WandbLogger,
|
||||
)
|
||||
from pytorch_lightning.loggers.logger import DummyExperiment
|
||||
from pytorch_lightning.loggers.tensorboard import _TENSORBOARD_AVAILABLE
|
||||
from tests_pytorch.helpers.runif import RunIf
|
||||
from tests_pytorch.loggers.test_comet import _patch_comet_atexit
|
||||
from tests_pytorch.loggers.test_mlflow import mock_mlflow_run_creation
|
||||
|
@ -300,10 +301,15 @@ def test_logger_with_prefix_all(tmpdir, monkeypatch):
|
|||
logger.experiment.__getitem__().log.assert_called_once_with(1.0)
|
||||
|
||||
# TensorBoard
|
||||
with mock.patch("pytorch_lightning.loggers.tensorboard.SummaryWriter"):
|
||||
logger = _instantiate_logger(TensorBoardLogger, save_dir=tmpdir, prefix=prefix)
|
||||
logger.log_metrics({"test": 1.0}, step=0)
|
||||
logger.experiment.add_scalar.assert_called_once_with("tmp-test", 1.0, 0)
|
||||
if _TENSORBOARD_AVAILABLE:
|
||||
import torch.utils.tensorboard as tb
|
||||
else:
|
||||
import tensorboardX as tb
|
||||
|
||||
monkeypatch.setattr(tb, "SummaryWriter", Mock())
|
||||
logger = _instantiate_logger(TensorBoardLogger, save_dir=tmpdir, prefix=prefix)
|
||||
logger.log_metrics({"test": 1.0}, step=0)
|
||||
logger.experiment.add_scalar.assert_called_once_with("tmp-test", 1.0, 0)
|
||||
|
||||
# WandB
|
||||
with mock.patch("pytorch_lightning.loggers.wandb.wandb") as wandb, mock.patch(
|
||||
|
@ -316,7 +322,7 @@ def test_logger_with_prefix_all(tmpdir, monkeypatch):
|
|||
logger.experiment.log.assert_called_once_with({"tmp-test": 1.0, "trainer/global_step": 0})
|
||||
|
||||
|
||||
def test_logger_default_name(tmpdir):
|
||||
def test_logger_default_name(tmpdir, monkeypatch):
|
||||
"""Test that the default logger name is lightning_logs."""
|
||||
|
||||
# CSV
|
||||
|
@ -324,9 +330,14 @@ def test_logger_default_name(tmpdir):
|
|||
assert logger.name == "lightning_logs"
|
||||
|
||||
# TensorBoard
|
||||
with mock.patch("pytorch_lightning.loggers.tensorboard.SummaryWriter"):
|
||||
logger = _instantiate_logger(TensorBoardLogger, save_dir=tmpdir)
|
||||
assert logger.name == "lightning_logs"
|
||||
if _TENSORBOARD_AVAILABLE:
|
||||
import torch.utils.tensorboard as tb
|
||||
else:
|
||||
import tensorboardX as tb
|
||||
|
||||
monkeypatch.setattr(tb, "SummaryWriter", Mock())
|
||||
logger = _instantiate_logger(TensorBoardLogger, save_dir=tmpdir)
|
||||
assert logger.name == "lightning_logs"
|
||||
|
||||
# MLflow
|
||||
with mock.patch("pytorch_lightning.loggers.mlflow.mlflow"), mock.patch(
|
||||
|
|
|
@ -15,6 +15,7 @@ import logging
|
|||
import os
|
||||
from argparse import Namespace
|
||||
from unittest import mock
|
||||
from unittest.mock import Mock
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
@ -278,23 +279,28 @@ def test_tensorboard_with_accummulated_gradients(mock_log_metrics, tmpdir):
|
|||
assert count_steps == model.indexes
|
||||
|
||||
|
||||
@mock.patch("pytorch_lightning.loggers.tensorboard.SummaryWriter")
|
||||
def test_tensorboard_finalize(summary_writer, tmpdir):
|
||||
def test_tensorboard_finalize(monkeypatch, tmpdir):
|
||||
"""Test that the SummaryWriter closes in finalize."""
|
||||
if _TENSORBOARD_AVAILABLE:
|
||||
import torch.utils.tensorboard as tb
|
||||
else:
|
||||
import tensorboardX as tb
|
||||
|
||||
monkeypatch.setattr(tb, "SummaryWriter", Mock())
|
||||
logger = TensorBoardLogger(save_dir=tmpdir)
|
||||
assert logger._experiment is None
|
||||
logger.finalize("any")
|
||||
|
||||
# no log calls, no experiment created -> nothing to flush
|
||||
summary_writer.assert_not_called()
|
||||
logger.experiment.assert_not_called()
|
||||
|
||||
logger = TensorBoardLogger(save_dir=tmpdir)
|
||||
logger.log_metrics({"flush_me": 11.1}) # trigger creation of an experiment
|
||||
logger.finalize("any")
|
||||
|
||||
# finalize flushes to experiment directory
|
||||
summary_writer().flush.assert_called()
|
||||
summary_writer().close.assert_called()
|
||||
logger.experiment.flush.assert_called()
|
||||
logger.experiment.close.assert_called()
|
||||
|
||||
|
||||
def test_tensorboard_save_hparams_to_yaml_once(tmpdir):
|
||||
|
|
|
@ -1330,7 +1330,9 @@ def test_tensorboard_logger_init_args():
|
|||
"TensorBoardLogger",
|
||||
{
|
||||
"save_dir": "tb", # Resolve from TensorBoardLogger.__init__
|
||||
"comment": "tb", # Resolve from tensorboard.writer.SummaryWriter.__init__
|
||||
},
|
||||
{
|
||||
"comment": "tb", # Unsupported resolving from local imports
|
||||
},
|
||||
)
|
||||
|
||||
|
|
Loading…
Reference in New Issue