From dfae7342ccfeeb5317090caf80714af0ef07acc9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Tue, 24 Aug 2021 13:02:06 +0200 Subject: [PATCH] sanitize arrays when logging as hyperparameters in TensorBoardLogger (#9031) Co-authored-by: Justus Schock <12886177+justusschock@users.noreply.github.com> Co-authored-by: Justus Schock <12886177+justusschock@users.noreply.github.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: thomas chaton --- CHANGELOG.md | 3 +++ pytorch_lightning/loggers/tensorboard.py | 7 +++++++ tests/loggers/test_tensorboard.py | 5 +++++ 3 files changed, 15 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index f0b3d5f125..d7fe1638ca 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -63,6 +63,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Added DeepSpeed Stage 1 support ([#8974](https://github.com/PyTorchLightning/pytorch-lightning/pull/8974)) +- Added sanitization of tensors when they get logged as hyperparameters in `TensorBoardLogger` ([#9031](https://github.com/PyTorchLightning/pytorch-lightning/pull/9031)) + + - Added `InterBatchParallelDataFetcher` ([#9020](https://github.com/PyTorchLightning/pytorch-lightning/pull/9020)) diff --git a/pytorch_lightning/loggers/tensorboard.py b/pytorch_lightning/loggers/tensorboard.py index 85ecb51304..e51e6b1f6a 100644 --- a/pytorch_lightning/loggers/tensorboard.py +++ b/pytorch_lightning/loggers/tensorboard.py @@ -21,6 +21,7 @@ import os from argparse import Namespace from typing import Any, Dict, Optional, Union +import numpy as np import torch from torch.utils.tensorboard import SummaryWriter from torch.utils.tensorboard.summary import hparams @@ -286,6 +287,12 @@ class TensorBoardLogger(LightningLoggerBase): return max(existing_versions) + 1 + @staticmethod + def _sanitize_params(params: Dict[str, Any]) -> Dict[str, Any]: + params = LightningLoggerBase._sanitize_params(params) + # logging of arrays with dimension > 1 is not supported, sanitize as string + return {k: str(v) if isinstance(v, (torch.Tensor, np.ndarray)) and v.ndim > 1 else v for k, v in params.items()} + def __getstate__(self): state = self.__dict__.copy() state["_experiment"] = None diff --git a/tests/loggers/test_tensorboard.py b/tests/loggers/test_tensorboard.py index ef6d2203a3..a1c66c0559 100644 --- a/tests/loggers/test_tensorboard.py +++ b/tests/loggers/test_tensorboard.py @@ -17,6 +17,7 @@ import os from argparse import Namespace from unittest import mock +import numpy as np import pytest import torch import yaml @@ -178,6 +179,8 @@ def test_tensorboard_log_hyperparams(tmpdir): "list": [1, 2, 3], "namespace": Namespace(foo=Namespace(bar="buzz")), "layer": torch.nn.BatchNorm1d, + "tensor": torch.empty(2, 2, 2), + "array": np.empty([2, 2, 2]), } logger.log_hyperparams(hparams) @@ -193,6 +196,8 @@ def test_tensorboard_log_hparams_and_metrics(tmpdir): "list": [1, 2, 3], "namespace": Namespace(foo=Namespace(bar="buzz")), "layer": torch.nn.BatchNorm1d, + "tensor": torch.empty(2, 2, 2), + "array": np.empty([2, 2, 2]), } metrics = {"abc": torch.tensor([0.54])} logger.log_hyperparams(hparams, metrics)