From a9708105f70e16f0520dbd9895616b927edb81aa Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Carlos=20Mochol=C3=AD?= Date: Wed, 23 Nov 2022 13:59:38 +0100 Subject: [PATCH] Lazy import tensorboard (#15762) --- src/pytorch_lightning/loggers/tensorboard.py | 29 ++++++++++++++++--- tests/tests_pytorch/loggers/test_all.py | 29 +++++++++++++------ .../tests_pytorch/loggers/test_tensorboard.py | 16 ++++++---- tests/tests_pytorch/test_cli.py | 4 ++- 4 files changed, 59 insertions(+), 19 deletions(-) diff --git a/src/pytorch_lightning/loggers/tensorboard.py b/src/pytorch_lightning/loggers/tensorboard.py index 1c840a3dea..25b0d8fa0c 100644 --- a/src/pytorch_lightning/loggers/tensorboard.py +++ b/src/pytorch_lightning/loggers/tensorboard.py @@ -19,12 +19,10 @@ TensorBoard Logger import logging import os from argparse import Namespace -from typing import Any, Dict, Mapping, Optional, Union +from typing import Any, Dict, Mapping, Optional, TYPE_CHECKING, Union import numpy as np from lightning_utilities.core.imports import RequirementCache -from tensorboardX import SummaryWriter -from tensorboardX.summary import hparams from torch import Tensor import pytorch_lightning as pl @@ -40,6 +38,13 @@ from pytorch_lightning.utilities.rank_zero import rank_zero_only, rank_zero_warn log = logging.getLogger(__name__) _TENSORBOARD_AVAILABLE = RequirementCache("tensorboard") +_TENSORBOARDX_AVAILABLE = RequirementCache("tensorboardX") +if TYPE_CHECKING: + # assumes at least one will be installed when type checking + if _TENSORBOARD_AVAILABLE: + from torch.utils.tensorboard import SummaryWriter + else: + from tensorboardX import SummaryWriter # type: ignore[no-redef] if _OMEGACONF_AVAILABLE: from omegaconf import Container, OmegaConf @@ -109,6 +114,10 @@ class TensorBoardLogger(Logger): sub_dir: Optional[_PATH] = None, **kwargs: Any, ): + if not _TENSORBOARD_AVAILABLE and not _TENSORBOARDX_AVAILABLE: + raise ModuleNotFoundError( + "Neither `tensorboard` nor `tensorboardX` is available. Try `pip install`ing either." + ) super().__init__() save_dir = os.fspath(save_dir) self._save_dir = save_dir @@ -172,7 +181,7 @@ class TensorBoardLogger(Logger): @property @rank_zero_experiment - def experiment(self) -> SummaryWriter: + def experiment(self) -> "SummaryWriter": r""" Actual tensorboard object. To use TensorBoard features in your :class:`~pytorch_lightning.core.module.LightningModule` do the following. @@ -188,6 +197,12 @@ class TensorBoardLogger(Logger): assert rank_zero_only.rank == 0, "tried to init log dirs in non global_rank=0" if self.root_dir: self._fs.makedirs(self.root_dir, exist_ok=True) + + if _TENSORBOARD_AVAILABLE: + from torch.utils.tensorboard import SummaryWriter + else: + from tensorboardX import SummaryWriter # type: ignore[no-redef] + self._experiment = SummaryWriter(log_dir=self.log_dir, **self._kwargs) return self._experiment @@ -224,6 +239,12 @@ class TensorBoardLogger(Logger): if metrics: self.log_metrics(metrics, 0) + + if _TENSORBOARD_AVAILABLE: + from torch.utils.tensorboard.summary import hparams + else: + from tensorboardX.summary import hparams # type: ignore[no-redef] + exp, ssi, sei = hparams(params, metrics) writer = self.experiment._get_file_writer() writer.add_summary(exp) diff --git a/tests/tests_pytorch/loggers/test_all.py b/tests/tests_pytorch/loggers/test_all.py index 4244e98455..4477b13b5b 100644 --- a/tests/tests_pytorch/loggers/test_all.py +++ b/tests/tests_pytorch/loggers/test_all.py @@ -15,7 +15,7 @@ import contextlib import inspect import pickle from unittest import mock -from unittest.mock import ANY +from unittest.mock import ANY, Mock import pytest import torch @@ -31,6 +31,7 @@ from pytorch_lightning.loggers import ( WandbLogger, ) from pytorch_lightning.loggers.logger import DummyExperiment +from pytorch_lightning.loggers.tensorboard import _TENSORBOARD_AVAILABLE from tests_pytorch.helpers.runif import RunIf from tests_pytorch.loggers.test_comet import _patch_comet_atexit from tests_pytorch.loggers.test_mlflow import mock_mlflow_run_creation @@ -300,10 +301,15 @@ def test_logger_with_prefix_all(tmpdir, monkeypatch): logger.experiment.__getitem__().log.assert_called_once_with(1.0) # TensorBoard - with mock.patch("pytorch_lightning.loggers.tensorboard.SummaryWriter"): - logger = _instantiate_logger(TensorBoardLogger, save_dir=tmpdir, prefix=prefix) - logger.log_metrics({"test": 1.0}, step=0) - logger.experiment.add_scalar.assert_called_once_with("tmp-test", 1.0, 0) + if _TENSORBOARD_AVAILABLE: + import torch.utils.tensorboard as tb + else: + import tensorboardX as tb + + monkeypatch.setattr(tb, "SummaryWriter", Mock()) + logger = _instantiate_logger(TensorBoardLogger, save_dir=tmpdir, prefix=prefix) + logger.log_metrics({"test": 1.0}, step=0) + logger.experiment.add_scalar.assert_called_once_with("tmp-test", 1.0, 0) # WandB with mock.patch("pytorch_lightning.loggers.wandb.wandb") as wandb, mock.patch( @@ -316,7 +322,7 @@ def test_logger_with_prefix_all(tmpdir, monkeypatch): logger.experiment.log.assert_called_once_with({"tmp-test": 1.0, "trainer/global_step": 0}) -def test_logger_default_name(tmpdir): +def test_logger_default_name(tmpdir, monkeypatch): """Test that the default logger name is lightning_logs.""" # CSV @@ -324,9 +330,14 @@ def test_logger_default_name(tmpdir): assert logger.name == "lightning_logs" # TensorBoard - with mock.patch("pytorch_lightning.loggers.tensorboard.SummaryWriter"): - logger = _instantiate_logger(TensorBoardLogger, save_dir=tmpdir) - assert logger.name == "lightning_logs" + if _TENSORBOARD_AVAILABLE: + import torch.utils.tensorboard as tb + else: + import tensorboardX as tb + + monkeypatch.setattr(tb, "SummaryWriter", Mock()) + logger = _instantiate_logger(TensorBoardLogger, save_dir=tmpdir) + assert logger.name == "lightning_logs" # MLflow with mock.patch("pytorch_lightning.loggers.mlflow.mlflow"), mock.patch( diff --git a/tests/tests_pytorch/loggers/test_tensorboard.py b/tests/tests_pytorch/loggers/test_tensorboard.py index ddab738269..7189f4c735 100644 --- a/tests/tests_pytorch/loggers/test_tensorboard.py +++ b/tests/tests_pytorch/loggers/test_tensorboard.py @@ -15,6 +15,7 @@ import logging import os from argparse import Namespace from unittest import mock +from unittest.mock import Mock import numpy as np import pytest @@ -278,23 +279,28 @@ def test_tensorboard_with_accummulated_gradients(mock_log_metrics, tmpdir): assert count_steps == model.indexes -@mock.patch("pytorch_lightning.loggers.tensorboard.SummaryWriter") -def test_tensorboard_finalize(summary_writer, tmpdir): +def test_tensorboard_finalize(monkeypatch, tmpdir): """Test that the SummaryWriter closes in finalize.""" + if _TENSORBOARD_AVAILABLE: + import torch.utils.tensorboard as tb + else: + import tensorboardX as tb + + monkeypatch.setattr(tb, "SummaryWriter", Mock()) logger = TensorBoardLogger(save_dir=tmpdir) assert logger._experiment is None logger.finalize("any") # no log calls, no experiment created -> nothing to flush - summary_writer.assert_not_called() + logger.experiment.assert_not_called() logger = TensorBoardLogger(save_dir=tmpdir) logger.log_metrics({"flush_me": 11.1}) # trigger creation of an experiment logger.finalize("any") # finalize flushes to experiment directory - summary_writer().flush.assert_called() - summary_writer().close.assert_called() + logger.experiment.flush.assert_called() + logger.experiment.close.assert_called() def test_tensorboard_save_hparams_to_yaml_once(tmpdir): diff --git a/tests/tests_pytorch/test_cli.py b/tests/tests_pytorch/test_cli.py index 5e864cea35..79562e52e3 100644 --- a/tests/tests_pytorch/test_cli.py +++ b/tests/tests_pytorch/test_cli.py @@ -1330,7 +1330,9 @@ def test_tensorboard_logger_init_args(): "TensorBoardLogger", { "save_dir": "tb", # Resolve from TensorBoardLogger.__init__ - "comment": "tb", # Resolve from tensorboard.writer.SummaryWriter.__init__ + }, + { + "comment": "tb", # Unsupported resolving from local imports }, )