From ce90b3898aff1dac89215f49f40b19777c91125a Mon Sep 17 00:00:00 2001 From: awaelchli Date: Sun, 14 Apr 2024 06:01:58 -0700 Subject: [PATCH] Sanitize hparams that can't be json-serialized in `WandbLogger.log_hyperparameters()` (#19769) --- src/lightning/fabric/utilities/logger.py | 19 +++++++++++++- src/lightning/pytorch/CHANGELOG.md | 3 ++- src/lightning/pytorch/loggers/wandb.py | 8 +++++- tests/tests_fabric/utilities/test_logger.py | 29 ++++++++++++++++++++- tests/tests_pytorch/loggers/test_wandb.py | 6 +++-- 5 files changed, 59 insertions(+), 6 deletions(-) diff --git a/src/lightning/fabric/utilities/logger.py b/src/lightning/fabric/utilities/logger.py index 2604a0d926..abe5816ded 100644 --- a/src/lightning/fabric/utilities/logger.py +++ b/src/lightning/fabric/utilities/logger.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - +import json from argparse import Namespace from dataclasses import asdict, is_dataclass from typing import Any, Dict, Mapping, MutableMapping, Optional, Union @@ -132,6 +132,23 @@ def _sanitize_params(params: Dict[str, Any]) -> Dict[str, Any]: return params +def _convert_json_serializable(params: Dict[str, Any]) -> Dict[str, Any]: + """Convert non-serializable objects in params to string.""" + return {k: str(v) if not _is_json_serializable(v) else v for k, v in params.items()} + + +def _is_json_serializable(value: Any) -> bool: + """Test whether a variable can be encoded as json.""" + if value is None or isinstance(value, (bool, int, float, str, list, dict)): # fast path + return True + try: + json.dumps(value) + return True + except (TypeError, OverflowError): + # OverflowError is raised if number is too large to encode + return False + + def _add_prefix( metrics: Mapping[str, Union[Tensor, float]], prefix: str, separator: str ) -> Mapping[str, Union[Tensor, float]]: diff --git a/src/lightning/pytorch/CHANGELOG.md b/src/lightning/pytorch/CHANGELOG.md index be7b66a27c..9f2a1b40ac 100644 --- a/src/lightning/pytorch/CHANGELOG.md +++ b/src/lightning/pytorch/CHANGELOG.md @@ -49,7 +49,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixed an issue causing a TypeError when using `torch.compile` as a decorator ([#19627](https://github.com/Lightning-AI/pytorch-lightning/pull/19627)) -- + +- Fixed `WandbLogger.log_hyperparameters()` raising an error if hyperparameters are not JSON serializable ([#19769](https://github.com/Lightning-AI/pytorch-lightning/pull/19769)) - diff --git a/src/lightning/pytorch/loggers/wandb.py b/src/lightning/pytorch/loggers/wandb.py index 4025f2cd18..c5d995bff3 100644 --- a/src/lightning/pytorch/loggers/wandb.py +++ b/src/lightning/pytorch/loggers/wandb.py @@ -26,7 +26,12 @@ from lightning_utilities.core.imports import RequirementCache from torch import Tensor from typing_extensions import override -from lightning.fabric.utilities.logger import _add_prefix, _convert_params, _sanitize_callable_params +from lightning.fabric.utilities.logger import ( + _add_prefix, + _convert_json_serializable, + _convert_params, + _sanitize_callable_params, +) from lightning.fabric.utilities.types import _PATH from lightning.pytorch.callbacks.model_checkpoint import ModelCheckpoint from lightning.pytorch.loggers.logger import Logger, rank_zero_experiment @@ -419,6 +424,7 @@ class WandbLogger(Logger): def log_hyperparams(self, params: Union[Dict[str, Any], Namespace]) -> None: params = _convert_params(params) params = _sanitize_callable_params(params) + params = _convert_json_serializable(params) self.experiment.config.update(params, allow_val_change=True) @override diff --git a/tests/tests_fabric/utilities/test_logger.py b/tests/tests_fabric/utilities/test_logger.py index 5b62113314..33681c65f7 100644 --- a/tests/tests_fabric/utilities/test_logger.py +++ b/tests/tests_fabric/utilities/test_logger.py @@ -11,14 +11,15 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - from argparse import Namespace from dataclasses import dataclass +from pathlib import Path import numpy as np import torch from lightning.fabric.utilities.logger import ( _add_prefix, + _convert_json_serializable, _convert_params, _flatten_dict, _sanitize_callable_params, @@ -167,3 +168,29 @@ def test_add_prefix(): assert "prefix-metric2" not in metrics assert metrics["prefix2_prefix-metric1"] == 1 assert metrics["prefix2_prefix-metric2"] == 2 + + +def test_convert_json_serializable(): + data = { + # JSON-serializable + "none": None, + "int": 1, + "float": 1.1, + "bool": True, + "dict": {"a": 1}, + "list": [2, 3, 4], + # not JSON-serializable + "path": Path("path"), + "tensor": torch.tensor(1), + } + expected = { + "none": None, + "int": 1, + "float": 1.1, + "bool": True, + "dict": {"a": 1}, + "list": [2, 3, 4], + "path": "path", + "tensor": "tensor(1)", + } + assert _convert_json_serializable(data) == expected diff --git a/tests/tests_pytorch/loggers/test_wandb.py b/tests/tests_pytorch/loggers/test_wandb.py index f667b0a7b5..a8e70bfb65 100644 --- a/tests/tests_pytorch/loggers/test_wandb.py +++ b/tests/tests_pytorch/loggers/test_wandb.py @@ -13,6 +13,7 @@ # limitations under the License. import os import pickle +from pathlib import Path from unittest import mock import pytest @@ -113,9 +114,10 @@ def test_wandb_logger_init(wandb_mock): wandb_mock.init().log.assert_called_with({"acc": 1.0, "trainer/global_step": 6}) # log hyper parameters - hparams = {"test": None, "nested": {"a": 1}, "b": [2, 3, 4]} + hparams = {"none": None, "dict": {"a": 1}, "b": [2, 3, 4], "path": Path("path")} + expected = {"none": None, "dict": {"a": 1}, "b": [2, 3, 4], "path": "path"} logger.log_hyperparams(hparams) - wandb_mock.init().config.update.assert_called_once_with(hparams, allow_val_change=True) + wandb_mock.init().config.update.assert_called_once_with(expected, allow_val_change=True) # watch a model logger.watch("model", "log", 10, False)