Lazy import tensorboard (#15762)
This commit is contained in:
parent
952b64b358
commit
a9708105f7
|
@ -19,12 +19,10 @@ TensorBoard Logger
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
from argparse import Namespace
|
from argparse import Namespace
|
||||||
from typing import Any, Dict, Mapping, Optional, Union
|
from typing import Any, Dict, Mapping, Optional, TYPE_CHECKING, Union
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from lightning_utilities.core.imports import RequirementCache
|
from lightning_utilities.core.imports import RequirementCache
|
||||||
from tensorboardX import SummaryWriter
|
|
||||||
from tensorboardX.summary import hparams
|
|
||||||
from torch import Tensor
|
from torch import Tensor
|
||||||
|
|
||||||
import pytorch_lightning as pl
|
import pytorch_lightning as pl
|
||||||
|
@ -40,6 +38,13 @@ from pytorch_lightning.utilities.rank_zero import rank_zero_only, rank_zero_warn
|
||||||
log = logging.getLogger(__name__)
|
log = logging.getLogger(__name__)
|
||||||
|
|
||||||
_TENSORBOARD_AVAILABLE = RequirementCache("tensorboard")
|
_TENSORBOARD_AVAILABLE = RequirementCache("tensorboard")
|
||||||
|
_TENSORBOARDX_AVAILABLE = RequirementCache("tensorboardX")
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
# assumes at least one will be installed when type checking
|
||||||
|
if _TENSORBOARD_AVAILABLE:
|
||||||
|
from torch.utils.tensorboard import SummaryWriter
|
||||||
|
else:
|
||||||
|
from tensorboardX import SummaryWriter # type: ignore[no-redef]
|
||||||
|
|
||||||
if _OMEGACONF_AVAILABLE:
|
if _OMEGACONF_AVAILABLE:
|
||||||
from omegaconf import Container, OmegaConf
|
from omegaconf import Container, OmegaConf
|
||||||
|
@ -109,6 +114,10 @@ class TensorBoardLogger(Logger):
|
||||||
sub_dir: Optional[_PATH] = None,
|
sub_dir: Optional[_PATH] = None,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
):
|
):
|
||||||
|
if not _TENSORBOARD_AVAILABLE and not _TENSORBOARDX_AVAILABLE:
|
||||||
|
raise ModuleNotFoundError(
|
||||||
|
"Neither `tensorboard` nor `tensorboardX` is available. Try `pip install`ing either."
|
||||||
|
)
|
||||||
super().__init__()
|
super().__init__()
|
||||||
save_dir = os.fspath(save_dir)
|
save_dir = os.fspath(save_dir)
|
||||||
self._save_dir = save_dir
|
self._save_dir = save_dir
|
||||||
|
@ -172,7 +181,7 @@ class TensorBoardLogger(Logger):
|
||||||
|
|
||||||
@property
|
@property
|
||||||
@rank_zero_experiment
|
@rank_zero_experiment
|
||||||
def experiment(self) -> SummaryWriter:
|
def experiment(self) -> "SummaryWriter":
|
||||||
r"""
|
r"""
|
||||||
Actual tensorboard object. To use TensorBoard features in your
|
Actual tensorboard object. To use TensorBoard features in your
|
||||||
:class:`~pytorch_lightning.core.module.LightningModule` do the following.
|
:class:`~pytorch_lightning.core.module.LightningModule` do the following.
|
||||||
|
@ -188,6 +197,12 @@ class TensorBoardLogger(Logger):
|
||||||
assert rank_zero_only.rank == 0, "tried to init log dirs in non global_rank=0"
|
assert rank_zero_only.rank == 0, "tried to init log dirs in non global_rank=0"
|
||||||
if self.root_dir:
|
if self.root_dir:
|
||||||
self._fs.makedirs(self.root_dir, exist_ok=True)
|
self._fs.makedirs(self.root_dir, exist_ok=True)
|
||||||
|
|
||||||
|
if _TENSORBOARD_AVAILABLE:
|
||||||
|
from torch.utils.tensorboard import SummaryWriter
|
||||||
|
else:
|
||||||
|
from tensorboardX import SummaryWriter # type: ignore[no-redef]
|
||||||
|
|
||||||
self._experiment = SummaryWriter(log_dir=self.log_dir, **self._kwargs)
|
self._experiment = SummaryWriter(log_dir=self.log_dir, **self._kwargs)
|
||||||
return self._experiment
|
return self._experiment
|
||||||
|
|
||||||
|
@ -224,6 +239,12 @@ class TensorBoardLogger(Logger):
|
||||||
|
|
||||||
if metrics:
|
if metrics:
|
||||||
self.log_metrics(metrics, 0)
|
self.log_metrics(metrics, 0)
|
||||||
|
|
||||||
|
if _TENSORBOARD_AVAILABLE:
|
||||||
|
from torch.utils.tensorboard.summary import hparams
|
||||||
|
else:
|
||||||
|
from tensorboardX.summary import hparams # type: ignore[no-redef]
|
||||||
|
|
||||||
exp, ssi, sei = hparams(params, metrics)
|
exp, ssi, sei = hparams(params, metrics)
|
||||||
writer = self.experiment._get_file_writer()
|
writer = self.experiment._get_file_writer()
|
||||||
writer.add_summary(exp)
|
writer.add_summary(exp)
|
||||||
|
|
|
@ -15,7 +15,7 @@ import contextlib
|
||||||
import inspect
|
import inspect
|
||||||
import pickle
|
import pickle
|
||||||
from unittest import mock
|
from unittest import mock
|
||||||
from unittest.mock import ANY
|
from unittest.mock import ANY, Mock
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
import torch
|
import torch
|
||||||
|
@ -31,6 +31,7 @@ from pytorch_lightning.loggers import (
|
||||||
WandbLogger,
|
WandbLogger,
|
||||||
)
|
)
|
||||||
from pytorch_lightning.loggers.logger import DummyExperiment
|
from pytorch_lightning.loggers.logger import DummyExperiment
|
||||||
|
from pytorch_lightning.loggers.tensorboard import _TENSORBOARD_AVAILABLE
|
||||||
from tests_pytorch.helpers.runif import RunIf
|
from tests_pytorch.helpers.runif import RunIf
|
||||||
from tests_pytorch.loggers.test_comet import _patch_comet_atexit
|
from tests_pytorch.loggers.test_comet import _patch_comet_atexit
|
||||||
from tests_pytorch.loggers.test_mlflow import mock_mlflow_run_creation
|
from tests_pytorch.loggers.test_mlflow import mock_mlflow_run_creation
|
||||||
|
@ -300,10 +301,15 @@ def test_logger_with_prefix_all(tmpdir, monkeypatch):
|
||||||
logger.experiment.__getitem__().log.assert_called_once_with(1.0)
|
logger.experiment.__getitem__().log.assert_called_once_with(1.0)
|
||||||
|
|
||||||
# TensorBoard
|
# TensorBoard
|
||||||
with mock.patch("pytorch_lightning.loggers.tensorboard.SummaryWriter"):
|
if _TENSORBOARD_AVAILABLE:
|
||||||
logger = _instantiate_logger(TensorBoardLogger, save_dir=tmpdir, prefix=prefix)
|
import torch.utils.tensorboard as tb
|
||||||
logger.log_metrics({"test": 1.0}, step=0)
|
else:
|
||||||
logger.experiment.add_scalar.assert_called_once_with("tmp-test", 1.0, 0)
|
import tensorboardX as tb
|
||||||
|
|
||||||
|
monkeypatch.setattr(tb, "SummaryWriter", Mock())
|
||||||
|
logger = _instantiate_logger(TensorBoardLogger, save_dir=tmpdir, prefix=prefix)
|
||||||
|
logger.log_metrics({"test": 1.0}, step=0)
|
||||||
|
logger.experiment.add_scalar.assert_called_once_with("tmp-test", 1.0, 0)
|
||||||
|
|
||||||
# WandB
|
# WandB
|
||||||
with mock.patch("pytorch_lightning.loggers.wandb.wandb") as wandb, mock.patch(
|
with mock.patch("pytorch_lightning.loggers.wandb.wandb") as wandb, mock.patch(
|
||||||
|
@ -316,7 +322,7 @@ def test_logger_with_prefix_all(tmpdir, monkeypatch):
|
||||||
logger.experiment.log.assert_called_once_with({"tmp-test": 1.0, "trainer/global_step": 0})
|
logger.experiment.log.assert_called_once_with({"tmp-test": 1.0, "trainer/global_step": 0})
|
||||||
|
|
||||||
|
|
||||||
def test_logger_default_name(tmpdir):
|
def test_logger_default_name(tmpdir, monkeypatch):
|
||||||
"""Test that the default logger name is lightning_logs."""
|
"""Test that the default logger name is lightning_logs."""
|
||||||
|
|
||||||
# CSV
|
# CSV
|
||||||
|
@ -324,9 +330,14 @@ def test_logger_default_name(tmpdir):
|
||||||
assert logger.name == "lightning_logs"
|
assert logger.name == "lightning_logs"
|
||||||
|
|
||||||
# TensorBoard
|
# TensorBoard
|
||||||
with mock.patch("pytorch_lightning.loggers.tensorboard.SummaryWriter"):
|
if _TENSORBOARD_AVAILABLE:
|
||||||
logger = _instantiate_logger(TensorBoardLogger, save_dir=tmpdir)
|
import torch.utils.tensorboard as tb
|
||||||
assert logger.name == "lightning_logs"
|
else:
|
||||||
|
import tensorboardX as tb
|
||||||
|
|
||||||
|
monkeypatch.setattr(tb, "SummaryWriter", Mock())
|
||||||
|
logger = _instantiate_logger(TensorBoardLogger, save_dir=tmpdir)
|
||||||
|
assert logger.name == "lightning_logs"
|
||||||
|
|
||||||
# MLflow
|
# MLflow
|
||||||
with mock.patch("pytorch_lightning.loggers.mlflow.mlflow"), mock.patch(
|
with mock.patch("pytorch_lightning.loggers.mlflow.mlflow"), mock.patch(
|
||||||
|
|
|
@ -15,6 +15,7 @@ import logging
|
||||||
import os
|
import os
|
||||||
from argparse import Namespace
|
from argparse import Namespace
|
||||||
from unittest import mock
|
from unittest import mock
|
||||||
|
from unittest.mock import Mock
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import pytest
|
import pytest
|
||||||
|
@ -278,23 +279,28 @@ def test_tensorboard_with_accummulated_gradients(mock_log_metrics, tmpdir):
|
||||||
assert count_steps == model.indexes
|
assert count_steps == model.indexes
|
||||||
|
|
||||||
|
|
||||||
@mock.patch("pytorch_lightning.loggers.tensorboard.SummaryWriter")
|
def test_tensorboard_finalize(monkeypatch, tmpdir):
|
||||||
def test_tensorboard_finalize(summary_writer, tmpdir):
|
|
||||||
"""Test that the SummaryWriter closes in finalize."""
|
"""Test that the SummaryWriter closes in finalize."""
|
||||||
|
if _TENSORBOARD_AVAILABLE:
|
||||||
|
import torch.utils.tensorboard as tb
|
||||||
|
else:
|
||||||
|
import tensorboardX as tb
|
||||||
|
|
||||||
|
monkeypatch.setattr(tb, "SummaryWriter", Mock())
|
||||||
logger = TensorBoardLogger(save_dir=tmpdir)
|
logger = TensorBoardLogger(save_dir=tmpdir)
|
||||||
assert logger._experiment is None
|
assert logger._experiment is None
|
||||||
logger.finalize("any")
|
logger.finalize("any")
|
||||||
|
|
||||||
# no log calls, no experiment created -> nothing to flush
|
# no log calls, no experiment created -> nothing to flush
|
||||||
summary_writer.assert_not_called()
|
logger.experiment.assert_not_called()
|
||||||
|
|
||||||
logger = TensorBoardLogger(save_dir=tmpdir)
|
logger = TensorBoardLogger(save_dir=tmpdir)
|
||||||
logger.log_metrics({"flush_me": 11.1}) # trigger creation of an experiment
|
logger.log_metrics({"flush_me": 11.1}) # trigger creation of an experiment
|
||||||
logger.finalize("any")
|
logger.finalize("any")
|
||||||
|
|
||||||
# finalize flushes to experiment directory
|
# finalize flushes to experiment directory
|
||||||
summary_writer().flush.assert_called()
|
logger.experiment.flush.assert_called()
|
||||||
summary_writer().close.assert_called()
|
logger.experiment.close.assert_called()
|
||||||
|
|
||||||
|
|
||||||
def test_tensorboard_save_hparams_to_yaml_once(tmpdir):
|
def test_tensorboard_save_hparams_to_yaml_once(tmpdir):
|
||||||
|
|
|
@ -1330,7 +1330,9 @@ def test_tensorboard_logger_init_args():
|
||||||
"TensorBoardLogger",
|
"TensorBoardLogger",
|
||||||
{
|
{
|
||||||
"save_dir": "tb", # Resolve from TensorBoardLogger.__init__
|
"save_dir": "tb", # Resolve from TensorBoardLogger.__init__
|
||||||
"comment": "tb", # Resolve from tensorboard.writer.SummaryWriter.__init__
|
},
|
||||||
|
{
|
||||||
|
"comment": "tb", # Unsupported resolving from local imports
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue