lightning/tests/tests_pytorch/utilities/test_logger.py

194 lines
6.0 KiB
Python

# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from argparse import ArgumentParser, Namespace
import numpy as np
import torch
from pytorch_lightning import Trainer
from pytorch_lightning.loggers import CSVLogger
from pytorch_lightning.utilities.logger import (
_add_prefix,
_convert_params,
_flatten_dict,
_sanitize_callable_params,
_sanitize_params,
_version,
)
def test_convert_params():
"""Test conversion of params to a dict."""
# Test normal dict, make sure it is unchanged
params = {"foo": "bar", 1: 23}
assert type(params) == dict
params = _convert_params(params)
assert type(params) == dict
assert params["foo"] == "bar"
assert params[1] == 23
# Test None conversion
params = None
assert type(params) != dict
params = _convert_params(params)
assert type(params) == dict
assert params == {}
# Test conversion of argparse Namespace
opt = "--max_epochs 1".split(" ")
parser = ArgumentParser()
parser = Trainer.add_argparse_args(parent_parser=parser)
params = parser.parse_args(opt)
assert type(params) == Namespace
params = _convert_params(params)
assert type(params) == dict
assert params["gpus"] is None
def test_flatten_dict():
"""Validate flatten_dict can handle nested dictionaries and argparse Namespace."""
# Test basic dict flattening with custom delimiter
params = {"a": {"b": "c"}}
params = _flatten_dict(params, "--")
assert "a" not in params
assert params["a--b"] == "c"
# Test complex nested dict flattening
params = {"a": {5: {"foo": "bar"}}, "b": 6, "c": {7: [1, 2, 3, 4], 8: "foo", 9: {10: "bar"}}}
params = _flatten_dict(params)
assert "a" not in params
assert params["a/5/foo"] == "bar"
assert params["b"] == 6
assert params["c/7"] == [1, 2, 3, 4]
assert params["c/8"] == "foo"
assert params["c/9/10"] == "bar"
# Test flattening of argparse Namespace
opt = "--max_epochs 1".split(" ")
parser = ArgumentParser()
parser = Trainer.add_argparse_args(parent_parser=parser)
params = parser.parse_args(opt)
wrapping_dict = {"params": params}
params = _flatten_dict(wrapping_dict)
assert type(params) == dict
assert params["params/logger"] is True
assert params["params/gpus"] is None
assert "logger" not in params
assert "gpus" not in params
def test_sanitize_callable_params():
"""Callback function are not serializiable.
Therefore, we get them a chance to return something and if the returned type is not accepted, return None.
"""
opt = "--max_epochs 1".split(" ")
parser = ArgumentParser()
parser = Trainer.add_argparse_args(parent_parser=parser)
params = parser.parse_args(opt)
def return_something():
return "something"
params.something = return_something
def wrapper_something():
return return_something
params.wrapper_something_wo_name = lambda: lambda: "1"
params.wrapper_something = wrapper_something
params = _convert_params(params)
params = _flatten_dict(params)
params = _sanitize_callable_params(params)
assert params["gpus"] is None
assert params["something"] == "something"
assert params["wrapper_something"] == "wrapper_something"
assert params["wrapper_something_wo_name"] == "<lambda>"
def test_sanitize_params():
"""Verify sanitize params converts various types to loggable strings."""
params = {
"float": 0.3,
"int": 1,
"string": "abc",
"bool": True,
"list": [1, 2, 3],
"np_bool": np.bool_(False),
"np_int": np.int_(5),
"np_double": np.double(3.14159),
"namespace": Namespace(foo=3),
"layer": torch.nn.BatchNorm1d,
"tensor": torch.ones(3),
}
params = _sanitize_params(params)
assert params["float"] == 0.3
assert params["int"] == 1
assert params["string"] == "abc"
assert params["bool"] is True
assert params["list"] == "[1, 2, 3]"
assert params["np_bool"] is False
assert params["np_int"] == 5
assert params["np_double"] == 3.14159
assert params["namespace"] == "Namespace(foo=3)"
assert params["layer"] == "<class 'torch.nn.modules.batchnorm.BatchNorm1d'>"
assert torch.equal(params["tensor"], torch.ones(3))
def test_add_prefix():
"""Verify add_prefix modifies the dict keys correctly."""
metrics = {"metric1": 1, "metric2": 2}
metrics = _add_prefix(metrics, "prefix", "-")
assert "prefix-metric1" in metrics
assert "prefix-metric2" in metrics
assert "metric1" not in metrics
assert "metric2" not in metrics
metrics = _add_prefix(metrics, "prefix2", "_")
assert "prefix2_prefix-metric1" in metrics
assert "prefix2_prefix-metric2" in metrics
assert "prefix-metric1" not in metrics
assert "prefix-metric2" not in metrics
assert metrics["prefix2_prefix-metric1"] == 1
assert metrics["prefix2_prefix-metric2"] == 2
def test_version(tmpdir):
"""Verify versions of loggers are concatenated properly."""
logger1 = CSVLogger(tmpdir, version=0)
logger2 = CSVLogger(tmpdir, version=2)
logger3 = CSVLogger(tmpdir, version=1)
logger4 = CSVLogger(tmpdir, version=0)
loggers = [logger1, logger2, logger3, logger4]
version = _version([])
assert version == ""
version = _version([logger3])
assert version == 1
version = _version(loggers)
assert version == "0_2_1"
version = _version(loggers, "-")
assert version == "0-2-1"