diff --git a/src/lightning/fabric/utilities/logger.py b/src/lightning/fabric/utilities/logger.py index c3874262ca..2604a0d926 100644 --- a/src/lightning/fabric/utilities/logger.py +++ b/src/lightning/fabric/utilities/logger.py @@ -13,6 +13,7 @@ # limitations under the License. from argparse import Namespace +from dataclasses import asdict, is_dataclass from typing import Any, Dict, Mapping, MutableMapping, Optional, Union import numpy as np @@ -88,8 +89,11 @@ def _flatten_dict(params: MutableMapping[Any, Any], delimiter: str = "/", parent result: Dict[str, Any] = {} for k, v in params.items(): new_key = parent_key + delimiter + str(k) if parent_key else str(k) - if isinstance(v, Namespace): + if is_dataclass(v): + v = asdict(v) + elif isinstance(v, Namespace): v = vars(v) + if isinstance(v, MutableMapping): result = {**result, **_flatten_dict(v, parent_key=new_key, delimiter=delimiter)} else: diff --git a/tests/tests_fabric/utilities/test_logger.py b/tests/tests_fabric/utilities/test_logger.py index b5e28c9b09..5b62113314 100644 --- a/tests/tests_fabric/utilities/test_logger.py +++ b/tests/tests_fabric/utilities/test_logger.py @@ -13,6 +13,7 @@ # limitations under the License. from argparse import Namespace +from dataclasses import dataclass import numpy as np import torch @@ -73,6 +74,21 @@ def test_flatten_dict(): assert "a" not in params assert "b" not in params + # Test flattening of dataclass objects + @dataclass + class A: + c: int + d: int + + @dataclass + class B: + a: A + b: int + + params = {"params": B(a=A(c=1, d=2), b=3), "param": 4} + params = _flatten_dict(params) + assert params == {"param": 4, "params/b": 3, "params/a/c": 1, "params/a/d": 2} + def test_sanitize_callable_params(): """Callback function are not serializiable.