Sanitize hparams that can't be json-serialized in `WandbLogger.log_hyperparameters()` (#19769)

This commit is contained in:
awaelchli 2024-04-14 06:01:58 -07:00 committed by GitHub
parent 67b270bd4d
commit ce90b3898a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 59 additions and 6 deletions

View File

@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json
from argparse import Namespace
from dataclasses import asdict, is_dataclass
from typing import Any, Dict, Mapping, MutableMapping, Optional, Union
@ -132,6 +132,23 @@ def _sanitize_params(params: Dict[str, Any]) -> Dict[str, Any]:
return params
def _convert_json_serializable(params: Dict[str, Any]) -> Dict[str, Any]:
"""Convert non-serializable objects in params to string."""
return {k: str(v) if not _is_json_serializable(v) else v for k, v in params.items()}
def _is_json_serializable(value: Any) -> bool:
"""Test whether a variable can be encoded as json."""
if value is None or isinstance(value, (bool, int, float, str, list, dict)): # fast path
return True
try:
json.dumps(value)
return True
except (TypeError, OverflowError):
# OverflowError is raised if number is too large to encode
return False
def _add_prefix(
metrics: Mapping[str, Union[Tensor, float]], prefix: str, separator: str
) -> Mapping[str, Union[Tensor, float]]:

View File

@ -49,7 +49,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Fixed an issue causing a TypeError when using `torch.compile` as a decorator ([#19627](https://github.com/Lightning-AI/pytorch-lightning/pull/19627))
-
- Fixed `WandbLogger.log_hyperparameters()` raising an error if hyperparameters are not JSON serializable ([#19769](https://github.com/Lightning-AI/pytorch-lightning/pull/19769))
-

View File

@ -26,7 +26,12 @@ from lightning_utilities.core.imports import RequirementCache
from torch import Tensor
from typing_extensions import override
from lightning.fabric.utilities.logger import _add_prefix, _convert_params, _sanitize_callable_params
from lightning.fabric.utilities.logger import (
_add_prefix,
_convert_json_serializable,
_convert_params,
_sanitize_callable_params,
)
from lightning.fabric.utilities.types import _PATH
from lightning.pytorch.callbacks.model_checkpoint import ModelCheckpoint
from lightning.pytorch.loggers.logger import Logger, rank_zero_experiment
@ -419,6 +424,7 @@ class WandbLogger(Logger):
def log_hyperparams(self, params: Union[Dict[str, Any], Namespace]) -> None:
params = _convert_params(params)
params = _sanitize_callable_params(params)
params = _convert_json_serializable(params)
self.experiment.config.update(params, allow_val_change=True)
@override

View File

@ -11,14 +11,15 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from argparse import Namespace
from dataclasses import dataclass
from pathlib import Path
import numpy as np
import torch
from lightning.fabric.utilities.logger import (
_add_prefix,
_convert_json_serializable,
_convert_params,
_flatten_dict,
_sanitize_callable_params,
@ -167,3 +168,29 @@ def test_add_prefix():
assert "prefix-metric2" not in metrics
assert metrics["prefix2_prefix-metric1"] == 1
assert metrics["prefix2_prefix-metric2"] == 2
def test_convert_json_serializable():
data = {
# JSON-serializable
"none": None,
"int": 1,
"float": 1.1,
"bool": True,
"dict": {"a": 1},
"list": [2, 3, 4],
# not JSON-serializable
"path": Path("path"),
"tensor": torch.tensor(1),
}
expected = {
"none": None,
"int": 1,
"float": 1.1,
"bool": True,
"dict": {"a": 1},
"list": [2, 3, 4],
"path": "path",
"tensor": "tensor(1)",
}
assert _convert_json_serializable(data) == expected

View File

@ -13,6 +13,7 @@
# limitations under the License.
import os
import pickle
from pathlib import Path
from unittest import mock
import pytest
@ -113,9 +114,10 @@ def test_wandb_logger_init(wandb_mock):
wandb_mock.init().log.assert_called_with({"acc": 1.0, "trainer/global_step": 6})
# log hyper parameters
hparams = {"test": None, "nested": {"a": 1}, "b": [2, 3, 4]}
hparams = {"none": None, "dict": {"a": 1}, "b": [2, 3, 4], "path": Path("path")}
expected = {"none": None, "dict": {"a": 1}, "b": [2, 3, 4], "path": "path"}
logger.log_hyperparams(hparams)
wandb_mock.init().config.update.assert_called_once_with(hparams, allow_val_change=True)
wandb_mock.init().config.update.assert_called_once_with(expected, allow_val_change=True)
# watch a model
logger.watch("model", "log", 10, False)