lightning/pytorch_lightning/utilities/logger.py

165 lines
5.6 KiB
Python

# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Utilities for loggers."""
from argparse import Namespace
from typing import Any, Dict, Generator, List, MutableMapping, Optional, Union
import numpy as np
import torch
def _convert_params(params: Union[Dict[str, Any], Namespace]) -> Dict[str, Any]:
"""Ensure parameters are a dict or convert to dict if necessary.
Args:
params: Target to be converted to a dictionary
Returns:
params as a dictionary
"""
# in case converting from namespace
if isinstance(params, Namespace):
params = vars(params)
if params is None:
params = {}
return params
def _sanitize_callable_params(params: Dict[str, Any]) -> Dict[str, Any]:
"""Sanitize callable params dict, e.g. ``{'a': <function_**** at 0x****>} -> {'a': 'function_****'}``.
Args:
params: Dictionary containing the hyperparameters
Returns:
dictionary with all callables sanitized
"""
def _sanitize_callable(val: Any) -> Any:
# Give them one chance to return a value. Don't go rabbit hole of recursive call
if callable(val):
try:
_val = val()
if callable(_val):
return val.__name__
return _val
# todo: specify the possible exception
except Exception:
return getattr(val, "__name__", None)
return val
return {key: _sanitize_callable(val) for key, val in params.items()}
def _flatten_dict(params: Dict[Any, Any], delimiter: str = "/") -> Dict[str, Any]:
"""Flatten hierarchical dict, e.g. ``{'a': {'b': 'c'}} -> {'a/b': 'c'}``.
Args:
params: Dictionary containing the hyperparameters
delimiter: Delimiter to express the hierarchy. Defaults to ``'/'``.
Returns:
Flattened dict.
Examples:
>>> _flatten_dict({'a': {'b': 'c'}})
{'a/b': 'c'}
>>> _flatten_dict({'a': {'b': 123}})
{'a/b': 123}
>>> _flatten_dict({5: {'a': 123}})
{'5/a': 123}
"""
def _dict_generator(
input_dict: Any, prefixes: List[Optional[str]] = None
) -> Generator[Any, Optional[List[str]], List[Any]]:
prefixes = prefixes[:] if prefixes else []
if isinstance(input_dict, MutableMapping):
for key, value in input_dict.items():
key = str(key)
if isinstance(value, (MutableMapping, Namespace)):
value = vars(value) if isinstance(value, Namespace) else value
yield from _dict_generator(value, prefixes + [key])
else:
yield prefixes + [key, value if value is not None else str(None)]
else:
yield prefixes + [input_dict if input_dict is None else str(input_dict)]
return {delimiter.join(keys): val for *keys, val in _dict_generator(params)}
def _sanitize_params(params: Dict[str, Any]) -> Dict[str, Any]:
"""Returns params with non-primitvies converted to strings for logging.
>>> params = {"float": 0.3,
... "int": 1,
... "string": "abc",
... "bool": True,
... "list": [1, 2, 3],
... "namespace": Namespace(foo=3),
... "layer": torch.nn.BatchNorm1d}
>>> import pprint
>>> pprint.pprint(_sanitize_params(params)) # doctest: +NORMALIZE_WHITESPACE
{'bool': True,
'float': 0.3,
'int': 1,
'layer': "<class 'torch.nn.modules.batchnorm.BatchNorm1d'>",
'list': '[1, 2, 3]',
'namespace': 'Namespace(foo=3)',
'string': 'abc'}
"""
for k in params.keys():
# convert relevant np scalars to python types first (instead of str)
if isinstance(params[k], (np.bool_, np.integer, np.floating)):
params[k] = params[k].item()
elif type(params[k]) not in [bool, int, float, str, torch.Tensor]:
params[k] = str(params[k])
return params
def _add_prefix(metrics: Dict[str, float], prefix: str, separator: str) -> Dict[str, float]:
"""Insert prefix before each key in a dict, separated by the separator.
Args:
metrics: Dictionary with metric names as keys and measured quantities as values
prefix: Prefix to insert before each key
separator: Separates prefix and original key name
Returns:
Dictionary with prefix and separator inserted before each key
"""
if prefix:
metrics = {f"{prefix}{separator}{k}": v for k, v in metrics.items()}
return metrics
def _name(loggers: List[Any], separator: str = "_") -> str:
if len(loggers) == 1:
return loggers[0].name
else:
# Concatenate names together, removing duplicates and preserving order
return separator.join(dict.fromkeys(str(logger.name) for logger in loggers))
def _version(loggers: List[Any], separator: str = "_") -> Union[int, str]:
if len(loggers) == 1:
return loggers[0].version
else:
# Concatenate versions together, removing duplicates and preserving order
return separator.join(dict.fromkeys(str(logger.version) for logger in loggers))