fix(wandb): allow custom init args (#6989)
* feat(wandb): allow custom init args * style: pep8 * fix: get dict args * refactor: simplify init args * test: test init args * style: pep8 * docs: update CHANGELOG * test: check default resume value * fix: default value of anonymous * fix: respect order of parameters * feat: use look-up table for anonymous * yapf formatting Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com>
This commit is contained in:
parent
82c19e1444
commit
2a20102321
|
@ -460,6 +460,10 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
|
|||
- Fixed optimizer `state` not moved to `GPU` ([#7277](https://github.com/PyTorchLightning/pytorch-lightning/pull/7277))
|
||||
|
||||
|
||||
- Fixed custom init args for `WandbLogger` ([#6989](https://github.com/PyTorchLightning/pytorch-lightning/pull/6989))
|
||||
|
||||
|
||||
|
||||
## [1.2.7] - 2021-04-06
|
||||
|
||||
### Fixed
|
||||
|
|
|
@ -59,8 +59,7 @@ class WandbLogger(LightningLoggerBase):
|
|||
log_model: Save checkpoints in wandb dir to upload on W&B servers.
|
||||
prefix: A string to put at the beginning of metric keys.
|
||||
experiment: WandB experiment object. Automatically set when creating a run.
|
||||
\**kwargs: Additional arguments like `entity`, `group`, `tags`, etc. used by
|
||||
:func:`wandb.init` can be passed as keyword arguments in this logger.
|
||||
\**kwargs: Arguments passed to :func:`wandb.init` like `entity`, `group`, `tags`, etc.
|
||||
|
||||
Raises:
|
||||
ImportError:
|
||||
|
@ -93,7 +92,7 @@ class WandbLogger(LightningLoggerBase):
|
|||
save_dir: Optional[str] = None,
|
||||
offline: Optional[bool] = False,
|
||||
id: Optional[str] = None,
|
||||
anonymous: Optional[bool] = False,
|
||||
anonymous: Optional[bool] = None,
|
||||
version: Optional[str] = None,
|
||||
project: Optional[str] = None,
|
||||
log_model: Optional[bool] = False,
|
||||
|
@ -122,16 +121,25 @@ class WandbLogger(LightningLoggerBase):
|
|||
)
|
||||
|
||||
super().__init__()
|
||||
self._name = name
|
||||
self._save_dir = save_dir
|
||||
self._offline = offline
|
||||
self._id = version or id
|
||||
self._anonymous = 'allow' if anonymous else None
|
||||
self._project = project
|
||||
self._log_model = log_model
|
||||
self._prefix = prefix
|
||||
self._experiment = experiment
|
||||
self._kwargs = kwargs
|
||||
# set wandb init arguments
|
||||
anonymous_lut = {True: 'allow', False: None}
|
||||
self._wandb_init = dict(
|
||||
name=name,
|
||||
project=project,
|
||||
id=version or id,
|
||||
dir=save_dir,
|
||||
resume='allow',
|
||||
anonymous=anonymous_lut.get(anonymous, anonymous)
|
||||
)
|
||||
self._wandb_init.update(**kwargs)
|
||||
# extract parameters
|
||||
self._save_dir = self._wandb_init.get('dir')
|
||||
self._name = self._wandb_init.get('name')
|
||||
self._id = self._wandb_init.get('id')
|
||||
|
||||
def __getstate__(self):
|
||||
state = self.__dict__.copy()
|
||||
|
@ -158,15 +166,7 @@ class WandbLogger(LightningLoggerBase):
|
|||
if self._experiment is None:
|
||||
if self._offline:
|
||||
os.environ['WANDB_MODE'] = 'dryrun'
|
||||
self._experiment = wandb.init(
|
||||
name=self._name,
|
||||
dir=self._save_dir,
|
||||
project=self._project,
|
||||
anonymous=self._anonymous,
|
||||
id=self._id,
|
||||
resume='allow',
|
||||
**self._kwargs
|
||||
) if wandb.run is None else wandb.run
|
||||
self._experiment = wandb.init(**self._wandb_init) if wandb.run is None else wandb.run
|
||||
|
||||
# save checkpoints in wandb dir to upload on W&B servers
|
||||
if self._save_dir is None:
|
||||
|
|
|
@ -37,9 +37,13 @@ def test_wandb_logger_init(wandb, recwarn):
|
|||
|
||||
# test wandb.init called when there is no W&B run
|
||||
wandb.run = None
|
||||
logger = WandbLogger()
|
||||
logger = WandbLogger(
|
||||
name='test_name', save_dir='test_save_dir', version='test_id', project='test_project', resume='never'
|
||||
)
|
||||
logger.log_metrics({'acc': 1.0})
|
||||
wandb.init.assert_called_once()
|
||||
wandb.init.assert_called_once_with(
|
||||
name='test_name', dir='test_save_dir', id='test_id', project='test_project', resume='never', anonymous=None
|
||||
)
|
||||
wandb.init().log.assert_called_once_with({'acc': 1.0})
|
||||
|
||||
# test wandb.init and setting logger experiment externally
|
||||
|
@ -55,6 +59,8 @@ def test_wandb_logger_init(wandb, recwarn):
|
|||
wandb.init.reset_mock()
|
||||
wandb.run = wandb.init()
|
||||
logger = WandbLogger()
|
||||
# verify default resume value
|
||||
assert logger._wandb_init['resume'] == 'allow'
|
||||
logger.log_metrics({'acc': 1.0}, step=3)
|
||||
wandb.init.assert_called_once()
|
||||
wandb.init().log.assert_called_once_with({'acc': 1.0, 'trainer/global_step': 3})
|
||||
|
|
Loading…
Reference in New Issue