fix(wandb): allow custom init args (#6989)

* feat(wandb): allow custom init args

* style: pep8

* fix: get dict args

* refactor: simplify init args

* test: test init args

* style: pep8

* docs: update CHANGELOG

* test: check default resume value

* fix: default value of anonymous

* fix: respect order of parameters

* feat: use look-up table for anonymous

* yapf formatting

Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com>
This commit is contained in:
Boris Dayma 2021-05-04 04:45:36 -05:00 committed by GitHub
parent 82c19e1444
commit 2a20102321
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 30 additions and 20 deletions

View File

@ -460,6 +460,10 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Fixed optimizer `state` not moved to `GPU` ([#7277](https://github.com/PyTorchLightning/pytorch-lightning/pull/7277))
- Fixed custom init args for `WandbLogger` ([#6989](https://github.com/PyTorchLightning/pytorch-lightning/pull/6989))
## [1.2.7] - 2021-04-06
### Fixed

View File

@ -59,8 +59,7 @@ class WandbLogger(LightningLoggerBase):
log_model: Save checkpoints in wandb dir to upload on W&B servers.
prefix: A string to put at the beginning of metric keys.
experiment: WandB experiment object. Automatically set when creating a run.
\**kwargs: Additional arguments like `entity`, `group`, `tags`, etc. used by
:func:`wandb.init` can be passed as keyword arguments in this logger.
\**kwargs: Arguments passed to :func:`wandb.init` like `entity`, `group`, `tags`, etc.
Raises:
ImportError:
@ -93,7 +92,7 @@ class WandbLogger(LightningLoggerBase):
save_dir: Optional[str] = None,
offline: Optional[bool] = False,
id: Optional[str] = None,
anonymous: Optional[bool] = False,
anonymous: Optional[bool] = None,
version: Optional[str] = None,
project: Optional[str] = None,
log_model: Optional[bool] = False,
@ -122,16 +121,25 @@ class WandbLogger(LightningLoggerBase):
)
super().__init__()
self._name = name
self._save_dir = save_dir
self._offline = offline
self._id = version or id
self._anonymous = 'allow' if anonymous else None
self._project = project
self._log_model = log_model
self._prefix = prefix
self._experiment = experiment
self._kwargs = kwargs
# set wandb init arguments
anonymous_lut = {True: 'allow', False: None}
self._wandb_init = dict(
name=name,
project=project,
id=version or id,
dir=save_dir,
resume='allow',
anonymous=anonymous_lut.get(anonymous, anonymous)
)
self._wandb_init.update(**kwargs)
# extract parameters
self._save_dir = self._wandb_init.get('dir')
self._name = self._wandb_init.get('name')
self._id = self._wandb_init.get('id')
def __getstate__(self):
state = self.__dict__.copy()
@ -158,15 +166,7 @@ class WandbLogger(LightningLoggerBase):
if self._experiment is None:
if self._offline:
os.environ['WANDB_MODE'] = 'dryrun'
self._experiment = wandb.init(
name=self._name,
dir=self._save_dir,
project=self._project,
anonymous=self._anonymous,
id=self._id,
resume='allow',
**self._kwargs
) if wandb.run is None else wandb.run
self._experiment = wandb.init(**self._wandb_init) if wandb.run is None else wandb.run
# save checkpoints in wandb dir to upload on W&B servers
if self._save_dir is None:

View File

@ -37,9 +37,13 @@ def test_wandb_logger_init(wandb, recwarn):
# test wandb.init called when there is no W&B run
wandb.run = None
logger = WandbLogger()
logger = WandbLogger(
name='test_name', save_dir='test_save_dir', version='test_id', project='test_project', resume='never'
)
logger.log_metrics({'acc': 1.0})
wandb.init.assert_called_once()
wandb.init.assert_called_once_with(
name='test_name', dir='test_save_dir', id='test_id', project='test_project', resume='never', anonymous=None
)
wandb.init().log.assert_called_once_with({'acc': 1.0})
# test wandb.init and setting logger experiment externally
@ -55,6 +59,8 @@ def test_wandb_logger_init(wandb, recwarn):
wandb.init.reset_mock()
wandb.run = wandb.init()
logger = WandbLogger()
# verify default resume value
assert logger._wandb_init['resume'] == 'allow'
logger.log_metrics({'acc': 1.0}, step=3)
wandb.init.assert_called_once()
wandb.init().log.assert_called_once_with({'acc': 1.0, 'trainer/global_step': 3})