Sanitize hparams that can't be json-serialized in `WandbLogger.log_hyperparameters()` (#19769)
This commit is contained in:
parent
67b270bd4d
commit
ce90b3898a
|
@ -11,7 +11,7 @@
|
|||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import json
|
||||
from argparse import Namespace
|
||||
from dataclasses import asdict, is_dataclass
|
||||
from typing import Any, Dict, Mapping, MutableMapping, Optional, Union
|
||||
|
@ -132,6 +132,23 @@ def _sanitize_params(params: Dict[str, Any]) -> Dict[str, Any]:
|
|||
return params
|
||||
|
||||
|
||||
def _convert_json_serializable(params: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Convert non-serializable objects in params to string."""
|
||||
return {k: str(v) if not _is_json_serializable(v) else v for k, v in params.items()}
|
||||
|
||||
|
||||
def _is_json_serializable(value: Any) -> bool:
|
||||
"""Test whether a variable can be encoded as json."""
|
||||
if value is None or isinstance(value, (bool, int, float, str, list, dict)): # fast path
|
||||
return True
|
||||
try:
|
||||
json.dumps(value)
|
||||
return True
|
||||
except (TypeError, OverflowError):
|
||||
# OverflowError is raised if number is too large to encode
|
||||
return False
|
||||
|
||||
|
||||
def _add_prefix(
|
||||
metrics: Mapping[str, Union[Tensor, float]], prefix: str, separator: str
|
||||
) -> Mapping[str, Union[Tensor, float]]:
|
||||
|
|
|
@ -49,7 +49,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
|
|||
|
||||
- Fixed an issue causing a TypeError when using `torch.compile` as a decorator ([#19627](https://github.com/Lightning-AI/pytorch-lightning/pull/19627))
|
||||
|
||||
-
|
||||
|
||||
- Fixed `WandbLogger.log_hyperparameters()` raising an error if hyperparameters are not JSON serializable ([#19769](https://github.com/Lightning-AI/pytorch-lightning/pull/19769))
|
||||
|
||||
-
|
||||
|
||||
|
|
|
@ -26,7 +26,12 @@ from lightning_utilities.core.imports import RequirementCache
|
|||
from torch import Tensor
|
||||
from typing_extensions import override
|
||||
|
||||
from lightning.fabric.utilities.logger import _add_prefix, _convert_params, _sanitize_callable_params
|
||||
from lightning.fabric.utilities.logger import (
|
||||
_add_prefix,
|
||||
_convert_json_serializable,
|
||||
_convert_params,
|
||||
_sanitize_callable_params,
|
||||
)
|
||||
from lightning.fabric.utilities.types import _PATH
|
||||
from lightning.pytorch.callbacks.model_checkpoint import ModelCheckpoint
|
||||
from lightning.pytorch.loggers.logger import Logger, rank_zero_experiment
|
||||
|
@ -419,6 +424,7 @@ class WandbLogger(Logger):
|
|||
def log_hyperparams(self, params: Union[Dict[str, Any], Namespace]) -> None:
|
||||
params = _convert_params(params)
|
||||
params = _sanitize_callable_params(params)
|
||||
params = _convert_json_serializable(params)
|
||||
self.experiment.config.update(params, allow_val_change=True)
|
||||
|
||||
@override
|
||||
|
|
|
@ -11,14 +11,15 @@
|
|||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from argparse import Namespace
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from lightning.fabric.utilities.logger import (
|
||||
_add_prefix,
|
||||
_convert_json_serializable,
|
||||
_convert_params,
|
||||
_flatten_dict,
|
||||
_sanitize_callable_params,
|
||||
|
@ -167,3 +168,29 @@ def test_add_prefix():
|
|||
assert "prefix-metric2" not in metrics
|
||||
assert metrics["prefix2_prefix-metric1"] == 1
|
||||
assert metrics["prefix2_prefix-metric2"] == 2
|
||||
|
||||
|
||||
def test_convert_json_serializable():
|
||||
data = {
|
||||
# JSON-serializable
|
||||
"none": None,
|
||||
"int": 1,
|
||||
"float": 1.1,
|
||||
"bool": True,
|
||||
"dict": {"a": 1},
|
||||
"list": [2, 3, 4],
|
||||
# not JSON-serializable
|
||||
"path": Path("path"),
|
||||
"tensor": torch.tensor(1),
|
||||
}
|
||||
expected = {
|
||||
"none": None,
|
||||
"int": 1,
|
||||
"float": 1.1,
|
||||
"bool": True,
|
||||
"dict": {"a": 1},
|
||||
"list": [2, 3, 4],
|
||||
"path": "path",
|
||||
"tensor": "tensor(1)",
|
||||
}
|
||||
assert _convert_json_serializable(data) == expected
|
||||
|
|
|
@ -13,6 +13,7 @@
|
|||
# limitations under the License.
|
||||
import os
|
||||
import pickle
|
||||
from pathlib import Path
|
||||
from unittest import mock
|
||||
|
||||
import pytest
|
||||
|
@ -113,9 +114,10 @@ def test_wandb_logger_init(wandb_mock):
|
|||
wandb_mock.init().log.assert_called_with({"acc": 1.0, "trainer/global_step": 6})
|
||||
|
||||
# log hyper parameters
|
||||
hparams = {"test": None, "nested": {"a": 1}, "b": [2, 3, 4]}
|
||||
hparams = {"none": None, "dict": {"a": 1}, "b": [2, 3, 4], "path": Path("path")}
|
||||
expected = {"none": None, "dict": {"a": 1}, "b": [2, 3, 4], "path": "path"}
|
||||
logger.log_hyperparams(hparams)
|
||||
wandb_mock.init().config.update.assert_called_once_with(hparams, allow_val_change=True)
|
||||
wandb_mock.init().config.update.assert_called_once_with(expected, allow_val_change=True)
|
||||
|
||||
# watch a model
|
||||
logger.watch("model", "log", 10, False)
|
||||
|
|
Loading…
Reference in New Issue