diff --git a/pytorch_lightning/loggers/base.py b/pytorch_lightning/loggers/base.py index 0cbbc134ca..a104aa83b1 100644 --- a/pytorch_lightning/loggers/base.py +++ b/pytorch_lightning/loggers/base.py @@ -19,11 +19,10 @@ import operator from abc import ABC, abstractmethod from argparse import Namespace from functools import wraps -from typing import Any, Callable, Dict, Iterable, List, Mapping, MutableMapping, Optional, Sequence, Tuple, Union +from typing import Any, Callable, Dict, Iterable, List, Mapping, Optional, Sequence, Tuple, Union from weakref import ReferenceType import numpy as np -import torch import pytorch_lightning as pl from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint @@ -178,107 +177,6 @@ class LightningLoggerBase(ABC): """ pass - @staticmethod - def _convert_params(params: Union[Dict[str, Any], Namespace]) -> Dict[str, Any]: - # in case converting from namespace - if isinstance(params, Namespace): - params = vars(params) - - if params is None: - params = {} - - return params - - @staticmethod - def _sanitize_callable_params(params: Dict[str, Any]) -> Dict[str, Any]: - """Sanitize callable params dict, e.g. ``{'a': } -> {'a': 'function_****'}``. - - Args: - params: Dictionary containing the hyperparameters - - Returns: - dictionary with all callables sanitized - """ - - def _sanitize_callable(val): - # Give them one chance to return a value. Don't go rabbit hole of recursive call - if isinstance(val, Callable): - try: - _val = val() - if isinstance(_val, Callable): - return val.__name__ - return _val - # todo: specify the possible exception - except Exception: - return getattr(val, "__name__", None) - return val - - return {key: _sanitize_callable(val) for key, val in params.items()} - - @staticmethod - def _flatten_dict(params: Dict[Any, Any], delimiter: str = "/") -> Dict[str, Any]: - """Flatten hierarchical dict, e.g. ``{'a': {'b': 'c'}} -> {'a/b': 'c'}``. - - Args: - params: Dictionary containing the hyperparameters - delimiter: Delimiter to express the hierarchy. Defaults to ``'/'``. - - Returns: - Flattened dict. - - Examples: - >>> LightningLoggerBase._flatten_dict({'a': {'b': 'c'}}) - {'a/b': 'c'} - >>> LightningLoggerBase._flatten_dict({'a': {'b': 123}}) - {'a/b': 123} - >>> LightningLoggerBase._flatten_dict({5: {'a': 123}}) - {'5/a': 123} - """ - - def _dict_generator(input_dict, prefixes=None): - prefixes = prefixes[:] if prefixes else [] - if isinstance(input_dict, MutableMapping): - for key, value in input_dict.items(): - key = str(key) - if isinstance(value, (MutableMapping, Namespace)): - value = vars(value) if isinstance(value, Namespace) else value - yield from _dict_generator(value, prefixes + [key]) - else: - yield prefixes + [key, value if value is not None else str(None)] - else: - yield prefixes + [input_dict if input_dict is None else str(input_dict)] - - return {delimiter.join(keys): val for *keys, val in _dict_generator(params)} - - @staticmethod - def _sanitize_params(params: Dict[str, Any]) -> Dict[str, Any]: - """Returns params with non-primitvies converted to strings for logging. - - >>> params = {"float": 0.3, - ... "int": 1, - ... "string": "abc", - ... "bool": True, - ... "list": [1, 2, 3], - ... "namespace": Namespace(foo=3), - ... "layer": torch.nn.BatchNorm1d} - >>> import pprint - >>> pprint.pprint(LightningLoggerBase._sanitize_params(params)) # doctest: +NORMALIZE_WHITESPACE - {'bool': True, - 'float': 0.3, - 'int': 1, - 'layer': "", - 'list': '[1, 2, 3]', - 'namespace': 'Namespace(foo=3)', - 'string': 'abc'} - """ - for k in params.keys(): - # convert relevant np scalars to python types first (instead of str) - if isinstance(params[k], (np.bool_, np.integer, np.floating)): - params[k] = params[k].item() - elif type(params[k]) not in [bool, int, float, str, torch.Tensor]: - params[k] = str(params[k]) - return params - @abstractmethod def log_hyperparams(self, params: argparse.Namespace, *args, **kwargs): """Record hyperparameters. @@ -360,12 +258,6 @@ class LightningLoggerBase(ABC): def version(self) -> Union[int, str]: """Return the experiment version.""" - def _add_prefix(self, metrics: Dict[str, float]): - if self._prefix: - metrics = {f"{self._prefix}{self.LOGGER_JOIN_CHAR}{k}": v for k, v in metrics.items()} - - return metrics - class LoggerCollection(LightningLoggerBase): """The :class:`LoggerCollection` class is used to iterate all logging actions over the given `logger_iterable`. diff --git a/pytorch_lightning/loggers/comet.py b/pytorch_lightning/loggers/comet.py index 69db7cdf73..7fe57b243b 100644 --- a/pytorch_lightning/loggers/comet.py +++ b/pytorch_lightning/loggers/comet.py @@ -28,6 +28,7 @@ import pytorch_lightning as pl from pytorch_lightning.loggers.base import LightningLoggerBase, rank_zero_experiment from pytorch_lightning.utilities import _module_available, rank_zero_only from pytorch_lightning.utilities.exceptions import MisconfigurationException +from pytorch_lightning.utilities.logger import _add_prefix, _convert_params, _flatten_dict log = logging.getLogger(__name__) _COMET_AVAILABLE = _module_available("comet_ml") @@ -232,8 +233,8 @@ class CometLogger(LightningLoggerBase): @rank_zero_only def log_hyperparams(self, params: Union[Dict[str, Any], Namespace]) -> None: - params = self._convert_params(params) - params = self._flatten_dict(params) + params = _convert_params(params) + params = _flatten_dict(params) self.experiment.log_parameters(params) @rank_zero_only @@ -246,7 +247,7 @@ class CometLogger(LightningLoggerBase): metrics_without_epoch[key] = val.cpu().detach() epoch = metrics_without_epoch.pop("epoch", None) - metrics_without_epoch = self._add_prefix(metrics_without_epoch) + metrics_without_epoch = _add_prefix(metrics_without_epoch, self._prefix, self.LOGGER_JOIN_CHAR) self.experiment.log_metrics(metrics_without_epoch, step=step, epoch=epoch) def reset_experiment(self): diff --git a/pytorch_lightning/loggers/csv_logs.py b/pytorch_lightning/loggers/csv_logs.py index 454a17905c..c6e70b580c 100644 --- a/pytorch_lightning/loggers/csv_logs.py +++ b/pytorch_lightning/loggers/csv_logs.py @@ -30,6 +30,7 @@ from pytorch_lightning.core.saving import save_hparams_to_yaml from pytorch_lightning.loggers.base import LightningLoggerBase, rank_zero_experiment from pytorch_lightning.utilities import rank_zero_warn from pytorch_lightning.utilities.distributed import rank_zero_only +from pytorch_lightning.utilities.logger import _add_prefix, _convert_params log = logging.getLogger(__name__) @@ -193,12 +194,12 @@ class CSVLogger(LightningLoggerBase): @rank_zero_only def log_hyperparams(self, params: Union[Dict[str, Any], Namespace]) -> None: - params = self._convert_params(params) + params = _convert_params(params) self.experiment.log_hparams(params) @rank_zero_only def log_metrics(self, metrics: Dict[str, float], step: Optional[int] = None) -> None: - metrics = self._add_prefix(metrics) + metrics = _add_prefix(metrics, self._prefix, self.LOGGER_JOIN_CHAR) self.experiment.log_metrics(metrics, step) if step is not None and (step + 1) % self._flush_logs_every_n_steps == 0: self.save() diff --git a/pytorch_lightning/loggers/mlflow.py b/pytorch_lightning/loggers/mlflow.py index b349ced6ec..0548599329 100644 --- a/pytorch_lightning/loggers/mlflow.py +++ b/pytorch_lightning/loggers/mlflow.py @@ -24,6 +24,7 @@ from typing import Any, Dict, Optional, Union from pytorch_lightning.loggers.base import LightningLoggerBase, rank_zero_experiment from pytorch_lightning.utilities import _module_available, rank_zero_only, rank_zero_warn +from pytorch_lightning.utilities.logger import _add_prefix, _convert_params, _flatten_dict log = logging.getLogger(__name__) LOCAL_FILE_URI_PREFIX = "file:" @@ -191,8 +192,8 @@ class MLFlowLogger(LightningLoggerBase): @rank_zero_only def log_hyperparams(self, params: Union[Dict[str, Any], Namespace]) -> None: - params = self._convert_params(params) - params = self._flatten_dict(params) + params = _convert_params(params) + params = _flatten_dict(params) for k, v in params.items(): if len(str(v)) > 250: rank_zero_warn( @@ -206,7 +207,7 @@ class MLFlowLogger(LightningLoggerBase): def log_metrics(self, metrics: Dict[str, float], step: Optional[int] = None) -> None: assert rank_zero_only.rank == 0, "experiment tried to log from global_rank != 0" - metrics = self._add_prefix(metrics) + metrics = _add_prefix(metrics, self._prefix, self.LOGGER_JOIN_CHAR) timestamp_ms = int(time() * 1000) for k, v in metrics.items(): diff --git a/pytorch_lightning/loggers/neptune.py b/pytorch_lightning/loggers/neptune.py index 7895eb0c75..96fa45b18e 100644 --- a/pytorch_lightning/loggers/neptune.py +++ b/pytorch_lightning/loggers/neptune.py @@ -34,6 +34,7 @@ from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint from pytorch_lightning.loggers.base import LightningLoggerBase, rank_zero_experiment from pytorch_lightning.utilities import rank_zero_only from pytorch_lightning.utilities.imports import _NEPTUNE_AVAILABLE, _NEPTUNE_GREATER_EQUAL_0_9 +from pytorch_lightning.utilities.logger import _add_prefix, _convert_params, _sanitize_callable_params from pytorch_lightning.utilities.model_summary import ModelSummary if _NEPTUNE_AVAILABLE and _NEPTUNE_GREATER_EQUAL_0_9: @@ -479,8 +480,8 @@ class NeptuneLogger(LightningLoggerBase): neptune_logger.log_hyperparams(PARAMS) """ - params = self._convert_params(params) - params = self._sanitize_callable_params(params) + params = _convert_params(params) + params = _sanitize_callable_params(params) parameters_key = self.PARAMETERS_KEY parameters_key = self._construct_path_with_prefix(parameters_key) @@ -498,7 +499,7 @@ class NeptuneLogger(LightningLoggerBase): if rank_zero_only.rank != 0: raise ValueError("run tried to log from global_rank != 0") - metrics = self._add_prefix(metrics) + metrics = _add_prefix(metrics, self._prefix, self.LOGGER_JOIN_CHAR) for key, val in metrics.items(): # `step` is ignored because Neptune expects strictly increasing step values which diff --git a/pytorch_lightning/loggers/tensorboard.py b/pytorch_lightning/loggers/tensorboard.py index 6767b56bc4..05e94bbc0a 100644 --- a/pytorch_lightning/loggers/tensorboard.py +++ b/pytorch_lightning/loggers/tensorboard.py @@ -31,6 +31,8 @@ from pytorch_lightning.core.saving import save_hparams_to_yaml from pytorch_lightning.loggers.base import LightningLoggerBase, rank_zero_experiment from pytorch_lightning.utilities import _OMEGACONF_AVAILABLE, rank_zero_only, rank_zero_warn from pytorch_lightning.utilities.cloud_io import get_filesystem +from pytorch_lightning.utilities.logger import _add_prefix, _convert_params, _flatten_dict +from pytorch_lightning.utilities.logger import _sanitize_params as _utils_sanitize_params log = logging.getLogger(__name__) @@ -186,7 +188,7 @@ class TensorBoardLogger(LightningLoggerBase): metrics: Dictionary with metric names as keys and measured quantities as values """ - params = self._convert_params(params) + params = _convert_params(params) # store params to output if _OMEGACONF_AVAILABLE and isinstance(params, Container): @@ -195,7 +197,7 @@ class TensorBoardLogger(LightningLoggerBase): self.hparams.update(params) # format params into the suitable for tensorboard - params = self._flatten_dict(params) + params = _flatten_dict(params) params = self._sanitize_params(params) if metrics is None: @@ -216,7 +218,7 @@ class TensorBoardLogger(LightningLoggerBase): def log_metrics(self, metrics: Dict[str, float], step: Optional[int] = None) -> None: assert rank_zero_only.rank == 0, "experiment tried to log from global_rank != 0" - metrics = self._add_prefix(metrics) + metrics = _add_prefix(metrics, self._prefix, self.LOGGER_JOIN_CHAR) for k, v in metrics.items(): if isinstance(v, torch.Tensor): @@ -311,7 +313,7 @@ class TensorBoardLogger(LightningLoggerBase): @staticmethod def _sanitize_params(params: Dict[str, Any]) -> Dict[str, Any]: - params = LightningLoggerBase._sanitize_params(params) + params = _utils_sanitize_params(params) # logging of arrays with dimension > 1 is not supported, sanitize as string return {k: str(v) if isinstance(v, (torch.Tensor, np.ndarray)) and v.ndim > 1 else v for k, v in params.items()} diff --git a/pytorch_lightning/loggers/test_tube.py b/pytorch_lightning/loggers/test_tube.py index b0eb42069e..df158c8253 100644 --- a/pytorch_lightning/loggers/test_tube.py +++ b/pytorch_lightning/loggers/test_tube.py @@ -22,6 +22,7 @@ import pytorch_lightning as pl from pytorch_lightning.loggers.base import LightningLoggerBase, rank_zero_experiment from pytorch_lightning.utilities import _module_available, rank_zero_deprecation, rank_zero_warn from pytorch_lightning.utilities.distributed import rank_zero_only +from pytorch_lightning.utilities.logger import _add_prefix, _convert_params, _flatten_dict _TESTTUBE_AVAILABLE = _module_available("test_tube") @@ -152,14 +153,14 @@ class TestTubeLogger(LightningLoggerBase): def log_hyperparams(self, params: Union[Dict[str, Any], Namespace]) -> None: # TODO: HACK figure out where this is being set to true self.experiment.debug = self.debug - params = self._convert_params(params) - params = self._flatten_dict(params) + params = _convert_params(params) + params = _flatten_dict(params) self.experiment.argparse(Namespace(**params)) @rank_zero_only def log_metrics(self, metrics: Dict[str, float], step: Optional[int] = None) -> None: # TODO: HACK figure out where this is being set to true - metrics = self._add_prefix(metrics) + metrics = _add_prefix(metrics, self._prefix, self.LOGGER_JOIN_CHAR) self.experiment.debug = self.debug self.experiment.log(metrics, global_step=step) diff --git a/pytorch_lightning/loggers/wandb.py b/pytorch_lightning/loggers/wandb.py index 4c09d6591f..ee1b1b9e94 100644 --- a/pytorch_lightning/loggers/wandb.py +++ b/pytorch_lightning/loggers/wandb.py @@ -29,6 +29,7 @@ from pytorch_lightning.loggers.base import LightningLoggerBase, rank_zero_experi from pytorch_lightning.utilities import _module_available, rank_zero_only from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.imports import _compare_version +from pytorch_lightning.utilities.logger import _add_prefix, _convert_params, _flatten_dict, _sanitize_callable_params from pytorch_lightning.utilities.warnings import rank_zero_warn _WANDB_AVAILABLE = _module_available("wandb") @@ -356,16 +357,16 @@ class WandbLogger(LightningLoggerBase): @rank_zero_only def log_hyperparams(self, params: Union[Dict[str, Any], Namespace]) -> None: - params = self._convert_params(params) - params = self._flatten_dict(params) - params = self._sanitize_callable_params(params) + params = _convert_params(params) + params = _flatten_dict(params) + params = _sanitize_callable_params(params) self.experiment.config.update(params, allow_val_change=True) @rank_zero_only def log_metrics(self, metrics: Dict[str, float], step: Optional[int] = None) -> None: assert rank_zero_only.rank == 0, "experiment tried to log from global_rank != 0" - metrics = self._add_prefix(metrics) + metrics = _add_prefix(metrics, self._prefix, self.LOGGER_JOIN_CHAR) if step is not None: self.experiment.log({**metrics, "trainer/global_step": step}) else: diff --git a/pytorch_lightning/utilities/logger.py b/pytorch_lightning/utilities/logger.py new file mode 100644 index 0000000000..62c6b78f93 --- /dev/null +++ b/pytorch_lightning/utilities/logger.py @@ -0,0 +1,148 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Utilities for loggers.""" + +from argparse import Namespace +from typing import Any, Dict, Generator, List, MutableMapping, Optional, Union + +import numpy as np +import torch + + +def _convert_params(params: Union[Dict[str, Any], Namespace]) -> Dict[str, Any]: + """Ensure parameters are a dict or convert to dict if neccesary. + Args: + params: Target to be converted to a dictionary + + Returns: + params as a dictionary + + """ + # in case converting from namespace + if isinstance(params, Namespace): + params = vars(params) + + if params is None: + params = {} + + return params + + +def _sanitize_callable_params(params: Dict[str, Any]) -> Dict[str, Any]: + """Sanitize callable params dict, e.g. ``{'a': } -> {'a': 'function_****'}``. + + Args: + params: Dictionary containing the hyperparameters + + Returns: + dictionary with all callables sanitized + """ + + def _sanitize_callable(val: Any) -> Any: + # Give them one chance to return a value. Don't go rabbit hole of recursive call + if callable(val): + try: + _val = val() + if callable(_val): + return val.__name__ + return _val + # todo: specify the possible exception + except Exception: + return getattr(val, "__name__", None) + return val + + return {key: _sanitize_callable(val) for key, val in params.items()} + + +def _flatten_dict(params: Dict[Any, Any], delimiter: str = "/") -> Dict[str, Any]: + """Flatten hierarchical dict, e.g. ``{'a': {'b': 'c'}} -> {'a/b': 'c'}``. + + Args: + params: Dictionary containing the hyperparameters + delimiter: Delimiter to express the hierarchy. Defaults to ``'/'``. + + Returns: + Flattened dict. + + Examples: + >>> _flatten_dict({'a': {'b': 'c'}}) + {'a/b': 'c'} + >>> _flatten_dict({'a': {'b': 123}}) + {'a/b': 123} + >>> _flatten_dict({5: {'a': 123}}) + {'5/a': 123} + """ + + def _dict_generator( + input_dict: Any, prefixes: List[Optional[str]] = None + ) -> Generator[Any, Optional[List[str]], List[Any]]: + prefixes = prefixes[:] if prefixes else [] + if isinstance(input_dict, MutableMapping): + for key, value in input_dict.items(): + key = str(key) + if isinstance(value, (MutableMapping, Namespace)): + value = vars(value) if isinstance(value, Namespace) else value + yield from _dict_generator(value, prefixes + [key]) + else: + yield prefixes + [key, value if value is not None else str(None)] + else: + yield prefixes + [input_dict if input_dict is None else str(input_dict)] + + return {delimiter.join(keys): val for *keys, val in _dict_generator(params)} + + +def _sanitize_params(params: Dict[str, Any]) -> Dict[str, Any]: + """Returns params with non-primitvies converted to strings for logging. + + >>> params = {"float": 0.3, + ... "int": 1, + ... "string": "abc", + ... "bool": True, + ... "list": [1, 2, 3], + ... "namespace": Namespace(foo=3), + ... "layer": torch.nn.BatchNorm1d} + >>> import pprint + >>> pprint.pprint(_sanitize_params(params)) # doctest: +NORMALIZE_WHITESPACE + {'bool': True, + 'float': 0.3, + 'int': 1, + 'layer': "", + 'list': '[1, 2, 3]', + 'namespace': 'Namespace(foo=3)', + 'string': 'abc'} + """ + for k in params.keys(): + # convert relevant np scalars to python types first (instead of str) + if isinstance(params[k], (np.bool_, np.integer, np.floating)): + params[k] = params[k].item() + elif type(params[k]) not in [bool, int, float, str, torch.Tensor]: + params[k] = str(params[k]) + return params + + +def _add_prefix(metrics: Dict[str, float], prefix: str, separator: str) -> Dict[str, float]: + """Insert prefix before each key in a dict, separated by the separator. + + Args: + metrics: Dictionary with metric names as keys and measured quantities as values + prefix: Prefix to insert before each key + separator: Separates prefix and original key name + + Returns: + Dictionary with prefix and separator inserted before each key + """ + if prefix: + metrics = {f"{prefix}{separator}{k}": v for k, v in metrics.items()} + + return metrics diff --git a/tests/loggers/test_base.py b/tests/loggers/test_base.py index cb5721d6a9..eb7dd949e5 100644 --- a/tests/loggers/test_base.py +++ b/tests/loggers/test_base.py @@ -26,6 +26,7 @@ from pytorch_lightning.loggers import LightningLoggerBase, LoggerCollection, Ten from pytorch_lightning.loggers.base import DummyExperiment, DummyLogger from pytorch_lightning.utilities import rank_zero_only from pytorch_lightning.utilities.exceptions import MisconfigurationException +from pytorch_lightning.utilities.logger import _convert_params, _sanitize_params from tests.helpers.boring_model import BoringDataModule, BoringModel @@ -290,8 +291,8 @@ def test_np_sanitization(): @rank_zero_only def log_hyperparams(self, params): - params = self._convert_params(params) - params = self._sanitize_params(params) + params = _convert_params(params) + params = _sanitize_params(params) self.logged_params = params logger = CustomParamsLogger() diff --git a/tests/loggers/test_wandb.py b/tests/loggers/test_wandb.py index 85b20c5624..705fe8f4b3 100644 --- a/tests/loggers/test_wandb.py +++ b/tests/loggers/test_wandb.py @@ -13,7 +13,6 @@ # limitations under the License. import os import pickle -from argparse import ArgumentParser from unittest import mock import pytest @@ -279,36 +278,6 @@ def test_wandb_log_media(wandb, tmpdir): wandb.init().log.assert_called_once_with({"samples": wandb.Table(), "trainer/global_step": 5}) -def test_wandb_sanitize_callable_params(tmpdir): - """Callback function are not serializiable. - - Therefore, we get them a chance to return something and if the returned type is not accepted, return None. - """ - opt = "--max_epochs 1".split(" ") - parser = ArgumentParser() - parser = Trainer.add_argparse_args(parent_parser=parser) - params = parser.parse_args(opt) - - def return_something(): - return "something" - - params.something = return_something - - def wrapper_something(): - return return_something - - params.wrapper_something_wo_name = lambda: lambda: "1" - params.wrapper_something = wrapper_something - - params = WandbLogger._convert_params(params) - params = WandbLogger._flatten_dict(params) - params = WandbLogger._sanitize_callable_params(params) - assert params["gpus"] == "None" - assert params["something"] == "something" - assert params["wrapper_something"] == "wrapper_something" - assert params["wrapper_something_wo_name"] == "" - - @mock.patch("pytorch_lightning.loggers.wandb.wandb") def test_wandb_logger_offline_log_model(wandb, tmpdir): """Test that log_model=True raises an error in offline mode.""" diff --git a/tests/utilities/test_logger.py b/tests/utilities/test_logger.py new file mode 100644 index 0000000000..8d9b495fb9 --- /dev/null +++ b/tests/utilities/test_logger.py @@ -0,0 +1,174 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from argparse import ArgumentParser, Namespace + +import numpy as np +import torch + +from pytorch_lightning import Trainer +from pytorch_lightning.utilities.logger import ( + _add_prefix, + _convert_params, + _flatten_dict, + _sanitize_callable_params, + _sanitize_params, +) + + +def test_convert_params(): + """Test conversion of params to a dict.""" + + # Test normal dict, make sure it is unchanged + params = {"foo": "bar", 1: 23} + assert type(params) == dict + params = _convert_params(params) + assert type(params) == dict + assert params["foo"] == "bar" + assert params[1] == 23 + + # Test None conversion + params = None + assert type(params) != dict + params = _convert_params(params) + assert type(params) == dict + assert params == {} + + # Test conversion of argparse Namespace + opt = "--max_epochs 1".split(" ") + parser = ArgumentParser() + parser = Trainer.add_argparse_args(parent_parser=parser) + params = parser.parse_args(opt) + + assert type(params) == Namespace + params = _convert_params(params) + assert type(params) == dict + assert params["gpus"] is None + + +def test_flatten_dict(): + """Validate flatten_dict can handle nested dictionaries and argparse Namespace.""" + + # Test basic dict flattening with custom delimiter + params = {"a": {"b": "c"}} + params = _flatten_dict(params, "--") + + assert "a" not in params + assert params["a--b"] == "c" + + # Test complex nested dict flattening + params = {"a": {5: {"foo": "bar"}}, "b": 6, "c": {7: [1, 2, 3, 4], 8: "foo", 9: {10: "bar"}}} + params = _flatten_dict(params) + + assert "a" not in params + assert params["a/5/foo"] == "bar" + assert params["b"] == 6 + assert params["c/7"] == [1, 2, 3, 4] + assert params["c/8"] == "foo" + assert params["c/9/10"] == "bar" + + # Test flattening of argparse Namespace + opt = "--max_epochs 1".split(" ") + parser = ArgumentParser() + parser = Trainer.add_argparse_args(parent_parser=parser) + params = parser.parse_args(opt) + wrapping_dict = {"params": params} + params = _flatten_dict(wrapping_dict) + + assert type(params) == dict + assert params["params/logger"] is True + assert params["params/gpus"] == "None" + assert "logger" not in params + assert "gpus" not in params + + +def test_sanitize_callable_params(): + """Callback function are not serializiable. + + Therefore, we get them a chance to return something and if the returned type is not accepted, return None. + """ + opt = "--max_epochs 1".split(" ") + parser = ArgumentParser() + parser = Trainer.add_argparse_args(parent_parser=parser) + params = parser.parse_args(opt) + + def return_something(): + return "something" + + params.something = return_something + + def wrapper_something(): + return return_something + + params.wrapper_something_wo_name = lambda: lambda: "1" + params.wrapper_something = wrapper_something + + params = _convert_params(params) + params = _flatten_dict(params) + params = _sanitize_callable_params(params) + assert params["gpus"] == "None" + assert params["something"] == "something" + assert params["wrapper_something"] == "wrapper_something" + assert params["wrapper_something_wo_name"] == "" + + +def test_sanitize_params(): + """Verify sanitize params converts various types to loggable strings.""" + + params = { + "float": 0.3, + "int": 1, + "string": "abc", + "bool": True, + "list": [1, 2, 3], + "np_bool": np.bool_(False), + "np_int": np.int_(5), + "np_double": np.double(3.14159), + "namespace": Namespace(foo=3), + "layer": torch.nn.BatchNorm1d, + "tensor": torch.ones(3), + } + params = _sanitize_params(params) + + assert params["float"] == 0.3 + assert params["int"] == 1 + assert params["string"] == "abc" + assert params["bool"] is True + assert params["list"] == "[1, 2, 3]" + assert params["np_bool"] is False + assert params["np_int"] == 5 + assert params["np_double"] == 3.14159 + assert params["namespace"] == "Namespace(foo=3)" + assert params["layer"] == "" + assert torch.equal(params["tensor"], torch.ones(3)) + + +def test_add_prefix(): + """Verify add_prefix modifies the dict keys correctly.""" + + metrics = {"metric1": 1, "metric2": 2} + metrics = _add_prefix(metrics, "prefix", "-") + + assert "prefix-metric1" in metrics + assert "prefix-metric2" in metrics + assert "metric1" not in metrics + assert "metric2" not in metrics + + metrics = _add_prefix(metrics, "prefix2", "_") + + assert "prefix2_prefix-metric1" in metrics + assert "prefix2_prefix-metric2" in metrics + assert "prefix-metric1" not in metrics + assert "prefix-metric2" not in metrics + assert metrics["prefix2_prefix-metric1"] == 1 + assert metrics["prefix2_prefix-metric2"] == 2