170 lines
5.2 KiB
Python
170 lines
5.2 KiB
Python
# Copyright The Lightning AI team.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
from argparse import Namespace
|
|
from dataclasses import dataclass
|
|
|
|
import numpy as np
|
|
import torch
|
|
from lightning.fabric.utilities.logger import (
|
|
_add_prefix,
|
|
_convert_params,
|
|
_flatten_dict,
|
|
_sanitize_callable_params,
|
|
_sanitize_params,
|
|
)
|
|
|
|
|
|
def test_convert_params():
|
|
"""Test conversion of params to a dict."""
|
|
# Test normal dict, make sure it is unchanged
|
|
params = {"string": "string", "int": 1, "float": 0.1, "bool": True, "none": None}
|
|
expected = params.copy()
|
|
assert _convert_params(params) == expected
|
|
|
|
# Test None conversion
|
|
assert _convert_params(None) == {}
|
|
|
|
# Test conversion of argparse Namespace
|
|
params = Namespace(string="string", int=1, float=0.1, bool=True, none=None)
|
|
expected = vars(params)
|
|
assert _convert_params(params) == expected
|
|
|
|
|
|
def test_flatten_dict():
|
|
"""Validate flatten_dict can handle nested dictionaries and argparse Namespace."""
|
|
# Test basic dict flattening with custom delimiter
|
|
params = {"a": {"b": "c"}}
|
|
params = _flatten_dict(params, "--")
|
|
|
|
assert "a" not in params
|
|
assert params["a--b"] == "c"
|
|
|
|
# Test complex nested dict flattening
|
|
params = {"a": {5: {"foo": "bar"}}, "b": 6, "c": {7: [1, 2, 3, 4], 8: "foo", 9: {10: "bar"}}}
|
|
params = _flatten_dict(params)
|
|
|
|
assert "a" not in params
|
|
assert params["a/5/foo"] == "bar"
|
|
assert params["b"] == 6
|
|
assert params["c/7"] == [1, 2, 3, 4]
|
|
assert params["c/8"] == "foo"
|
|
assert params["c/9/10"] == "bar"
|
|
|
|
# Test flattening of argparse Namespace
|
|
params = Namespace(a=1, b=2)
|
|
wrapping_dict = {"params": params}
|
|
params = _flatten_dict(wrapping_dict)
|
|
|
|
params_type = type(params) # way around needed for Ruff's `isinstance` suggestion
|
|
assert params_type is dict
|
|
assert params["params/a"] == 1
|
|
assert params["params/b"] == 2
|
|
assert "a" not in params
|
|
assert "b" not in params
|
|
|
|
# Test flattening of dataclass objects
|
|
@dataclass
|
|
class A:
|
|
c: int
|
|
d: int
|
|
|
|
@dataclass
|
|
class B:
|
|
a: A
|
|
b: int
|
|
|
|
params = {"params": B(a=A(c=1, d=2), b=3), "param": 4}
|
|
params = _flatten_dict(params)
|
|
assert params == {"param": 4, "params/b": 3, "params/a/c": 1, "params/a/d": 2}
|
|
|
|
|
|
def test_sanitize_callable_params():
|
|
"""Callback function are not serializiable.
|
|
|
|
Therefore, we get them a chance to return something and if the returned type is not accepted, return None.
|
|
|
|
"""
|
|
|
|
def return_something():
|
|
return "something"
|
|
|
|
def wrapper_something():
|
|
return return_something
|
|
|
|
params = Namespace(
|
|
foo="bar",
|
|
something=return_something,
|
|
wrapper_something_wo_name=(lambda: lambda: "1"),
|
|
wrapper_something=wrapper_something,
|
|
)
|
|
|
|
params = _convert_params(params)
|
|
params = _flatten_dict(params)
|
|
params = _sanitize_callable_params(params)
|
|
assert params["foo"] == "bar"
|
|
assert params["something"] == "something"
|
|
assert params["wrapper_something"] == "wrapper_something"
|
|
assert params["wrapper_something_wo_name"] == "<lambda>"
|
|
|
|
|
|
def test_sanitize_params():
|
|
"""Verify sanitize params converts various types to loggable strings."""
|
|
params = {
|
|
"float": 0.3,
|
|
"int": 1,
|
|
"string": "abc",
|
|
"bool": True,
|
|
"list": [1, 2, 3],
|
|
"np_bool": np.bool_(False),
|
|
"np_int": np.int_(5),
|
|
"np_double": np.double(3.14159),
|
|
"namespace": Namespace(foo=3),
|
|
"layer": torch.nn.BatchNorm1d,
|
|
"tensor": torch.ones(3),
|
|
}
|
|
params = _sanitize_params(params)
|
|
|
|
assert params["float"] == 0.3
|
|
assert params["int"] == 1
|
|
assert params["string"] == "abc"
|
|
assert params["bool"] is True
|
|
assert params["list"] == "[1, 2, 3]"
|
|
assert params["np_bool"] is False
|
|
assert params["np_int"] == 5
|
|
assert params["np_double"] == 3.14159
|
|
assert params["namespace"] == "Namespace(foo=3)"
|
|
assert params["layer"] == "<class 'torch.nn.modules.batchnorm.BatchNorm1d'>"
|
|
assert torch.equal(params["tensor"], torch.ones(3))
|
|
|
|
|
|
def test_add_prefix():
|
|
"""Verify add_prefix modifies the dict keys correctly."""
|
|
metrics = {"metric1": 1, "metric2": 2}
|
|
metrics = _add_prefix(metrics, "prefix", "-")
|
|
|
|
assert "prefix-metric1" in metrics
|
|
assert "prefix-metric2" in metrics
|
|
assert "metric1" not in metrics
|
|
assert "metric2" not in metrics
|
|
|
|
metrics = _add_prefix(metrics, "prefix2", "_")
|
|
|
|
assert "prefix2_prefix-metric1" in metrics
|
|
assert "prefix2_prefix-metric2" in metrics
|
|
assert "prefix-metric1" not in metrics
|
|
assert "prefix-metric2" not in metrics
|
|
assert metrics["prefix2_prefix-metric1"] == 1
|
|
assert metrics["prefix2_prefix-metric2"] == 2
|