sanitize arrays when logging as hyperparameters in TensorBoardLogger (#9031)

Co-authored-by: Justus Schock <12886177+justusschock@users.noreply.github.com>
Co-authored-by: Justus Schock <12886177+justusschock@users.noreply.github.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: thomas chaton <thomas@grid.ai>
This commit is contained in:
Adrian Wälchli 2021-08-24 13:02:06 +02:00 committed by GitHub
parent 1feec8c601
commit dfae7342cc
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 15 additions and 0 deletions

View File

@ -63,6 +63,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Added DeepSpeed Stage 1 support ([#8974](https://github.com/PyTorchLightning/pytorch-lightning/pull/8974))
- Added sanitization of tensors when they get logged as hyperparameters in `TensorBoardLogger` ([#9031](https://github.com/PyTorchLightning/pytorch-lightning/pull/9031))
- Added `InterBatchParallelDataFetcher` ([#9020](https://github.com/PyTorchLightning/pytorch-lightning/pull/9020))

View File

@ -21,6 +21,7 @@ import os
from argparse import Namespace
from typing import Any, Dict, Optional, Union
import numpy as np
import torch
from torch.utils.tensorboard import SummaryWriter
from torch.utils.tensorboard.summary import hparams
@ -286,6 +287,12 @@ class TensorBoardLogger(LightningLoggerBase):
return max(existing_versions) + 1
@staticmethod
def _sanitize_params(params: Dict[str, Any]) -> Dict[str, Any]:
params = LightningLoggerBase._sanitize_params(params)
# logging of arrays with dimension > 1 is not supported, sanitize as string
return {k: str(v) if isinstance(v, (torch.Tensor, np.ndarray)) and v.ndim > 1 else v for k, v in params.items()}
def __getstate__(self):
state = self.__dict__.copy()
state["_experiment"] = None

View File

@ -17,6 +17,7 @@ import os
from argparse import Namespace
from unittest import mock
import numpy as np
import pytest
import torch
import yaml
@ -178,6 +179,8 @@ def test_tensorboard_log_hyperparams(tmpdir):
"list": [1, 2, 3],
"namespace": Namespace(foo=Namespace(bar="buzz")),
"layer": torch.nn.BatchNorm1d,
"tensor": torch.empty(2, 2, 2),
"array": np.empty([2, 2, 2]),
}
logger.log_hyperparams(hparams)
@ -193,6 +196,8 @@ def test_tensorboard_log_hparams_and_metrics(tmpdir):
"list": [1, 2, 3],
"namespace": Namespace(foo=Namespace(bar="buzz")),
"layer": torch.nn.BatchNorm1d,
"tensor": torch.empty(2, 2, 2),
"array": np.empty([2, 2, 2]),
}
metrics = {"abc": torch.tensor([0.54])}
logger.log_hyperparams(hparams, metrics)