Lazy import tensorboard (#15762)

This commit is contained in:
Carlos Mocholí 2022-11-23 13:59:38 +01:00 committed by GitHub
parent 952b64b358
commit a9708105f7
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 59 additions and 19 deletions

View File

@ -19,12 +19,10 @@ TensorBoard Logger
import logging import logging
import os import os
from argparse import Namespace from argparse import Namespace
from typing import Any, Dict, Mapping, Optional, Union from typing import Any, Dict, Mapping, Optional, TYPE_CHECKING, Union
import numpy as np import numpy as np
from lightning_utilities.core.imports import RequirementCache from lightning_utilities.core.imports import RequirementCache
from tensorboardX import SummaryWriter
from tensorboardX.summary import hparams
from torch import Tensor from torch import Tensor
import pytorch_lightning as pl import pytorch_lightning as pl
@ -40,6 +38,13 @@ from pytorch_lightning.utilities.rank_zero import rank_zero_only, rank_zero_warn
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
_TENSORBOARD_AVAILABLE = RequirementCache("tensorboard") _TENSORBOARD_AVAILABLE = RequirementCache("tensorboard")
_TENSORBOARDX_AVAILABLE = RequirementCache("tensorboardX")
if TYPE_CHECKING:
# assumes at least one will be installed when type checking
if _TENSORBOARD_AVAILABLE:
from torch.utils.tensorboard import SummaryWriter
else:
from tensorboardX import SummaryWriter # type: ignore[no-redef]
if _OMEGACONF_AVAILABLE: if _OMEGACONF_AVAILABLE:
from omegaconf import Container, OmegaConf from omegaconf import Container, OmegaConf
@ -109,6 +114,10 @@ class TensorBoardLogger(Logger):
sub_dir: Optional[_PATH] = None, sub_dir: Optional[_PATH] = None,
**kwargs: Any, **kwargs: Any,
): ):
if not _TENSORBOARD_AVAILABLE and not _TENSORBOARDX_AVAILABLE:
raise ModuleNotFoundError(
"Neither `tensorboard` nor `tensorboardX` is available. Try `pip install`ing either."
)
super().__init__() super().__init__()
save_dir = os.fspath(save_dir) save_dir = os.fspath(save_dir)
self._save_dir = save_dir self._save_dir = save_dir
@ -172,7 +181,7 @@ class TensorBoardLogger(Logger):
@property @property
@rank_zero_experiment @rank_zero_experiment
def experiment(self) -> SummaryWriter: def experiment(self) -> "SummaryWriter":
r""" r"""
Actual tensorboard object. To use TensorBoard features in your Actual tensorboard object. To use TensorBoard features in your
:class:`~pytorch_lightning.core.module.LightningModule` do the following. :class:`~pytorch_lightning.core.module.LightningModule` do the following.
@ -188,6 +197,12 @@ class TensorBoardLogger(Logger):
assert rank_zero_only.rank == 0, "tried to init log dirs in non global_rank=0" assert rank_zero_only.rank == 0, "tried to init log dirs in non global_rank=0"
if self.root_dir: if self.root_dir:
self._fs.makedirs(self.root_dir, exist_ok=True) self._fs.makedirs(self.root_dir, exist_ok=True)
if _TENSORBOARD_AVAILABLE:
from torch.utils.tensorboard import SummaryWriter
else:
from tensorboardX import SummaryWriter # type: ignore[no-redef]
self._experiment = SummaryWriter(log_dir=self.log_dir, **self._kwargs) self._experiment = SummaryWriter(log_dir=self.log_dir, **self._kwargs)
return self._experiment return self._experiment
@ -224,6 +239,12 @@ class TensorBoardLogger(Logger):
if metrics: if metrics:
self.log_metrics(metrics, 0) self.log_metrics(metrics, 0)
if _TENSORBOARD_AVAILABLE:
from torch.utils.tensorboard.summary import hparams
else:
from tensorboardX.summary import hparams # type: ignore[no-redef]
exp, ssi, sei = hparams(params, metrics) exp, ssi, sei = hparams(params, metrics)
writer = self.experiment._get_file_writer() writer = self.experiment._get_file_writer()
writer.add_summary(exp) writer.add_summary(exp)

View File

@ -15,7 +15,7 @@ import contextlib
import inspect import inspect
import pickle import pickle
from unittest import mock from unittest import mock
from unittest.mock import ANY from unittest.mock import ANY, Mock
import pytest import pytest
import torch import torch
@ -31,6 +31,7 @@ from pytorch_lightning.loggers import (
WandbLogger, WandbLogger,
) )
from pytorch_lightning.loggers.logger import DummyExperiment from pytorch_lightning.loggers.logger import DummyExperiment
from pytorch_lightning.loggers.tensorboard import _TENSORBOARD_AVAILABLE
from tests_pytorch.helpers.runif import RunIf from tests_pytorch.helpers.runif import RunIf
from tests_pytorch.loggers.test_comet import _patch_comet_atexit from tests_pytorch.loggers.test_comet import _patch_comet_atexit
from tests_pytorch.loggers.test_mlflow import mock_mlflow_run_creation from tests_pytorch.loggers.test_mlflow import mock_mlflow_run_creation
@ -300,10 +301,15 @@ def test_logger_with_prefix_all(tmpdir, monkeypatch):
logger.experiment.__getitem__().log.assert_called_once_with(1.0) logger.experiment.__getitem__().log.assert_called_once_with(1.0)
# TensorBoard # TensorBoard
with mock.patch("pytorch_lightning.loggers.tensorboard.SummaryWriter"): if _TENSORBOARD_AVAILABLE:
logger = _instantiate_logger(TensorBoardLogger, save_dir=tmpdir, prefix=prefix) import torch.utils.tensorboard as tb
logger.log_metrics({"test": 1.0}, step=0) else:
logger.experiment.add_scalar.assert_called_once_with("tmp-test", 1.0, 0) import tensorboardX as tb
monkeypatch.setattr(tb, "SummaryWriter", Mock())
logger = _instantiate_logger(TensorBoardLogger, save_dir=tmpdir, prefix=prefix)
logger.log_metrics({"test": 1.0}, step=0)
logger.experiment.add_scalar.assert_called_once_with("tmp-test", 1.0, 0)
# WandB # WandB
with mock.patch("pytorch_lightning.loggers.wandb.wandb") as wandb, mock.patch( with mock.patch("pytorch_lightning.loggers.wandb.wandb") as wandb, mock.patch(
@ -316,7 +322,7 @@ def test_logger_with_prefix_all(tmpdir, monkeypatch):
logger.experiment.log.assert_called_once_with({"tmp-test": 1.0, "trainer/global_step": 0}) logger.experiment.log.assert_called_once_with({"tmp-test": 1.0, "trainer/global_step": 0})
def test_logger_default_name(tmpdir): def test_logger_default_name(tmpdir, monkeypatch):
"""Test that the default logger name is lightning_logs.""" """Test that the default logger name is lightning_logs."""
# CSV # CSV
@ -324,9 +330,14 @@ def test_logger_default_name(tmpdir):
assert logger.name == "lightning_logs" assert logger.name == "lightning_logs"
# TensorBoard # TensorBoard
with mock.patch("pytorch_lightning.loggers.tensorboard.SummaryWriter"): if _TENSORBOARD_AVAILABLE:
logger = _instantiate_logger(TensorBoardLogger, save_dir=tmpdir) import torch.utils.tensorboard as tb
assert logger.name == "lightning_logs" else:
import tensorboardX as tb
monkeypatch.setattr(tb, "SummaryWriter", Mock())
logger = _instantiate_logger(TensorBoardLogger, save_dir=tmpdir)
assert logger.name == "lightning_logs"
# MLflow # MLflow
with mock.patch("pytorch_lightning.loggers.mlflow.mlflow"), mock.patch( with mock.patch("pytorch_lightning.loggers.mlflow.mlflow"), mock.patch(

View File

@ -15,6 +15,7 @@ import logging
import os import os
from argparse import Namespace from argparse import Namespace
from unittest import mock from unittest import mock
from unittest.mock import Mock
import numpy as np import numpy as np
import pytest import pytest
@ -278,23 +279,28 @@ def test_tensorboard_with_accummulated_gradients(mock_log_metrics, tmpdir):
assert count_steps == model.indexes assert count_steps == model.indexes
@mock.patch("pytorch_lightning.loggers.tensorboard.SummaryWriter") def test_tensorboard_finalize(monkeypatch, tmpdir):
def test_tensorboard_finalize(summary_writer, tmpdir):
"""Test that the SummaryWriter closes in finalize.""" """Test that the SummaryWriter closes in finalize."""
if _TENSORBOARD_AVAILABLE:
import torch.utils.tensorboard as tb
else:
import tensorboardX as tb
monkeypatch.setattr(tb, "SummaryWriter", Mock())
logger = TensorBoardLogger(save_dir=tmpdir) logger = TensorBoardLogger(save_dir=tmpdir)
assert logger._experiment is None assert logger._experiment is None
logger.finalize("any") logger.finalize("any")
# no log calls, no experiment created -> nothing to flush # no log calls, no experiment created -> nothing to flush
summary_writer.assert_not_called() logger.experiment.assert_not_called()
logger = TensorBoardLogger(save_dir=tmpdir) logger = TensorBoardLogger(save_dir=tmpdir)
logger.log_metrics({"flush_me": 11.1}) # trigger creation of an experiment logger.log_metrics({"flush_me": 11.1}) # trigger creation of an experiment
logger.finalize("any") logger.finalize("any")
# finalize flushes to experiment directory # finalize flushes to experiment directory
summary_writer().flush.assert_called() logger.experiment.flush.assert_called()
summary_writer().close.assert_called() logger.experiment.close.assert_called()
def test_tensorboard_save_hparams_to_yaml_once(tmpdir): def test_tensorboard_save_hparams_to_yaml_once(tmpdir):

View File

@ -1330,7 +1330,9 @@ def test_tensorboard_logger_init_args():
"TensorBoardLogger", "TensorBoardLogger",
{ {
"save_dir": "tb", # Resolve from TensorBoardLogger.__init__ "save_dir": "tb", # Resolve from TensorBoardLogger.__init__
"comment": "tb", # Resolve from tensorboard.writer.SummaryWriter.__init__ },
{
"comment": "tb", # Unsupported resolving from local imports
}, },
) )