BUG - Wandb: Sanitize callable. (#4320)
* add _sanitize_callable_params * add call on _val if callable * clean code formatter * resolve pep8 * default return function name * resolve pep8 * Apply suggestions from code review Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com> * Update CHANGELOG.md Co-authored-by: Sean Naren <sean.narenthiran@gmail.com> Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com>
This commit is contained in:
parent
376268f01e
commit
f07ee33db6
|
@ -138,3 +138,4 @@ mlruns/
|
|||
*.ckpt
|
||||
pytorch\ lightning
|
||||
test-reports/
|
||||
wandb
|
||||
|
|
16
CHANGELOG.md
16
CHANGELOG.md
|
@ -11,28 +11,44 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
|
|||
|
||||
- Added `dirpath` and `filename` parameter in `ModelCheckpoint` ([#4213](https://github.com/PyTorchLightning/pytorch-lightning/pull/4213))
|
||||
|
||||
|
||||
- Added plugins docs and DDPPlugin to customize ddp across all accelerators([#4258](https://github.com/PyTorchLightning/pytorch-lightning/pull/4285))
|
||||
|
||||
|
||||
- Added `strict` option to the scheduler dictionary ([#3586](https://github.com/PyTorchLightning/pytorch-lightning/pull/3586))
|
||||
|
||||
|
||||
- Added `fsspec` support for profilers ([#4162](https://github.com/PyTorchLightning/pytorch-lightning/pull/4162))
|
||||
|
||||
|
||||
### Changed
|
||||
|
||||
|
||||
- Improved error messages for invalid `configure_optimizers` returns ([#3587](https://github.com/PyTorchLightning/pytorch-lightning/pull/3587))
|
||||
|
||||
|
||||
- Allow changing the logged step value in `validation_step` ([#4130](https://github.com/PyTorchLightning/pytorch-lightning/pull/4130))
|
||||
|
||||
|
||||
- Allow setting `replace_sampler_ddp=True` with a distributed sampler already added ([#4273](https://github.com/PyTorchLightning/pytorch-lightning/pull/4273))
|
||||
|
||||
|
||||
- Fixed santized parameters for `WandbLogger.log_hyperparams` ([#4320](https://github.com/PyTorchLightning/pytorch-lightning/pull/4320))
|
||||
|
||||
|
||||
### Deprecated
|
||||
|
||||
|
||||
- Deprecated `filepath` in `ModelCheckpoint` ([#4213](https://github.com/PyTorchLightning/pytorch-lightning/pull/4213))
|
||||
|
||||
|
||||
- Deprecated `reorder` parameter of the `auc` metric ([#4237](https://github.com/PyTorchLightning/pytorch-lightning/pull/4237))
|
||||
|
||||
|
||||
### Removed
|
||||
|
||||
|
||||
|
||||
### Fixed
|
||||
|
||||
- Fixed setting device ids in DDP ([#4297](https://github.com/PyTorchLightning/pytorch-lightning/pull/4297))
|
||||
|
|
|
@ -168,6 +168,31 @@ class LightningLoggerBase(ABC):
|
|||
|
||||
return params
|
||||
|
||||
@staticmethod
|
||||
def _sanitize_callable_params(params: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""
|
||||
Sanitize callable params dict, e.g. ``{'a': <function_**** at 0x****>} -> {'a': 'function_****'}``.
|
||||
|
||||
Args:
|
||||
params: Dictionary containing the hyperparameters
|
||||
|
||||
Returns:
|
||||
dictionary with all callables sanitized
|
||||
"""
|
||||
def _sanitize_callable(val):
|
||||
# Give them one chance to return a value. Don't go rabbit hole of recursive call
|
||||
if isinstance(val, Callable):
|
||||
try:
|
||||
_val = val()
|
||||
if isinstance(_val, Callable):
|
||||
return val.__name__
|
||||
return _val
|
||||
except Exception:
|
||||
return val.__name__
|
||||
return val
|
||||
|
||||
return {key: _sanitize_callable(val) for key, val in params.items()}
|
||||
|
||||
@staticmethod
|
||||
def _flatten_dict(params: Dict[str, Any], delimiter: str = '/') -> Dict[str, Any]:
|
||||
"""
|
||||
|
|
|
@ -135,6 +135,7 @@ class WandbLogger(LightningLoggerBase):
|
|||
def log_hyperparams(self, params: Union[Dict[str, Any], Namespace]) -> None:
|
||||
params = self._convert_params(params)
|
||||
params = self._flatten_dict(params)
|
||||
params = self._sanitize_callable_params(params)
|
||||
self.experiment.config.update(params, allow_val_change=True)
|
||||
|
||||
@rank_zero_only
|
||||
|
|
|
@ -14,6 +14,8 @@
|
|||
import os
|
||||
import pickle
|
||||
from unittest import mock
|
||||
from argparse import ArgumentParser
|
||||
import types
|
||||
|
||||
from pytorch_lightning import Trainer
|
||||
from pytorch_lightning.loggers import WandbLogger
|
||||
|
@ -109,3 +111,30 @@ def test_wandb_logger_dirs_creation(wandb, tmpdir):
|
|||
|
||||
assert trainer.checkpoint_callback.dirpath == str(tmpdir / 'project' / version / 'checkpoints')
|
||||
assert set(os.listdir(trainer.checkpoint_callback.dirpath)) == {'epoch=0.ckpt'}
|
||||
|
||||
|
||||
def test_wandb_sanitize_callable_params(tmpdir):
|
||||
"""
|
||||
Callback function are not serializiable. Therefore, we get them a chance to return
|
||||
something and if the returned type is not accepted, return None.
|
||||
"""
|
||||
opt = "--max_epochs 1".split(" ")
|
||||
parser = ArgumentParser()
|
||||
parser = Trainer.add_argparse_args(parent_parser=parser)
|
||||
params = parser.parse_args(opt)
|
||||
|
||||
def return_something():
|
||||
return "something"
|
||||
params.something = return_something
|
||||
|
||||
def wrapper_something():
|
||||
return return_something
|
||||
params.wrapper_something = wrapper_something
|
||||
|
||||
assert isinstance(params.gpus, types.FunctionType)
|
||||
params = WandbLogger._convert_params(params)
|
||||
params = WandbLogger._flatten_dict(params)
|
||||
params = WandbLogger._sanitize_callable_params(params)
|
||||
assert params["gpus"] == '_gpus_arg_default'
|
||||
assert params["something"] == "something"
|
||||
assert params["wrapper_something"] == "wrapper_something"
|
||||
|
|
Loading…
Reference in New Issue