Flatten dataclass hyperparameters for logging (#18906)

Co-authored-by: jaswon <jason@jwon.xyz>
This commit is contained in:
Jason Won 2023-11-03 19:30:19 -04:00 committed by GitHub
parent ed7cc27d57
commit 8d68607cef
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 21 additions and 1 deletions

View File

@ -13,6 +13,7 @@
# limitations under the License.
from argparse import Namespace
from dataclasses import asdict, is_dataclass
from typing import Any, Dict, Mapping, MutableMapping, Optional, Union
import numpy as np
@ -88,8 +89,11 @@ def _flatten_dict(params: MutableMapping[Any, Any], delimiter: str = "/", parent
result: Dict[str, Any] = {}
for k, v in params.items():
new_key = parent_key + delimiter + str(k) if parent_key else str(k)
if isinstance(v, Namespace):
if is_dataclass(v):
v = asdict(v)
elif isinstance(v, Namespace):
v = vars(v)
if isinstance(v, MutableMapping):
result = {**result, **_flatten_dict(v, parent_key=new_key, delimiter=delimiter)}
else:

View File

@ -13,6 +13,7 @@
# limitations under the License.
from argparse import Namespace
from dataclasses import dataclass
import numpy as np
import torch
@ -73,6 +74,21 @@ def test_flatten_dict():
assert "a" not in params
assert "b" not in params
# Test flattening of dataclass objects
@dataclass
class A:
c: int
d: int
@dataclass
class B:
a: A
b: int
params = {"params": B(a=A(c=1, d=2), b=3), "param": 4}
params = _flatten_dict(params)
assert params == {"param": 4, "params/b": 3, "params/a/c": 1, "params/a/d": 2}
def test_sanitize_callable_params():
"""Callback function are not serializiable.