sanitize arrays when logging as hyperparameters in TensorBoardLogger (#9031)
Co-authored-by: Justus Schock <12886177+justusschock@users.noreply.github.com> Co-authored-by: Justus Schock <12886177+justusschock@users.noreply.github.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: thomas chaton <thomas@grid.ai>
This commit is contained in:
parent
1feec8c601
commit
dfae7342cc
|
@ -63,6 +63,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
|
|||
- Added DeepSpeed Stage 1 support ([#8974](https://github.com/PyTorchLightning/pytorch-lightning/pull/8974))
|
||||
|
||||
|
||||
- Added sanitization of tensors when they get logged as hyperparameters in `TensorBoardLogger` ([#9031](https://github.com/PyTorchLightning/pytorch-lightning/pull/9031))
|
||||
|
||||
|
||||
- Added `InterBatchParallelDataFetcher` ([#9020](https://github.com/PyTorchLightning/pytorch-lightning/pull/9020))
|
||||
|
||||
|
||||
|
|
|
@ -21,6 +21,7 @@ import os
|
|||
from argparse import Namespace
|
||||
from typing import Any, Dict, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
from torch.utils.tensorboard.summary import hparams
|
||||
|
@ -286,6 +287,12 @@ class TensorBoardLogger(LightningLoggerBase):
|
|||
|
||||
return max(existing_versions) + 1
|
||||
|
||||
@staticmethod
|
||||
def _sanitize_params(params: Dict[str, Any]) -> Dict[str, Any]:
|
||||
params = LightningLoggerBase._sanitize_params(params)
|
||||
# logging of arrays with dimension > 1 is not supported, sanitize as string
|
||||
return {k: str(v) if isinstance(v, (torch.Tensor, np.ndarray)) and v.ndim > 1 else v for k, v in params.items()}
|
||||
|
||||
def __getstate__(self):
|
||||
state = self.__dict__.copy()
|
||||
state["_experiment"] = None
|
||||
|
|
|
@ -17,6 +17,7 @@ import os
|
|||
from argparse import Namespace
|
||||
from unittest import mock
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
import torch
|
||||
import yaml
|
||||
|
@ -178,6 +179,8 @@ def test_tensorboard_log_hyperparams(tmpdir):
|
|||
"list": [1, 2, 3],
|
||||
"namespace": Namespace(foo=Namespace(bar="buzz")),
|
||||
"layer": torch.nn.BatchNorm1d,
|
||||
"tensor": torch.empty(2, 2, 2),
|
||||
"array": np.empty([2, 2, 2]),
|
||||
}
|
||||
logger.log_hyperparams(hparams)
|
||||
|
||||
|
@ -193,6 +196,8 @@ def test_tensorboard_log_hparams_and_metrics(tmpdir):
|
|||
"list": [1, 2, 3],
|
||||
"namespace": Namespace(foo=Namespace(bar="buzz")),
|
||||
"layer": torch.nn.BatchNorm1d,
|
||||
"tensor": torch.empty(2, 2, 2),
|
||||
"array": np.empty([2, 2, 2]),
|
||||
}
|
||||
metrics = {"abc": torch.tensor([0.54])}
|
||||
logger.log_hyperparams(hparams, metrics)
|
||||
|
|
Loading…
Reference in New Issue