diff --git a/pytorch_lightning/logging/__init__.py b/pytorch_lightning/logging/__init__.py index 1d2bce7281..8a6c3d03e1 100644 --- a/pytorch_lightning/logging/__init__.py +++ b/pytorch_lightning/logging/__init__.py @@ -168,6 +168,8 @@ Every k batches, lightning will write the new logs to disk from os import environ from .base import LightningLoggerBase, rank_zero_only +from .tensorboard import TensorboardLogger + try: from .test_tube import TestTubeLogger except ImportError: diff --git a/pytorch_lightning/logging/tensorboard.py b/pytorch_lightning/logging/tensorboard.py new file mode 100644 index 0000000000..338b02cc7a --- /dev/null +++ b/pytorch_lightning/logging/tensorboard.py @@ -0,0 +1,102 @@ +import os +from warnings import warn + +import torch +from pkg_resources import parse_version +from torch.utils.tensorboard import SummaryWriter + +from .base import LightningLoggerBase, rank_zero_only + + +class TensorboardLogger(LightningLoggerBase): + r"""Log to local file system in Tensorboard format + + Implemented using :class:`torch.utils.tensorboard.SummaryWriter`. Logs are saved to + `os.path.join(save_dir, name, version)` + + :example: + + .. code-block:: python + + logger = TensorboardLogger("tb_logs", name="my_model") + trainer = Trainer(logger=logger) + trainer.train(model) + + :param str save_dir: Save directory + :param str name: Experiment name. Defaults to "default". + :param int version: Experiment version. If version is not specified the logger inspects the save + directory for existing versions, then automatically assigns the next available version. + :param \**kwargs: Other arguments are passed directly to the :class:`SummaryWriter` constructor. + + + """ + + def __init__(self, save_dir, name="default", version=None, **kwargs): + super().__init__() + self.save_dir = save_dir + self._name = name + self._version = version if version is not None else None + + self._experiment = None + self.kwargs = kwargs + + @property + def experiment(self): + """The underlying :class:`torch.utils.tensorboard.SummaryWriter`. + + :rtype: torch.utils.tensorboard.SummaryWriter + """ + if self._experiment is not None: + return self._experiment + + root_dir = os.path.join(self.save_dir, self.name) + os.makedirs(root_dir, exist_ok=True) + log_dir = os.path.join(root_dir, str(self.version)) + self._experiment = SummaryWriter(log_dir=log_dir, **self.kwargs) + return self._experiment + + @rank_zero_only + def log_hyperparams(self, params): + if parse_version(torch.__version__) < parse_version("1.3.0"): + warn( + f"Hyperparameter logging is not available for Torch version {torch.__version__}. " + "Skipping log_hyperparams. Upgrade to Torch 1.3.0 or above to enable " + "hyperparameter logging" + ) + return + self.experiment.add_hparams(hparam_dict=vars(params)) + + @rank_zero_only + def log_metrics(self, metrics, step_idx=None): + for k, v in metrics.items(): + if isinstance(v, torch.Tensor): + v = v.item() + self.experiment.add_scalar(k, v, step_idx) + + @rank_zero_only + def save(self): + self.experiment.flush() + + @rank_zero_only + def finalize(self, status): + self.save() + + @property + def name(self): + return self._name + + @property + def version(self): + if self._version is None: + self._version = self._get_next_version() + return self._version + + def _get_next_version(self): + root_dir = os.path.join(self.save_dir, self.name) + existing_versions = [ + int(d) for d in os.listdir(root_dir) if os.path.isdir(os.path.join(root_dir, d)) and d.isdigit() + ] + if len(existing_versions) == 0: + return 0 + else: + return max(existing_versions) + 1 diff --git a/tests/test_logging.py b/tests/test_logging.py index bf35c2cc46..24d0c9182a 100644 --- a/tests/test_logging.py +++ b/tests/test_logging.py @@ -1,9 +1,16 @@ import os import pickle +import pytest +import torch + import tests.utils as tutils from pytorch_lightning import Trainer -from pytorch_lightning.logging import LightningLoggerBase, rank_zero_only +from pytorch_lightning.logging import ( + LightningLoggerBase, + rank_zero_only, + TensorboardLogger, +) from pytorch_lightning.testing import LightningTestModel @@ -16,6 +23,7 @@ def test_testtube_logger(tmpdir): logger = tutils.get_test_tube_logger(tmpdir, False) trainer_options = dict( + default_save_path=tmpdir, max_epochs=1, train_percent_check=0.01, logger=logger @@ -39,6 +47,7 @@ def test_testtube_pickle(tmpdir): logger.save() trainer_options = dict( + default_save_path=tmpdir, max_epochs=1, train_percent_check=0.01, logger=logger @@ -66,6 +75,7 @@ def test_mlflow_logger(tmpdir): logger = MLFlowLogger("test", tracking_uri=f"file:{os.sep * 2}{mlflow_dir}") trainer_options = dict( + default_save_path=tmpdir, max_epochs=1, train_percent_check=0.01, logger=logger @@ -92,6 +102,7 @@ def test_mlflow_pickle(tmpdir): mlflow_dir = os.path.join(tmpdir, "mlruns") logger = MLFlowLogger("test", tracking_uri=f"file:{os.sep * 2}{mlflow_dir}") trainer_options = dict( + default_save_path=tmpdir, max_epochs=1, logger=logger ) @@ -130,6 +141,7 @@ def test_comet_logger(tmpdir, monkeypatch): ) trainer_options = dict( + default_save_path=tmpdir, max_epochs=1, train_percent_check=0.01, logger=logger @@ -170,6 +182,7 @@ def test_comet_pickle(tmpdir, monkeypatch): ) trainer_options = dict( + default_save_path=tmpdir, max_epochs=1, logger=logger ) @@ -180,6 +193,89 @@ def test_comet_pickle(tmpdir, monkeypatch): trainer2.logger.log_metrics({"acc": 1.0}) +def test_tensorboard_logger(tmpdir): + """Verify that basic functionality of Tensorboard logger works.""" + + hparams = tutils.get_hparams() + model = LightningTestModel(hparams) + + logger = TensorboardLogger(save_dir=tmpdir, name="tensorboard_logger_test") + + trainer_options = dict(max_num_epochs=1, train_percent_check=0.01, logger=logger) + + trainer = Trainer(**trainer_options) + result = trainer.fit(model) + + print("result finished") + assert result == 1, "Training failed" + + +def test_tensorboard_pickle(tmpdir): + """Verify that pickling trainer with Tensorboard logger works.""" + + hparams = tutils.get_hparams() + model = LightningTestModel(hparams) + + comet_dir = os.path.join(tmpdir, "cometruns") + + logger = TensorboardLogger(save_dir=tmpdir, name="tensorboard_pickle_test") + + trainer_options = dict(max_num_epochs=1, logger=logger) + + trainer = Trainer(**trainer_options) + pkl_bytes = pickle.dumps(trainer) + trainer2 = pickle.loads(pkl_bytes) + trainer2.logger.log_metrics({"acc": 1.0}) + + +def test_tensorboard_automatic_versioning(tmpdir): + """Verify that automatic versioning works""" + + root_dir = tmpdir.mkdir("tb_versioning") + root_dir.mkdir("0") + root_dir.mkdir("1") + + logger = TensorboardLogger(save_dir=tmpdir, name="tb_versioning") + + assert logger.version == 2 + + +def test_tensorboard_manual_versioning(tmpdir): + """Verify that manual versioning works""" + + root_dir = tmpdir.mkdir("tb_versioning") + root_dir.mkdir("0") + root_dir.mkdir("1") + root_dir.mkdir("2") + + logger = TensorboardLogger(save_dir=tmpdir, name="tb_versioning", version=1) + + assert logger.version == 1 + + +@pytest.mark.parametrize("step_idx", [10, None]) +def test_tensorboard_log_metrics(tmpdir, step_idx): + logger = TensorboardLogger(tmpdir) + metrics = { + "float": 0.3, + "int": 1, + "FloatTensor": torch.tensor(0.1), + "IntTensor": torch.tensor(1) + } + logger.log_metrics(metrics, step_idx) + + +def test_tensorboard_log_hyperparams(tmpdir): + logger = TensorboardLogger(tmpdir) + hparams = { + "float": 0.3, + "int": 1, + "string": "abc", + "bool": True + } + logger.log_hyperparams(hparams) + + def test_custom_logger(tmpdir): class CustomLogger(LightningLoggerBase): def __init__(self):