Decouple utilities from `LightningLoggerBase` (#11484)

Co-authored-by: ananthsub <ananth.subramaniam@gmail.com>
Co-authored-by: Jirka <jirka.borovec@seznam.cz>
This commit is contained in:
Akash Kwatra 2022-02-02 14:29:01 -08:00 committed by GitHub
parent fbc1f9f1d9
commit 115a5d08e8
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
12 changed files with 356 additions and 164 deletions

View File

@ -19,11 +19,10 @@ import operator
from abc import ABC, abstractmethod
from argparse import Namespace
from functools import wraps
from typing import Any, Callable, Dict, Iterable, List, Mapping, MutableMapping, Optional, Sequence, Tuple, Union
from typing import Any, Callable, Dict, Iterable, List, Mapping, Optional, Sequence, Tuple, Union
from weakref import ReferenceType
import numpy as np
import torch
import pytorch_lightning as pl
from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint
@ -178,107 +177,6 @@ class LightningLoggerBase(ABC):
"""
pass
@staticmethod
def _convert_params(params: Union[Dict[str, Any], Namespace]) -> Dict[str, Any]:
# in case converting from namespace
if isinstance(params, Namespace):
params = vars(params)
if params is None:
params = {}
return params
@staticmethod
def _sanitize_callable_params(params: Dict[str, Any]) -> Dict[str, Any]:
"""Sanitize callable params dict, e.g. ``{'a': <function_**** at 0x****>} -> {'a': 'function_****'}``.
Args:
params: Dictionary containing the hyperparameters
Returns:
dictionary with all callables sanitized
"""
def _sanitize_callable(val):
# Give them one chance to return a value. Don't go rabbit hole of recursive call
if isinstance(val, Callable):
try:
_val = val()
if isinstance(_val, Callable):
return val.__name__
return _val
# todo: specify the possible exception
except Exception:
return getattr(val, "__name__", None)
return val
return {key: _sanitize_callable(val) for key, val in params.items()}
@staticmethod
def _flatten_dict(params: Dict[Any, Any], delimiter: str = "/") -> Dict[str, Any]:
"""Flatten hierarchical dict, e.g. ``{'a': {'b': 'c'}} -> {'a/b': 'c'}``.
Args:
params: Dictionary containing the hyperparameters
delimiter: Delimiter to express the hierarchy. Defaults to ``'/'``.
Returns:
Flattened dict.
Examples:
>>> LightningLoggerBase._flatten_dict({'a': {'b': 'c'}})
{'a/b': 'c'}
>>> LightningLoggerBase._flatten_dict({'a': {'b': 123}})
{'a/b': 123}
>>> LightningLoggerBase._flatten_dict({5: {'a': 123}})
{'5/a': 123}
"""
def _dict_generator(input_dict, prefixes=None):
prefixes = prefixes[:] if prefixes else []
if isinstance(input_dict, MutableMapping):
for key, value in input_dict.items():
key = str(key)
if isinstance(value, (MutableMapping, Namespace)):
value = vars(value) if isinstance(value, Namespace) else value
yield from _dict_generator(value, prefixes + [key])
else:
yield prefixes + [key, value if value is not None else str(None)]
else:
yield prefixes + [input_dict if input_dict is None else str(input_dict)]
return {delimiter.join(keys): val for *keys, val in _dict_generator(params)}
@staticmethod
def _sanitize_params(params: Dict[str, Any]) -> Dict[str, Any]:
"""Returns params with non-primitvies converted to strings for logging.
>>> params = {"float": 0.3,
... "int": 1,
... "string": "abc",
... "bool": True,
... "list": [1, 2, 3],
... "namespace": Namespace(foo=3),
... "layer": torch.nn.BatchNorm1d}
>>> import pprint
>>> pprint.pprint(LightningLoggerBase._sanitize_params(params)) # doctest: +NORMALIZE_WHITESPACE
{'bool': True,
'float': 0.3,
'int': 1,
'layer': "<class 'torch.nn.modules.batchnorm.BatchNorm1d'>",
'list': '[1, 2, 3]',
'namespace': 'Namespace(foo=3)',
'string': 'abc'}
"""
for k in params.keys():
# convert relevant np scalars to python types first (instead of str)
if isinstance(params[k], (np.bool_, np.integer, np.floating)):
params[k] = params[k].item()
elif type(params[k]) not in [bool, int, float, str, torch.Tensor]:
params[k] = str(params[k])
return params
@abstractmethod
def log_hyperparams(self, params: argparse.Namespace, *args, **kwargs):
"""Record hyperparameters.
@ -360,12 +258,6 @@ class LightningLoggerBase(ABC):
def version(self) -> Union[int, str]:
"""Return the experiment version."""
def _add_prefix(self, metrics: Dict[str, float]):
if self._prefix:
metrics = {f"{self._prefix}{self.LOGGER_JOIN_CHAR}{k}": v for k, v in metrics.items()}
return metrics
class LoggerCollection(LightningLoggerBase):
"""The :class:`LoggerCollection` class is used to iterate all logging actions over the given `logger_iterable`.

View File

@ -28,6 +28,7 @@ import pytorch_lightning as pl
from pytorch_lightning.loggers.base import LightningLoggerBase, rank_zero_experiment
from pytorch_lightning.utilities import _module_available, rank_zero_only
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.logger import _add_prefix, _convert_params, _flatten_dict
log = logging.getLogger(__name__)
_COMET_AVAILABLE = _module_available("comet_ml")
@ -232,8 +233,8 @@ class CometLogger(LightningLoggerBase):
@rank_zero_only
def log_hyperparams(self, params: Union[Dict[str, Any], Namespace]) -> None:
params = self._convert_params(params)
params = self._flatten_dict(params)
params = _convert_params(params)
params = _flatten_dict(params)
self.experiment.log_parameters(params)
@rank_zero_only
@ -246,7 +247,7 @@ class CometLogger(LightningLoggerBase):
metrics_without_epoch[key] = val.cpu().detach()
epoch = metrics_without_epoch.pop("epoch", None)
metrics_without_epoch = self._add_prefix(metrics_without_epoch)
metrics_without_epoch = _add_prefix(metrics_without_epoch, self._prefix, self.LOGGER_JOIN_CHAR)
self.experiment.log_metrics(metrics_without_epoch, step=step, epoch=epoch)
def reset_experiment(self):

View File

@ -30,6 +30,7 @@ from pytorch_lightning.core.saving import save_hparams_to_yaml
from pytorch_lightning.loggers.base import LightningLoggerBase, rank_zero_experiment
from pytorch_lightning.utilities import rank_zero_warn
from pytorch_lightning.utilities.distributed import rank_zero_only
from pytorch_lightning.utilities.logger import _add_prefix, _convert_params
log = logging.getLogger(__name__)
@ -193,12 +194,12 @@ class CSVLogger(LightningLoggerBase):
@rank_zero_only
def log_hyperparams(self, params: Union[Dict[str, Any], Namespace]) -> None:
params = self._convert_params(params)
params = _convert_params(params)
self.experiment.log_hparams(params)
@rank_zero_only
def log_metrics(self, metrics: Dict[str, float], step: Optional[int] = None) -> None:
metrics = self._add_prefix(metrics)
metrics = _add_prefix(metrics, self._prefix, self.LOGGER_JOIN_CHAR)
self.experiment.log_metrics(metrics, step)
if step is not None and (step + 1) % self._flush_logs_every_n_steps == 0:
self.save()

View File

@ -24,6 +24,7 @@ from typing import Any, Dict, Optional, Union
from pytorch_lightning.loggers.base import LightningLoggerBase, rank_zero_experiment
from pytorch_lightning.utilities import _module_available, rank_zero_only, rank_zero_warn
from pytorch_lightning.utilities.logger import _add_prefix, _convert_params, _flatten_dict
log = logging.getLogger(__name__)
LOCAL_FILE_URI_PREFIX = "file:"
@ -191,8 +192,8 @@ class MLFlowLogger(LightningLoggerBase):
@rank_zero_only
def log_hyperparams(self, params: Union[Dict[str, Any], Namespace]) -> None:
params = self._convert_params(params)
params = self._flatten_dict(params)
params = _convert_params(params)
params = _flatten_dict(params)
for k, v in params.items():
if len(str(v)) > 250:
rank_zero_warn(
@ -206,7 +207,7 @@ class MLFlowLogger(LightningLoggerBase):
def log_metrics(self, metrics: Dict[str, float], step: Optional[int] = None) -> None:
assert rank_zero_only.rank == 0, "experiment tried to log from global_rank != 0"
metrics = self._add_prefix(metrics)
metrics = _add_prefix(metrics, self._prefix, self.LOGGER_JOIN_CHAR)
timestamp_ms = int(time() * 1000)
for k, v in metrics.items():

View File

@ -34,6 +34,7 @@ from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint
from pytorch_lightning.loggers.base import LightningLoggerBase, rank_zero_experiment
from pytorch_lightning.utilities import rank_zero_only
from pytorch_lightning.utilities.imports import _NEPTUNE_AVAILABLE, _NEPTUNE_GREATER_EQUAL_0_9
from pytorch_lightning.utilities.logger import _add_prefix, _convert_params, _sanitize_callable_params
from pytorch_lightning.utilities.model_summary import ModelSummary
if _NEPTUNE_AVAILABLE and _NEPTUNE_GREATER_EQUAL_0_9:
@ -479,8 +480,8 @@ class NeptuneLogger(LightningLoggerBase):
neptune_logger.log_hyperparams(PARAMS)
"""
params = self._convert_params(params)
params = self._sanitize_callable_params(params)
params = _convert_params(params)
params = _sanitize_callable_params(params)
parameters_key = self.PARAMETERS_KEY
parameters_key = self._construct_path_with_prefix(parameters_key)
@ -498,7 +499,7 @@ class NeptuneLogger(LightningLoggerBase):
if rank_zero_only.rank != 0:
raise ValueError("run tried to log from global_rank != 0")
metrics = self._add_prefix(metrics)
metrics = _add_prefix(metrics, self._prefix, self.LOGGER_JOIN_CHAR)
for key, val in metrics.items():
# `step` is ignored because Neptune expects strictly increasing step values which

View File

@ -31,6 +31,8 @@ from pytorch_lightning.core.saving import save_hparams_to_yaml
from pytorch_lightning.loggers.base import LightningLoggerBase, rank_zero_experiment
from pytorch_lightning.utilities import _OMEGACONF_AVAILABLE, rank_zero_only, rank_zero_warn
from pytorch_lightning.utilities.cloud_io import get_filesystem
from pytorch_lightning.utilities.logger import _add_prefix, _convert_params, _flatten_dict
from pytorch_lightning.utilities.logger import _sanitize_params as _utils_sanitize_params
log = logging.getLogger(__name__)
@ -186,7 +188,7 @@ class TensorBoardLogger(LightningLoggerBase):
metrics: Dictionary with metric names as keys and measured quantities as values
"""
params = self._convert_params(params)
params = _convert_params(params)
# store params to output
if _OMEGACONF_AVAILABLE and isinstance(params, Container):
@ -195,7 +197,7 @@ class TensorBoardLogger(LightningLoggerBase):
self.hparams.update(params)
# format params into the suitable for tensorboard
params = self._flatten_dict(params)
params = _flatten_dict(params)
params = self._sanitize_params(params)
if metrics is None:
@ -216,7 +218,7 @@ class TensorBoardLogger(LightningLoggerBase):
def log_metrics(self, metrics: Dict[str, float], step: Optional[int] = None) -> None:
assert rank_zero_only.rank == 0, "experiment tried to log from global_rank != 0"
metrics = self._add_prefix(metrics)
metrics = _add_prefix(metrics, self._prefix, self.LOGGER_JOIN_CHAR)
for k, v in metrics.items():
if isinstance(v, torch.Tensor):
@ -311,7 +313,7 @@ class TensorBoardLogger(LightningLoggerBase):
@staticmethod
def _sanitize_params(params: Dict[str, Any]) -> Dict[str, Any]:
params = LightningLoggerBase._sanitize_params(params)
params = _utils_sanitize_params(params)
# logging of arrays with dimension > 1 is not supported, sanitize as string
return {k: str(v) if isinstance(v, (torch.Tensor, np.ndarray)) and v.ndim > 1 else v for k, v in params.items()}

View File

@ -22,6 +22,7 @@ import pytorch_lightning as pl
from pytorch_lightning.loggers.base import LightningLoggerBase, rank_zero_experiment
from pytorch_lightning.utilities import _module_available, rank_zero_deprecation, rank_zero_warn
from pytorch_lightning.utilities.distributed import rank_zero_only
from pytorch_lightning.utilities.logger import _add_prefix, _convert_params, _flatten_dict
_TESTTUBE_AVAILABLE = _module_available("test_tube")
@ -152,14 +153,14 @@ class TestTubeLogger(LightningLoggerBase):
def log_hyperparams(self, params: Union[Dict[str, Any], Namespace]) -> None:
# TODO: HACK figure out where this is being set to true
self.experiment.debug = self.debug
params = self._convert_params(params)
params = self._flatten_dict(params)
params = _convert_params(params)
params = _flatten_dict(params)
self.experiment.argparse(Namespace(**params))
@rank_zero_only
def log_metrics(self, metrics: Dict[str, float], step: Optional[int] = None) -> None:
# TODO: HACK figure out where this is being set to true
metrics = self._add_prefix(metrics)
metrics = _add_prefix(metrics, self._prefix, self.LOGGER_JOIN_CHAR)
self.experiment.debug = self.debug
self.experiment.log(metrics, global_step=step)

View File

@ -29,6 +29,7 @@ from pytorch_lightning.loggers.base import LightningLoggerBase, rank_zero_experi
from pytorch_lightning.utilities import _module_available, rank_zero_only
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.imports import _compare_version
from pytorch_lightning.utilities.logger import _add_prefix, _convert_params, _flatten_dict, _sanitize_callable_params
from pytorch_lightning.utilities.warnings import rank_zero_warn
_WANDB_AVAILABLE = _module_available("wandb")
@ -356,16 +357,16 @@ class WandbLogger(LightningLoggerBase):
@rank_zero_only
def log_hyperparams(self, params: Union[Dict[str, Any], Namespace]) -> None:
params = self._convert_params(params)
params = self._flatten_dict(params)
params = self._sanitize_callable_params(params)
params = _convert_params(params)
params = _flatten_dict(params)
params = _sanitize_callable_params(params)
self.experiment.config.update(params, allow_val_change=True)
@rank_zero_only
def log_metrics(self, metrics: Dict[str, float], step: Optional[int] = None) -> None:
assert rank_zero_only.rank == 0, "experiment tried to log from global_rank != 0"
metrics = self._add_prefix(metrics)
metrics = _add_prefix(metrics, self._prefix, self.LOGGER_JOIN_CHAR)
if step is not None:
self.experiment.log({**metrics, "trainer/global_step": step})
else:

View File

@ -0,0 +1,148 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Utilities for loggers."""
from argparse import Namespace
from typing import Any, Dict, Generator, List, MutableMapping, Optional, Union
import numpy as np
import torch
def _convert_params(params: Union[Dict[str, Any], Namespace]) -> Dict[str, Any]:
"""Ensure parameters are a dict or convert to dict if neccesary.
Args:
params: Target to be converted to a dictionary
Returns:
params as a dictionary
"""
# in case converting from namespace
if isinstance(params, Namespace):
params = vars(params)
if params is None:
params = {}
return params
def _sanitize_callable_params(params: Dict[str, Any]) -> Dict[str, Any]:
"""Sanitize callable params dict, e.g. ``{'a': <function_**** at 0x****>} -> {'a': 'function_****'}``.
Args:
params: Dictionary containing the hyperparameters
Returns:
dictionary with all callables sanitized
"""
def _sanitize_callable(val: Any) -> Any:
# Give them one chance to return a value. Don't go rabbit hole of recursive call
if callable(val):
try:
_val = val()
if callable(_val):
return val.__name__
return _val
# todo: specify the possible exception
except Exception:
return getattr(val, "__name__", None)
return val
return {key: _sanitize_callable(val) for key, val in params.items()}
def _flatten_dict(params: Dict[Any, Any], delimiter: str = "/") -> Dict[str, Any]:
"""Flatten hierarchical dict, e.g. ``{'a': {'b': 'c'}} -> {'a/b': 'c'}``.
Args:
params: Dictionary containing the hyperparameters
delimiter: Delimiter to express the hierarchy. Defaults to ``'/'``.
Returns:
Flattened dict.
Examples:
>>> _flatten_dict({'a': {'b': 'c'}})
{'a/b': 'c'}
>>> _flatten_dict({'a': {'b': 123}})
{'a/b': 123}
>>> _flatten_dict({5: {'a': 123}})
{'5/a': 123}
"""
def _dict_generator(
input_dict: Any, prefixes: List[Optional[str]] = None
) -> Generator[Any, Optional[List[str]], List[Any]]:
prefixes = prefixes[:] if prefixes else []
if isinstance(input_dict, MutableMapping):
for key, value in input_dict.items():
key = str(key)
if isinstance(value, (MutableMapping, Namespace)):
value = vars(value) if isinstance(value, Namespace) else value
yield from _dict_generator(value, prefixes + [key])
else:
yield prefixes + [key, value if value is not None else str(None)]
else:
yield prefixes + [input_dict if input_dict is None else str(input_dict)]
return {delimiter.join(keys): val for *keys, val in _dict_generator(params)}
def _sanitize_params(params: Dict[str, Any]) -> Dict[str, Any]:
"""Returns params with non-primitvies converted to strings for logging.
>>> params = {"float": 0.3,
... "int": 1,
... "string": "abc",
... "bool": True,
... "list": [1, 2, 3],
... "namespace": Namespace(foo=3),
... "layer": torch.nn.BatchNorm1d}
>>> import pprint
>>> pprint.pprint(_sanitize_params(params)) # doctest: +NORMALIZE_WHITESPACE
{'bool': True,
'float': 0.3,
'int': 1,
'layer': "<class 'torch.nn.modules.batchnorm.BatchNorm1d'>",
'list': '[1, 2, 3]',
'namespace': 'Namespace(foo=3)',
'string': 'abc'}
"""
for k in params.keys():
# convert relevant np scalars to python types first (instead of str)
if isinstance(params[k], (np.bool_, np.integer, np.floating)):
params[k] = params[k].item()
elif type(params[k]) not in [bool, int, float, str, torch.Tensor]:
params[k] = str(params[k])
return params
def _add_prefix(metrics: Dict[str, float], prefix: str, separator: str) -> Dict[str, float]:
"""Insert prefix before each key in a dict, separated by the separator.
Args:
metrics: Dictionary with metric names as keys and measured quantities as values
prefix: Prefix to insert before each key
separator: Separates prefix and original key name
Returns:
Dictionary with prefix and separator inserted before each key
"""
if prefix:
metrics = {f"{prefix}{separator}{k}": v for k, v in metrics.items()}
return metrics

View File

@ -26,6 +26,7 @@ from pytorch_lightning.loggers import LightningLoggerBase, LoggerCollection, Ten
from pytorch_lightning.loggers.base import DummyExperiment, DummyLogger
from pytorch_lightning.utilities import rank_zero_only
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.logger import _convert_params, _sanitize_params
from tests.helpers.boring_model import BoringDataModule, BoringModel
@ -290,8 +291,8 @@ def test_np_sanitization():
@rank_zero_only
def log_hyperparams(self, params):
params = self._convert_params(params)
params = self._sanitize_params(params)
params = _convert_params(params)
params = _sanitize_params(params)
self.logged_params = params
logger = CustomParamsLogger()

View File

@ -13,7 +13,6 @@
# limitations under the License.
import os
import pickle
from argparse import ArgumentParser
from unittest import mock
import pytest
@ -279,36 +278,6 @@ def test_wandb_log_media(wandb, tmpdir):
wandb.init().log.assert_called_once_with({"samples": wandb.Table(), "trainer/global_step": 5})
def test_wandb_sanitize_callable_params(tmpdir):
"""Callback function are not serializiable.
Therefore, we get them a chance to return something and if the returned type is not accepted, return None.
"""
opt = "--max_epochs 1".split(" ")
parser = ArgumentParser()
parser = Trainer.add_argparse_args(parent_parser=parser)
params = parser.parse_args(opt)
def return_something():
return "something"
params.something = return_something
def wrapper_something():
return return_something
params.wrapper_something_wo_name = lambda: lambda: "1"
params.wrapper_something = wrapper_something
params = WandbLogger._convert_params(params)
params = WandbLogger._flatten_dict(params)
params = WandbLogger._sanitize_callable_params(params)
assert params["gpus"] == "None"
assert params["something"] == "something"
assert params["wrapper_something"] == "wrapper_something"
assert params["wrapper_something_wo_name"] == "<lambda>"
@mock.patch("pytorch_lightning.loggers.wandb.wandb")
def test_wandb_logger_offline_log_model(wandb, tmpdir):
"""Test that log_model=True raises an error in offline mode."""

View File

@ -0,0 +1,174 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from argparse import ArgumentParser, Namespace
import numpy as np
import torch
from pytorch_lightning import Trainer
from pytorch_lightning.utilities.logger import (
_add_prefix,
_convert_params,
_flatten_dict,
_sanitize_callable_params,
_sanitize_params,
)
def test_convert_params():
"""Test conversion of params to a dict."""
# Test normal dict, make sure it is unchanged
params = {"foo": "bar", 1: 23}
assert type(params) == dict
params = _convert_params(params)
assert type(params) == dict
assert params["foo"] == "bar"
assert params[1] == 23
# Test None conversion
params = None
assert type(params) != dict
params = _convert_params(params)
assert type(params) == dict
assert params == {}
# Test conversion of argparse Namespace
opt = "--max_epochs 1".split(" ")
parser = ArgumentParser()
parser = Trainer.add_argparse_args(parent_parser=parser)
params = parser.parse_args(opt)
assert type(params) == Namespace
params = _convert_params(params)
assert type(params) == dict
assert params["gpus"] is None
def test_flatten_dict():
"""Validate flatten_dict can handle nested dictionaries and argparse Namespace."""
# Test basic dict flattening with custom delimiter
params = {"a": {"b": "c"}}
params = _flatten_dict(params, "--")
assert "a" not in params
assert params["a--b"] == "c"
# Test complex nested dict flattening
params = {"a": {5: {"foo": "bar"}}, "b": 6, "c": {7: [1, 2, 3, 4], 8: "foo", 9: {10: "bar"}}}
params = _flatten_dict(params)
assert "a" not in params
assert params["a/5/foo"] == "bar"
assert params["b"] == 6
assert params["c/7"] == [1, 2, 3, 4]
assert params["c/8"] == "foo"
assert params["c/9/10"] == "bar"
# Test flattening of argparse Namespace
opt = "--max_epochs 1".split(" ")
parser = ArgumentParser()
parser = Trainer.add_argparse_args(parent_parser=parser)
params = parser.parse_args(opt)
wrapping_dict = {"params": params}
params = _flatten_dict(wrapping_dict)
assert type(params) == dict
assert params["params/logger"] is True
assert params["params/gpus"] == "None"
assert "logger" not in params
assert "gpus" not in params
def test_sanitize_callable_params():
"""Callback function are not serializiable.
Therefore, we get them a chance to return something and if the returned type is not accepted, return None.
"""
opt = "--max_epochs 1".split(" ")
parser = ArgumentParser()
parser = Trainer.add_argparse_args(parent_parser=parser)
params = parser.parse_args(opt)
def return_something():
return "something"
params.something = return_something
def wrapper_something():
return return_something
params.wrapper_something_wo_name = lambda: lambda: "1"
params.wrapper_something = wrapper_something
params = _convert_params(params)
params = _flatten_dict(params)
params = _sanitize_callable_params(params)
assert params["gpus"] == "None"
assert params["something"] == "something"
assert params["wrapper_something"] == "wrapper_something"
assert params["wrapper_something_wo_name"] == "<lambda>"
def test_sanitize_params():
"""Verify sanitize params converts various types to loggable strings."""
params = {
"float": 0.3,
"int": 1,
"string": "abc",
"bool": True,
"list": [1, 2, 3],
"np_bool": np.bool_(False),
"np_int": np.int_(5),
"np_double": np.double(3.14159),
"namespace": Namespace(foo=3),
"layer": torch.nn.BatchNorm1d,
"tensor": torch.ones(3),
}
params = _sanitize_params(params)
assert params["float"] == 0.3
assert params["int"] == 1
assert params["string"] == "abc"
assert params["bool"] is True
assert params["list"] == "[1, 2, 3]"
assert params["np_bool"] is False
assert params["np_int"] == 5
assert params["np_double"] == 3.14159
assert params["namespace"] == "Namespace(foo=3)"
assert params["layer"] == "<class 'torch.nn.modules.batchnorm.BatchNorm1d'>"
assert torch.equal(params["tensor"], torch.ones(3))
def test_add_prefix():
"""Verify add_prefix modifies the dict keys correctly."""
metrics = {"metric1": 1, "metric2": 2}
metrics = _add_prefix(metrics, "prefix", "-")
assert "prefix-metric1" in metrics
assert "prefix-metric2" in metrics
assert "metric1" not in metrics
assert "metric2" not in metrics
metrics = _add_prefix(metrics, "prefix2", "_")
assert "prefix2_prefix-metric1" in metrics
assert "prefix2_prefix-metric2" in metrics
assert "prefix-metric1" not in metrics
assert "prefix-metric2" not in metrics
assert metrics["prefix2_prefix-metric1"] == 1
assert metrics["prefix2_prefix-metric2"] == 2