Flatten dataclass hyperparameters for logging (#18906)
Co-authored-by: jaswon <jason@jwon.xyz>
This commit is contained in:
parent
ed7cc27d57
commit
8d68607cef
|
@ -13,6 +13,7 @@
|
|||
# limitations under the License.
|
||||
|
||||
from argparse import Namespace
|
||||
from dataclasses import asdict, is_dataclass
|
||||
from typing import Any, Dict, Mapping, MutableMapping, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
|
@ -88,8 +89,11 @@ def _flatten_dict(params: MutableMapping[Any, Any], delimiter: str = "/", parent
|
|||
result: Dict[str, Any] = {}
|
||||
for k, v in params.items():
|
||||
new_key = parent_key + delimiter + str(k) if parent_key else str(k)
|
||||
if isinstance(v, Namespace):
|
||||
if is_dataclass(v):
|
||||
v = asdict(v)
|
||||
elif isinstance(v, Namespace):
|
||||
v = vars(v)
|
||||
|
||||
if isinstance(v, MutableMapping):
|
||||
result = {**result, **_flatten_dict(v, parent_key=new_key, delimiter=delimiter)}
|
||||
else:
|
||||
|
|
|
@ -13,6 +13,7 @@
|
|||
# limitations under the License.
|
||||
|
||||
from argparse import Namespace
|
||||
from dataclasses import dataclass
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
@ -73,6 +74,21 @@ def test_flatten_dict():
|
|||
assert "a" not in params
|
||||
assert "b" not in params
|
||||
|
||||
# Test flattening of dataclass objects
|
||||
@dataclass
|
||||
class A:
|
||||
c: int
|
||||
d: int
|
||||
|
||||
@dataclass
|
||||
class B:
|
||||
a: A
|
||||
b: int
|
||||
|
||||
params = {"params": B(a=A(c=1, d=2), b=3), "param": 4}
|
||||
params = _flatten_dict(params)
|
||||
assert params == {"param": 4, "params/b": 3, "params/a/c": 1, "params/a/d": 2}
|
||||
|
||||
|
||||
def test_sanitize_callable_params():
|
||||
"""Callback function are not serializiable.
|
||||
|
|
Loading…
Reference in New Issue