BUG - Wandb: Sanitize callable. (#4320)

* add _sanitize_callable_params

* add call on _val if callable

* clean code formatter

* resolve pep8

* default return function name

* resolve pep8

* Apply suggestions from code review

Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com>

* Update CHANGELOG.md

Co-authored-by: Sean Naren <sean.narenthiran@gmail.com>
Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>
Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com>
This commit is contained in:
chaton 2020-10-26 11:57:03 +00:00 committed by GitHub
parent 376268f01e
commit f07ee33db6
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 72 additions and 0 deletions

1
.gitignore vendored
View File

@ -138,3 +138,4 @@ mlruns/
*.ckpt
pytorch\ lightning
test-reports/
wandb

View File

@ -11,28 +11,44 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Added `dirpath` and `filename` parameter in `ModelCheckpoint` ([#4213](https://github.com/PyTorchLightning/pytorch-lightning/pull/4213))
- Added plugins docs and DDPPlugin to customize ddp across all accelerators([#4258](https://github.com/PyTorchLightning/pytorch-lightning/pull/4285))
- Added `strict` option to the scheduler dictionary ([#3586](https://github.com/PyTorchLightning/pytorch-lightning/pull/3586))
- Added `fsspec` support for profilers ([#4162](https://github.com/PyTorchLightning/pytorch-lightning/pull/4162))
### Changed
- Improved error messages for invalid `configure_optimizers` returns ([#3587](https://github.com/PyTorchLightning/pytorch-lightning/pull/3587))
- Allow changing the logged step value in `validation_step` ([#4130](https://github.com/PyTorchLightning/pytorch-lightning/pull/4130))
- Allow setting `replace_sampler_ddp=True` with a distributed sampler already added ([#4273](https://github.com/PyTorchLightning/pytorch-lightning/pull/4273))
- Fixed santized parameters for `WandbLogger.log_hyperparams` ([#4320](https://github.com/PyTorchLightning/pytorch-lightning/pull/4320))
### Deprecated
- Deprecated `filepath` in `ModelCheckpoint` ([#4213](https://github.com/PyTorchLightning/pytorch-lightning/pull/4213))
- Deprecated `reorder` parameter of the `auc` metric ([#4237](https://github.com/PyTorchLightning/pytorch-lightning/pull/4237))
### Removed
### Fixed
- Fixed setting device ids in DDP ([#4297](https://github.com/PyTorchLightning/pytorch-lightning/pull/4297))

View File

@ -168,6 +168,31 @@ class LightningLoggerBase(ABC):
return params
@staticmethod
def _sanitize_callable_params(params: Dict[str, Any]) -> Dict[str, Any]:
"""
Sanitize callable params dict, e.g. ``{'a': <function_**** at 0x****>} -> {'a': 'function_****'}``.
Args:
params: Dictionary containing the hyperparameters
Returns:
dictionary with all callables sanitized
"""
def _sanitize_callable(val):
# Give them one chance to return a value. Don't go rabbit hole of recursive call
if isinstance(val, Callable):
try:
_val = val()
if isinstance(_val, Callable):
return val.__name__
return _val
except Exception:
return val.__name__
return val
return {key: _sanitize_callable(val) for key, val in params.items()}
@staticmethod
def _flatten_dict(params: Dict[str, Any], delimiter: str = '/') -> Dict[str, Any]:
"""

View File

@ -135,6 +135,7 @@ class WandbLogger(LightningLoggerBase):
def log_hyperparams(self, params: Union[Dict[str, Any], Namespace]) -> None:
params = self._convert_params(params)
params = self._flatten_dict(params)
params = self._sanitize_callable_params(params)
self.experiment.config.update(params, allow_val_change=True)
@rank_zero_only

View File

@ -14,6 +14,8 @@
import os
import pickle
from unittest import mock
from argparse import ArgumentParser
import types
from pytorch_lightning import Trainer
from pytorch_lightning.loggers import WandbLogger
@ -109,3 +111,30 @@ def test_wandb_logger_dirs_creation(wandb, tmpdir):
assert trainer.checkpoint_callback.dirpath == str(tmpdir / 'project' / version / 'checkpoints')
assert set(os.listdir(trainer.checkpoint_callback.dirpath)) == {'epoch=0.ckpt'}
def test_wandb_sanitize_callable_params(tmpdir):
"""
Callback function are not serializiable. Therefore, we get them a chance to return
something and if the returned type is not accepted, return None.
"""
opt = "--max_epochs 1".split(" ")
parser = ArgumentParser()
parser = Trainer.add_argparse_args(parent_parser=parser)
params = parser.parse_args(opt)
def return_something():
return "something"
params.something = return_something
def wrapper_something():
return return_something
params.wrapper_something = wrapper_something
assert isinstance(params.gpus, types.FunctionType)
params = WandbLogger._convert_params(params)
params = WandbLogger._flatten_dict(params)
params = WandbLogger._sanitize_callable_params(params)
assert params["gpus"] == '_gpus_arg_default'
assert params["something"] == "something"
assert params["wrapper_something"] == "wrapper_something"