Decouple utilities from `LightningLoggerBase` (#11484)
Co-authored-by: ananthsub <ananth.subramaniam@gmail.com> Co-authored-by: Jirka <jirka.borovec@seznam.cz>
This commit is contained in:
parent
fbc1f9f1d9
commit
115a5d08e8
|
@ -19,11 +19,10 @@ import operator
|
|||
from abc import ABC, abstractmethod
|
||||
from argparse import Namespace
|
||||
from functools import wraps
|
||||
from typing import Any, Callable, Dict, Iterable, List, Mapping, MutableMapping, Optional, Sequence, Tuple, Union
|
||||
from typing import Any, Callable, Dict, Iterable, List, Mapping, Optional, Sequence, Tuple, Union
|
||||
from weakref import ReferenceType
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
import pytorch_lightning as pl
|
||||
from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint
|
||||
|
@ -178,107 +177,6 @@ class LightningLoggerBase(ABC):
|
|||
"""
|
||||
pass
|
||||
|
||||
@staticmethod
|
||||
def _convert_params(params: Union[Dict[str, Any], Namespace]) -> Dict[str, Any]:
|
||||
# in case converting from namespace
|
||||
if isinstance(params, Namespace):
|
||||
params = vars(params)
|
||||
|
||||
if params is None:
|
||||
params = {}
|
||||
|
||||
return params
|
||||
|
||||
@staticmethod
|
||||
def _sanitize_callable_params(params: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Sanitize callable params dict, e.g. ``{'a': <function_**** at 0x****>} -> {'a': 'function_****'}``.
|
||||
|
||||
Args:
|
||||
params: Dictionary containing the hyperparameters
|
||||
|
||||
Returns:
|
||||
dictionary with all callables sanitized
|
||||
"""
|
||||
|
||||
def _sanitize_callable(val):
|
||||
# Give them one chance to return a value. Don't go rabbit hole of recursive call
|
||||
if isinstance(val, Callable):
|
||||
try:
|
||||
_val = val()
|
||||
if isinstance(_val, Callable):
|
||||
return val.__name__
|
||||
return _val
|
||||
# todo: specify the possible exception
|
||||
except Exception:
|
||||
return getattr(val, "__name__", None)
|
||||
return val
|
||||
|
||||
return {key: _sanitize_callable(val) for key, val in params.items()}
|
||||
|
||||
@staticmethod
|
||||
def _flatten_dict(params: Dict[Any, Any], delimiter: str = "/") -> Dict[str, Any]:
|
||||
"""Flatten hierarchical dict, e.g. ``{'a': {'b': 'c'}} -> {'a/b': 'c'}``.
|
||||
|
||||
Args:
|
||||
params: Dictionary containing the hyperparameters
|
||||
delimiter: Delimiter to express the hierarchy. Defaults to ``'/'``.
|
||||
|
||||
Returns:
|
||||
Flattened dict.
|
||||
|
||||
Examples:
|
||||
>>> LightningLoggerBase._flatten_dict({'a': {'b': 'c'}})
|
||||
{'a/b': 'c'}
|
||||
>>> LightningLoggerBase._flatten_dict({'a': {'b': 123}})
|
||||
{'a/b': 123}
|
||||
>>> LightningLoggerBase._flatten_dict({5: {'a': 123}})
|
||||
{'5/a': 123}
|
||||
"""
|
||||
|
||||
def _dict_generator(input_dict, prefixes=None):
|
||||
prefixes = prefixes[:] if prefixes else []
|
||||
if isinstance(input_dict, MutableMapping):
|
||||
for key, value in input_dict.items():
|
||||
key = str(key)
|
||||
if isinstance(value, (MutableMapping, Namespace)):
|
||||
value = vars(value) if isinstance(value, Namespace) else value
|
||||
yield from _dict_generator(value, prefixes + [key])
|
||||
else:
|
||||
yield prefixes + [key, value if value is not None else str(None)]
|
||||
else:
|
||||
yield prefixes + [input_dict if input_dict is None else str(input_dict)]
|
||||
|
||||
return {delimiter.join(keys): val for *keys, val in _dict_generator(params)}
|
||||
|
||||
@staticmethod
|
||||
def _sanitize_params(params: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Returns params with non-primitvies converted to strings for logging.
|
||||
|
||||
>>> params = {"float": 0.3,
|
||||
... "int": 1,
|
||||
... "string": "abc",
|
||||
... "bool": True,
|
||||
... "list": [1, 2, 3],
|
||||
... "namespace": Namespace(foo=3),
|
||||
... "layer": torch.nn.BatchNorm1d}
|
||||
>>> import pprint
|
||||
>>> pprint.pprint(LightningLoggerBase._sanitize_params(params)) # doctest: +NORMALIZE_WHITESPACE
|
||||
{'bool': True,
|
||||
'float': 0.3,
|
||||
'int': 1,
|
||||
'layer': "<class 'torch.nn.modules.batchnorm.BatchNorm1d'>",
|
||||
'list': '[1, 2, 3]',
|
||||
'namespace': 'Namespace(foo=3)',
|
||||
'string': 'abc'}
|
||||
"""
|
||||
for k in params.keys():
|
||||
# convert relevant np scalars to python types first (instead of str)
|
||||
if isinstance(params[k], (np.bool_, np.integer, np.floating)):
|
||||
params[k] = params[k].item()
|
||||
elif type(params[k]) not in [bool, int, float, str, torch.Tensor]:
|
||||
params[k] = str(params[k])
|
||||
return params
|
||||
|
||||
@abstractmethod
|
||||
def log_hyperparams(self, params: argparse.Namespace, *args, **kwargs):
|
||||
"""Record hyperparameters.
|
||||
|
@ -360,12 +258,6 @@ class LightningLoggerBase(ABC):
|
|||
def version(self) -> Union[int, str]:
|
||||
"""Return the experiment version."""
|
||||
|
||||
def _add_prefix(self, metrics: Dict[str, float]):
|
||||
if self._prefix:
|
||||
metrics = {f"{self._prefix}{self.LOGGER_JOIN_CHAR}{k}": v for k, v in metrics.items()}
|
||||
|
||||
return metrics
|
||||
|
||||
|
||||
class LoggerCollection(LightningLoggerBase):
|
||||
"""The :class:`LoggerCollection` class is used to iterate all logging actions over the given `logger_iterable`.
|
||||
|
|
|
@ -28,6 +28,7 @@ import pytorch_lightning as pl
|
|||
from pytorch_lightning.loggers.base import LightningLoggerBase, rank_zero_experiment
|
||||
from pytorch_lightning.utilities import _module_available, rank_zero_only
|
||||
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
||||
from pytorch_lightning.utilities.logger import _add_prefix, _convert_params, _flatten_dict
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
_COMET_AVAILABLE = _module_available("comet_ml")
|
||||
|
@ -232,8 +233,8 @@ class CometLogger(LightningLoggerBase):
|
|||
|
||||
@rank_zero_only
|
||||
def log_hyperparams(self, params: Union[Dict[str, Any], Namespace]) -> None:
|
||||
params = self._convert_params(params)
|
||||
params = self._flatten_dict(params)
|
||||
params = _convert_params(params)
|
||||
params = _flatten_dict(params)
|
||||
self.experiment.log_parameters(params)
|
||||
|
||||
@rank_zero_only
|
||||
|
@ -246,7 +247,7 @@ class CometLogger(LightningLoggerBase):
|
|||
metrics_without_epoch[key] = val.cpu().detach()
|
||||
|
||||
epoch = metrics_without_epoch.pop("epoch", None)
|
||||
metrics_without_epoch = self._add_prefix(metrics_without_epoch)
|
||||
metrics_without_epoch = _add_prefix(metrics_without_epoch, self._prefix, self.LOGGER_JOIN_CHAR)
|
||||
self.experiment.log_metrics(metrics_without_epoch, step=step, epoch=epoch)
|
||||
|
||||
def reset_experiment(self):
|
||||
|
|
|
@ -30,6 +30,7 @@ from pytorch_lightning.core.saving import save_hparams_to_yaml
|
|||
from pytorch_lightning.loggers.base import LightningLoggerBase, rank_zero_experiment
|
||||
from pytorch_lightning.utilities import rank_zero_warn
|
||||
from pytorch_lightning.utilities.distributed import rank_zero_only
|
||||
from pytorch_lightning.utilities.logger import _add_prefix, _convert_params
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
@ -193,12 +194,12 @@ class CSVLogger(LightningLoggerBase):
|
|||
|
||||
@rank_zero_only
|
||||
def log_hyperparams(self, params: Union[Dict[str, Any], Namespace]) -> None:
|
||||
params = self._convert_params(params)
|
||||
params = _convert_params(params)
|
||||
self.experiment.log_hparams(params)
|
||||
|
||||
@rank_zero_only
|
||||
def log_metrics(self, metrics: Dict[str, float], step: Optional[int] = None) -> None:
|
||||
metrics = self._add_prefix(metrics)
|
||||
metrics = _add_prefix(metrics, self._prefix, self.LOGGER_JOIN_CHAR)
|
||||
self.experiment.log_metrics(metrics, step)
|
||||
if step is not None and (step + 1) % self._flush_logs_every_n_steps == 0:
|
||||
self.save()
|
||||
|
|
|
@ -24,6 +24,7 @@ from typing import Any, Dict, Optional, Union
|
|||
|
||||
from pytorch_lightning.loggers.base import LightningLoggerBase, rank_zero_experiment
|
||||
from pytorch_lightning.utilities import _module_available, rank_zero_only, rank_zero_warn
|
||||
from pytorch_lightning.utilities.logger import _add_prefix, _convert_params, _flatten_dict
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
LOCAL_FILE_URI_PREFIX = "file:"
|
||||
|
@ -191,8 +192,8 @@ class MLFlowLogger(LightningLoggerBase):
|
|||
|
||||
@rank_zero_only
|
||||
def log_hyperparams(self, params: Union[Dict[str, Any], Namespace]) -> None:
|
||||
params = self._convert_params(params)
|
||||
params = self._flatten_dict(params)
|
||||
params = _convert_params(params)
|
||||
params = _flatten_dict(params)
|
||||
for k, v in params.items():
|
||||
if len(str(v)) > 250:
|
||||
rank_zero_warn(
|
||||
|
@ -206,7 +207,7 @@ class MLFlowLogger(LightningLoggerBase):
|
|||
def log_metrics(self, metrics: Dict[str, float], step: Optional[int] = None) -> None:
|
||||
assert rank_zero_only.rank == 0, "experiment tried to log from global_rank != 0"
|
||||
|
||||
metrics = self._add_prefix(metrics)
|
||||
metrics = _add_prefix(metrics, self._prefix, self.LOGGER_JOIN_CHAR)
|
||||
|
||||
timestamp_ms = int(time() * 1000)
|
||||
for k, v in metrics.items():
|
||||
|
|
|
@ -34,6 +34,7 @@ from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint
|
|||
from pytorch_lightning.loggers.base import LightningLoggerBase, rank_zero_experiment
|
||||
from pytorch_lightning.utilities import rank_zero_only
|
||||
from pytorch_lightning.utilities.imports import _NEPTUNE_AVAILABLE, _NEPTUNE_GREATER_EQUAL_0_9
|
||||
from pytorch_lightning.utilities.logger import _add_prefix, _convert_params, _sanitize_callable_params
|
||||
from pytorch_lightning.utilities.model_summary import ModelSummary
|
||||
|
||||
if _NEPTUNE_AVAILABLE and _NEPTUNE_GREATER_EQUAL_0_9:
|
||||
|
@ -479,8 +480,8 @@ class NeptuneLogger(LightningLoggerBase):
|
|||
|
||||
neptune_logger.log_hyperparams(PARAMS)
|
||||
"""
|
||||
params = self._convert_params(params)
|
||||
params = self._sanitize_callable_params(params)
|
||||
params = _convert_params(params)
|
||||
params = _sanitize_callable_params(params)
|
||||
|
||||
parameters_key = self.PARAMETERS_KEY
|
||||
parameters_key = self._construct_path_with_prefix(parameters_key)
|
||||
|
@ -498,7 +499,7 @@ class NeptuneLogger(LightningLoggerBase):
|
|||
if rank_zero_only.rank != 0:
|
||||
raise ValueError("run tried to log from global_rank != 0")
|
||||
|
||||
metrics = self._add_prefix(metrics)
|
||||
metrics = _add_prefix(metrics, self._prefix, self.LOGGER_JOIN_CHAR)
|
||||
|
||||
for key, val in metrics.items():
|
||||
# `step` is ignored because Neptune expects strictly increasing step values which
|
||||
|
|
|
@ -31,6 +31,8 @@ from pytorch_lightning.core.saving import save_hparams_to_yaml
|
|||
from pytorch_lightning.loggers.base import LightningLoggerBase, rank_zero_experiment
|
||||
from pytorch_lightning.utilities import _OMEGACONF_AVAILABLE, rank_zero_only, rank_zero_warn
|
||||
from pytorch_lightning.utilities.cloud_io import get_filesystem
|
||||
from pytorch_lightning.utilities.logger import _add_prefix, _convert_params, _flatten_dict
|
||||
from pytorch_lightning.utilities.logger import _sanitize_params as _utils_sanitize_params
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
@ -186,7 +188,7 @@ class TensorBoardLogger(LightningLoggerBase):
|
|||
metrics: Dictionary with metric names as keys and measured quantities as values
|
||||
"""
|
||||
|
||||
params = self._convert_params(params)
|
||||
params = _convert_params(params)
|
||||
|
||||
# store params to output
|
||||
if _OMEGACONF_AVAILABLE and isinstance(params, Container):
|
||||
|
@ -195,7 +197,7 @@ class TensorBoardLogger(LightningLoggerBase):
|
|||
self.hparams.update(params)
|
||||
|
||||
# format params into the suitable for tensorboard
|
||||
params = self._flatten_dict(params)
|
||||
params = _flatten_dict(params)
|
||||
params = self._sanitize_params(params)
|
||||
|
||||
if metrics is None:
|
||||
|
@ -216,7 +218,7 @@ class TensorBoardLogger(LightningLoggerBase):
|
|||
def log_metrics(self, metrics: Dict[str, float], step: Optional[int] = None) -> None:
|
||||
assert rank_zero_only.rank == 0, "experiment tried to log from global_rank != 0"
|
||||
|
||||
metrics = self._add_prefix(metrics)
|
||||
metrics = _add_prefix(metrics, self._prefix, self.LOGGER_JOIN_CHAR)
|
||||
|
||||
for k, v in metrics.items():
|
||||
if isinstance(v, torch.Tensor):
|
||||
|
@ -311,7 +313,7 @@ class TensorBoardLogger(LightningLoggerBase):
|
|||
|
||||
@staticmethod
|
||||
def _sanitize_params(params: Dict[str, Any]) -> Dict[str, Any]:
|
||||
params = LightningLoggerBase._sanitize_params(params)
|
||||
params = _utils_sanitize_params(params)
|
||||
# logging of arrays with dimension > 1 is not supported, sanitize as string
|
||||
return {k: str(v) if isinstance(v, (torch.Tensor, np.ndarray)) and v.ndim > 1 else v for k, v in params.items()}
|
||||
|
||||
|
|
|
@ -22,6 +22,7 @@ import pytorch_lightning as pl
|
|||
from pytorch_lightning.loggers.base import LightningLoggerBase, rank_zero_experiment
|
||||
from pytorch_lightning.utilities import _module_available, rank_zero_deprecation, rank_zero_warn
|
||||
from pytorch_lightning.utilities.distributed import rank_zero_only
|
||||
from pytorch_lightning.utilities.logger import _add_prefix, _convert_params, _flatten_dict
|
||||
|
||||
_TESTTUBE_AVAILABLE = _module_available("test_tube")
|
||||
|
||||
|
@ -152,14 +153,14 @@ class TestTubeLogger(LightningLoggerBase):
|
|||
def log_hyperparams(self, params: Union[Dict[str, Any], Namespace]) -> None:
|
||||
# TODO: HACK figure out where this is being set to true
|
||||
self.experiment.debug = self.debug
|
||||
params = self._convert_params(params)
|
||||
params = self._flatten_dict(params)
|
||||
params = _convert_params(params)
|
||||
params = _flatten_dict(params)
|
||||
self.experiment.argparse(Namespace(**params))
|
||||
|
||||
@rank_zero_only
|
||||
def log_metrics(self, metrics: Dict[str, float], step: Optional[int] = None) -> None:
|
||||
# TODO: HACK figure out where this is being set to true
|
||||
metrics = self._add_prefix(metrics)
|
||||
metrics = _add_prefix(metrics, self._prefix, self.LOGGER_JOIN_CHAR)
|
||||
self.experiment.debug = self.debug
|
||||
self.experiment.log(metrics, global_step=step)
|
||||
|
||||
|
|
|
@ -29,6 +29,7 @@ from pytorch_lightning.loggers.base import LightningLoggerBase, rank_zero_experi
|
|||
from pytorch_lightning.utilities import _module_available, rank_zero_only
|
||||
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
||||
from pytorch_lightning.utilities.imports import _compare_version
|
||||
from pytorch_lightning.utilities.logger import _add_prefix, _convert_params, _flatten_dict, _sanitize_callable_params
|
||||
from pytorch_lightning.utilities.warnings import rank_zero_warn
|
||||
|
||||
_WANDB_AVAILABLE = _module_available("wandb")
|
||||
|
@ -356,16 +357,16 @@ class WandbLogger(LightningLoggerBase):
|
|||
|
||||
@rank_zero_only
|
||||
def log_hyperparams(self, params: Union[Dict[str, Any], Namespace]) -> None:
|
||||
params = self._convert_params(params)
|
||||
params = self._flatten_dict(params)
|
||||
params = self._sanitize_callable_params(params)
|
||||
params = _convert_params(params)
|
||||
params = _flatten_dict(params)
|
||||
params = _sanitize_callable_params(params)
|
||||
self.experiment.config.update(params, allow_val_change=True)
|
||||
|
||||
@rank_zero_only
|
||||
def log_metrics(self, metrics: Dict[str, float], step: Optional[int] = None) -> None:
|
||||
assert rank_zero_only.rank == 0, "experiment tried to log from global_rank != 0"
|
||||
|
||||
metrics = self._add_prefix(metrics)
|
||||
metrics = _add_prefix(metrics, self._prefix, self.LOGGER_JOIN_CHAR)
|
||||
if step is not None:
|
||||
self.experiment.log({**metrics, "trainer/global_step": step})
|
||||
else:
|
||||
|
|
|
@ -0,0 +1,148 @@
|
|||
# Copyright The PyTorch Lightning team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Utilities for loggers."""
|
||||
|
||||
from argparse import Namespace
|
||||
from typing import Any, Dict, Generator, List, MutableMapping, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
|
||||
def _convert_params(params: Union[Dict[str, Any], Namespace]) -> Dict[str, Any]:
|
||||
"""Ensure parameters are a dict or convert to dict if neccesary.
|
||||
Args:
|
||||
params: Target to be converted to a dictionary
|
||||
|
||||
Returns:
|
||||
params as a dictionary
|
||||
|
||||
"""
|
||||
# in case converting from namespace
|
||||
if isinstance(params, Namespace):
|
||||
params = vars(params)
|
||||
|
||||
if params is None:
|
||||
params = {}
|
||||
|
||||
return params
|
||||
|
||||
|
||||
def _sanitize_callable_params(params: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Sanitize callable params dict, e.g. ``{'a': <function_**** at 0x****>} -> {'a': 'function_****'}``.
|
||||
|
||||
Args:
|
||||
params: Dictionary containing the hyperparameters
|
||||
|
||||
Returns:
|
||||
dictionary with all callables sanitized
|
||||
"""
|
||||
|
||||
def _sanitize_callable(val: Any) -> Any:
|
||||
# Give them one chance to return a value. Don't go rabbit hole of recursive call
|
||||
if callable(val):
|
||||
try:
|
||||
_val = val()
|
||||
if callable(_val):
|
||||
return val.__name__
|
||||
return _val
|
||||
# todo: specify the possible exception
|
||||
except Exception:
|
||||
return getattr(val, "__name__", None)
|
||||
return val
|
||||
|
||||
return {key: _sanitize_callable(val) for key, val in params.items()}
|
||||
|
||||
|
||||
def _flatten_dict(params: Dict[Any, Any], delimiter: str = "/") -> Dict[str, Any]:
|
||||
"""Flatten hierarchical dict, e.g. ``{'a': {'b': 'c'}} -> {'a/b': 'c'}``.
|
||||
|
||||
Args:
|
||||
params: Dictionary containing the hyperparameters
|
||||
delimiter: Delimiter to express the hierarchy. Defaults to ``'/'``.
|
||||
|
||||
Returns:
|
||||
Flattened dict.
|
||||
|
||||
Examples:
|
||||
>>> _flatten_dict({'a': {'b': 'c'}})
|
||||
{'a/b': 'c'}
|
||||
>>> _flatten_dict({'a': {'b': 123}})
|
||||
{'a/b': 123}
|
||||
>>> _flatten_dict({5: {'a': 123}})
|
||||
{'5/a': 123}
|
||||
"""
|
||||
|
||||
def _dict_generator(
|
||||
input_dict: Any, prefixes: List[Optional[str]] = None
|
||||
) -> Generator[Any, Optional[List[str]], List[Any]]:
|
||||
prefixes = prefixes[:] if prefixes else []
|
||||
if isinstance(input_dict, MutableMapping):
|
||||
for key, value in input_dict.items():
|
||||
key = str(key)
|
||||
if isinstance(value, (MutableMapping, Namespace)):
|
||||
value = vars(value) if isinstance(value, Namespace) else value
|
||||
yield from _dict_generator(value, prefixes + [key])
|
||||
else:
|
||||
yield prefixes + [key, value if value is not None else str(None)]
|
||||
else:
|
||||
yield prefixes + [input_dict if input_dict is None else str(input_dict)]
|
||||
|
||||
return {delimiter.join(keys): val for *keys, val in _dict_generator(params)}
|
||||
|
||||
|
||||
def _sanitize_params(params: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Returns params with non-primitvies converted to strings for logging.
|
||||
|
||||
>>> params = {"float": 0.3,
|
||||
... "int": 1,
|
||||
... "string": "abc",
|
||||
... "bool": True,
|
||||
... "list": [1, 2, 3],
|
||||
... "namespace": Namespace(foo=3),
|
||||
... "layer": torch.nn.BatchNorm1d}
|
||||
>>> import pprint
|
||||
>>> pprint.pprint(_sanitize_params(params)) # doctest: +NORMALIZE_WHITESPACE
|
||||
{'bool': True,
|
||||
'float': 0.3,
|
||||
'int': 1,
|
||||
'layer': "<class 'torch.nn.modules.batchnorm.BatchNorm1d'>",
|
||||
'list': '[1, 2, 3]',
|
||||
'namespace': 'Namespace(foo=3)',
|
||||
'string': 'abc'}
|
||||
"""
|
||||
for k in params.keys():
|
||||
# convert relevant np scalars to python types first (instead of str)
|
||||
if isinstance(params[k], (np.bool_, np.integer, np.floating)):
|
||||
params[k] = params[k].item()
|
||||
elif type(params[k]) not in [bool, int, float, str, torch.Tensor]:
|
||||
params[k] = str(params[k])
|
||||
return params
|
||||
|
||||
|
||||
def _add_prefix(metrics: Dict[str, float], prefix: str, separator: str) -> Dict[str, float]:
|
||||
"""Insert prefix before each key in a dict, separated by the separator.
|
||||
|
||||
Args:
|
||||
metrics: Dictionary with metric names as keys and measured quantities as values
|
||||
prefix: Prefix to insert before each key
|
||||
separator: Separates prefix and original key name
|
||||
|
||||
Returns:
|
||||
Dictionary with prefix and separator inserted before each key
|
||||
"""
|
||||
if prefix:
|
||||
metrics = {f"{prefix}{separator}{k}": v for k, v in metrics.items()}
|
||||
|
||||
return metrics
|
|
@ -26,6 +26,7 @@ from pytorch_lightning.loggers import LightningLoggerBase, LoggerCollection, Ten
|
|||
from pytorch_lightning.loggers.base import DummyExperiment, DummyLogger
|
||||
from pytorch_lightning.utilities import rank_zero_only
|
||||
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
||||
from pytorch_lightning.utilities.logger import _convert_params, _sanitize_params
|
||||
from tests.helpers.boring_model import BoringDataModule, BoringModel
|
||||
|
||||
|
||||
|
@ -290,8 +291,8 @@ def test_np_sanitization():
|
|||
|
||||
@rank_zero_only
|
||||
def log_hyperparams(self, params):
|
||||
params = self._convert_params(params)
|
||||
params = self._sanitize_params(params)
|
||||
params = _convert_params(params)
|
||||
params = _sanitize_params(params)
|
||||
self.logged_params = params
|
||||
|
||||
logger = CustomParamsLogger()
|
||||
|
|
|
@ -13,7 +13,6 @@
|
|||
# limitations under the License.
|
||||
import os
|
||||
import pickle
|
||||
from argparse import ArgumentParser
|
||||
from unittest import mock
|
||||
|
||||
import pytest
|
||||
|
@ -279,36 +278,6 @@ def test_wandb_log_media(wandb, tmpdir):
|
|||
wandb.init().log.assert_called_once_with({"samples": wandb.Table(), "trainer/global_step": 5})
|
||||
|
||||
|
||||
def test_wandb_sanitize_callable_params(tmpdir):
|
||||
"""Callback function are not serializiable.
|
||||
|
||||
Therefore, we get them a chance to return something and if the returned type is not accepted, return None.
|
||||
"""
|
||||
opt = "--max_epochs 1".split(" ")
|
||||
parser = ArgumentParser()
|
||||
parser = Trainer.add_argparse_args(parent_parser=parser)
|
||||
params = parser.parse_args(opt)
|
||||
|
||||
def return_something():
|
||||
return "something"
|
||||
|
||||
params.something = return_something
|
||||
|
||||
def wrapper_something():
|
||||
return return_something
|
||||
|
||||
params.wrapper_something_wo_name = lambda: lambda: "1"
|
||||
params.wrapper_something = wrapper_something
|
||||
|
||||
params = WandbLogger._convert_params(params)
|
||||
params = WandbLogger._flatten_dict(params)
|
||||
params = WandbLogger._sanitize_callable_params(params)
|
||||
assert params["gpus"] == "None"
|
||||
assert params["something"] == "something"
|
||||
assert params["wrapper_something"] == "wrapper_something"
|
||||
assert params["wrapper_something_wo_name"] == "<lambda>"
|
||||
|
||||
|
||||
@mock.patch("pytorch_lightning.loggers.wandb.wandb")
|
||||
def test_wandb_logger_offline_log_model(wandb, tmpdir):
|
||||
"""Test that log_model=True raises an error in offline mode."""
|
||||
|
|
|
@ -0,0 +1,174 @@
|
|||
# Copyright The PyTorch Lightning team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from argparse import ArgumentParser, Namespace
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from pytorch_lightning import Trainer
|
||||
from pytorch_lightning.utilities.logger import (
|
||||
_add_prefix,
|
||||
_convert_params,
|
||||
_flatten_dict,
|
||||
_sanitize_callable_params,
|
||||
_sanitize_params,
|
||||
)
|
||||
|
||||
|
||||
def test_convert_params():
|
||||
"""Test conversion of params to a dict."""
|
||||
|
||||
# Test normal dict, make sure it is unchanged
|
||||
params = {"foo": "bar", 1: 23}
|
||||
assert type(params) == dict
|
||||
params = _convert_params(params)
|
||||
assert type(params) == dict
|
||||
assert params["foo"] == "bar"
|
||||
assert params[1] == 23
|
||||
|
||||
# Test None conversion
|
||||
params = None
|
||||
assert type(params) != dict
|
||||
params = _convert_params(params)
|
||||
assert type(params) == dict
|
||||
assert params == {}
|
||||
|
||||
# Test conversion of argparse Namespace
|
||||
opt = "--max_epochs 1".split(" ")
|
||||
parser = ArgumentParser()
|
||||
parser = Trainer.add_argparse_args(parent_parser=parser)
|
||||
params = parser.parse_args(opt)
|
||||
|
||||
assert type(params) == Namespace
|
||||
params = _convert_params(params)
|
||||
assert type(params) == dict
|
||||
assert params["gpus"] is None
|
||||
|
||||
|
||||
def test_flatten_dict():
|
||||
"""Validate flatten_dict can handle nested dictionaries and argparse Namespace."""
|
||||
|
||||
# Test basic dict flattening with custom delimiter
|
||||
params = {"a": {"b": "c"}}
|
||||
params = _flatten_dict(params, "--")
|
||||
|
||||
assert "a" not in params
|
||||
assert params["a--b"] == "c"
|
||||
|
||||
# Test complex nested dict flattening
|
||||
params = {"a": {5: {"foo": "bar"}}, "b": 6, "c": {7: [1, 2, 3, 4], 8: "foo", 9: {10: "bar"}}}
|
||||
params = _flatten_dict(params)
|
||||
|
||||
assert "a" not in params
|
||||
assert params["a/5/foo"] == "bar"
|
||||
assert params["b"] == 6
|
||||
assert params["c/7"] == [1, 2, 3, 4]
|
||||
assert params["c/8"] == "foo"
|
||||
assert params["c/9/10"] == "bar"
|
||||
|
||||
# Test flattening of argparse Namespace
|
||||
opt = "--max_epochs 1".split(" ")
|
||||
parser = ArgumentParser()
|
||||
parser = Trainer.add_argparse_args(parent_parser=parser)
|
||||
params = parser.parse_args(opt)
|
||||
wrapping_dict = {"params": params}
|
||||
params = _flatten_dict(wrapping_dict)
|
||||
|
||||
assert type(params) == dict
|
||||
assert params["params/logger"] is True
|
||||
assert params["params/gpus"] == "None"
|
||||
assert "logger" not in params
|
||||
assert "gpus" not in params
|
||||
|
||||
|
||||
def test_sanitize_callable_params():
|
||||
"""Callback function are not serializiable.
|
||||
|
||||
Therefore, we get them a chance to return something and if the returned type is not accepted, return None.
|
||||
"""
|
||||
opt = "--max_epochs 1".split(" ")
|
||||
parser = ArgumentParser()
|
||||
parser = Trainer.add_argparse_args(parent_parser=parser)
|
||||
params = parser.parse_args(opt)
|
||||
|
||||
def return_something():
|
||||
return "something"
|
||||
|
||||
params.something = return_something
|
||||
|
||||
def wrapper_something():
|
||||
return return_something
|
||||
|
||||
params.wrapper_something_wo_name = lambda: lambda: "1"
|
||||
params.wrapper_something = wrapper_something
|
||||
|
||||
params = _convert_params(params)
|
||||
params = _flatten_dict(params)
|
||||
params = _sanitize_callable_params(params)
|
||||
assert params["gpus"] == "None"
|
||||
assert params["something"] == "something"
|
||||
assert params["wrapper_something"] == "wrapper_something"
|
||||
assert params["wrapper_something_wo_name"] == "<lambda>"
|
||||
|
||||
|
||||
def test_sanitize_params():
|
||||
"""Verify sanitize params converts various types to loggable strings."""
|
||||
|
||||
params = {
|
||||
"float": 0.3,
|
||||
"int": 1,
|
||||
"string": "abc",
|
||||
"bool": True,
|
||||
"list": [1, 2, 3],
|
||||
"np_bool": np.bool_(False),
|
||||
"np_int": np.int_(5),
|
||||
"np_double": np.double(3.14159),
|
||||
"namespace": Namespace(foo=3),
|
||||
"layer": torch.nn.BatchNorm1d,
|
||||
"tensor": torch.ones(3),
|
||||
}
|
||||
params = _sanitize_params(params)
|
||||
|
||||
assert params["float"] == 0.3
|
||||
assert params["int"] == 1
|
||||
assert params["string"] == "abc"
|
||||
assert params["bool"] is True
|
||||
assert params["list"] == "[1, 2, 3]"
|
||||
assert params["np_bool"] is False
|
||||
assert params["np_int"] == 5
|
||||
assert params["np_double"] == 3.14159
|
||||
assert params["namespace"] == "Namespace(foo=3)"
|
||||
assert params["layer"] == "<class 'torch.nn.modules.batchnorm.BatchNorm1d'>"
|
||||
assert torch.equal(params["tensor"], torch.ones(3))
|
||||
|
||||
|
||||
def test_add_prefix():
|
||||
"""Verify add_prefix modifies the dict keys correctly."""
|
||||
|
||||
metrics = {"metric1": 1, "metric2": 2}
|
||||
metrics = _add_prefix(metrics, "prefix", "-")
|
||||
|
||||
assert "prefix-metric1" in metrics
|
||||
assert "prefix-metric2" in metrics
|
||||
assert "metric1" not in metrics
|
||||
assert "metric2" not in metrics
|
||||
|
||||
metrics = _add_prefix(metrics, "prefix2", "_")
|
||||
|
||||
assert "prefix2_prefix-metric1" in metrics
|
||||
assert "prefix2_prefix-metric2" in metrics
|
||||
assert "prefix-metric1" not in metrics
|
||||
assert "prefix-metric2" not in metrics
|
||||
assert metrics["prefix2_prefix-metric1"] == 1
|
||||
assert metrics["prefix2_prefix-metric2"] == 2
|
Loading…
Reference in New Issue