fix(wandb): allow custom init args (#6989)
* feat(wandb): allow custom init args * style: pep8 * fix: get dict args * refactor: simplify init args * test: test init args * style: pep8 * docs: update CHANGELOG * test: check default resume value * fix: default value of anonymous * fix: respect order of parameters * feat: use look-up table for anonymous * yapf formatting Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com>
This commit is contained in:
parent
82c19e1444
commit
2a20102321
|
@ -460,6 +460,10 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
|
||||||
- Fixed optimizer `state` not moved to `GPU` ([#7277](https://github.com/PyTorchLightning/pytorch-lightning/pull/7277))
|
- Fixed optimizer `state` not moved to `GPU` ([#7277](https://github.com/PyTorchLightning/pytorch-lightning/pull/7277))
|
||||||
|
|
||||||
|
|
||||||
|
- Fixed custom init args for `WandbLogger` ([#6989](https://github.com/PyTorchLightning/pytorch-lightning/pull/6989))
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
## [1.2.7] - 2021-04-06
|
## [1.2.7] - 2021-04-06
|
||||||
|
|
||||||
### Fixed
|
### Fixed
|
||||||
|
|
|
@ -59,8 +59,7 @@ class WandbLogger(LightningLoggerBase):
|
||||||
log_model: Save checkpoints in wandb dir to upload on W&B servers.
|
log_model: Save checkpoints in wandb dir to upload on W&B servers.
|
||||||
prefix: A string to put at the beginning of metric keys.
|
prefix: A string to put at the beginning of metric keys.
|
||||||
experiment: WandB experiment object. Automatically set when creating a run.
|
experiment: WandB experiment object. Automatically set when creating a run.
|
||||||
\**kwargs: Additional arguments like `entity`, `group`, `tags`, etc. used by
|
\**kwargs: Arguments passed to :func:`wandb.init` like `entity`, `group`, `tags`, etc.
|
||||||
:func:`wandb.init` can be passed as keyword arguments in this logger.
|
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
ImportError:
|
ImportError:
|
||||||
|
@ -93,7 +92,7 @@ class WandbLogger(LightningLoggerBase):
|
||||||
save_dir: Optional[str] = None,
|
save_dir: Optional[str] = None,
|
||||||
offline: Optional[bool] = False,
|
offline: Optional[bool] = False,
|
||||||
id: Optional[str] = None,
|
id: Optional[str] = None,
|
||||||
anonymous: Optional[bool] = False,
|
anonymous: Optional[bool] = None,
|
||||||
version: Optional[str] = None,
|
version: Optional[str] = None,
|
||||||
project: Optional[str] = None,
|
project: Optional[str] = None,
|
||||||
log_model: Optional[bool] = False,
|
log_model: Optional[bool] = False,
|
||||||
|
@ -122,16 +121,25 @@ class WandbLogger(LightningLoggerBase):
|
||||||
)
|
)
|
||||||
|
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self._name = name
|
|
||||||
self._save_dir = save_dir
|
|
||||||
self._offline = offline
|
self._offline = offline
|
||||||
self._id = version or id
|
|
||||||
self._anonymous = 'allow' if anonymous else None
|
|
||||||
self._project = project
|
|
||||||
self._log_model = log_model
|
self._log_model = log_model
|
||||||
self._prefix = prefix
|
self._prefix = prefix
|
||||||
self._experiment = experiment
|
self._experiment = experiment
|
||||||
self._kwargs = kwargs
|
# set wandb init arguments
|
||||||
|
anonymous_lut = {True: 'allow', False: None}
|
||||||
|
self._wandb_init = dict(
|
||||||
|
name=name,
|
||||||
|
project=project,
|
||||||
|
id=version or id,
|
||||||
|
dir=save_dir,
|
||||||
|
resume='allow',
|
||||||
|
anonymous=anonymous_lut.get(anonymous, anonymous)
|
||||||
|
)
|
||||||
|
self._wandb_init.update(**kwargs)
|
||||||
|
# extract parameters
|
||||||
|
self._save_dir = self._wandb_init.get('dir')
|
||||||
|
self._name = self._wandb_init.get('name')
|
||||||
|
self._id = self._wandb_init.get('id')
|
||||||
|
|
||||||
def __getstate__(self):
|
def __getstate__(self):
|
||||||
state = self.__dict__.copy()
|
state = self.__dict__.copy()
|
||||||
|
@ -158,15 +166,7 @@ class WandbLogger(LightningLoggerBase):
|
||||||
if self._experiment is None:
|
if self._experiment is None:
|
||||||
if self._offline:
|
if self._offline:
|
||||||
os.environ['WANDB_MODE'] = 'dryrun'
|
os.environ['WANDB_MODE'] = 'dryrun'
|
||||||
self._experiment = wandb.init(
|
self._experiment = wandb.init(**self._wandb_init) if wandb.run is None else wandb.run
|
||||||
name=self._name,
|
|
||||||
dir=self._save_dir,
|
|
||||||
project=self._project,
|
|
||||||
anonymous=self._anonymous,
|
|
||||||
id=self._id,
|
|
||||||
resume='allow',
|
|
||||||
**self._kwargs
|
|
||||||
) if wandb.run is None else wandb.run
|
|
||||||
|
|
||||||
# save checkpoints in wandb dir to upload on W&B servers
|
# save checkpoints in wandb dir to upload on W&B servers
|
||||||
if self._save_dir is None:
|
if self._save_dir is None:
|
||||||
|
|
|
@ -37,9 +37,13 @@ def test_wandb_logger_init(wandb, recwarn):
|
||||||
|
|
||||||
# test wandb.init called when there is no W&B run
|
# test wandb.init called when there is no W&B run
|
||||||
wandb.run = None
|
wandb.run = None
|
||||||
logger = WandbLogger()
|
logger = WandbLogger(
|
||||||
|
name='test_name', save_dir='test_save_dir', version='test_id', project='test_project', resume='never'
|
||||||
|
)
|
||||||
logger.log_metrics({'acc': 1.0})
|
logger.log_metrics({'acc': 1.0})
|
||||||
wandb.init.assert_called_once()
|
wandb.init.assert_called_once_with(
|
||||||
|
name='test_name', dir='test_save_dir', id='test_id', project='test_project', resume='never', anonymous=None
|
||||||
|
)
|
||||||
wandb.init().log.assert_called_once_with({'acc': 1.0})
|
wandb.init().log.assert_called_once_with({'acc': 1.0})
|
||||||
|
|
||||||
# test wandb.init and setting logger experiment externally
|
# test wandb.init and setting logger experiment externally
|
||||||
|
@ -55,6 +59,8 @@ def test_wandb_logger_init(wandb, recwarn):
|
||||||
wandb.init.reset_mock()
|
wandb.init.reset_mock()
|
||||||
wandb.run = wandb.init()
|
wandb.run = wandb.init()
|
||||||
logger = WandbLogger()
|
logger = WandbLogger()
|
||||||
|
# verify default resume value
|
||||||
|
assert logger._wandb_init['resume'] == 'allow'
|
||||||
logger.log_metrics({'acc': 1.0}, step=3)
|
logger.log_metrics({'acc': 1.0}, step=3)
|
||||||
wandb.init.assert_called_once()
|
wandb.init.assert_called_once()
|
||||||
wandb.init().log.assert_called_once_with({'acc': 1.0, 'trainer/global_step': 3})
|
wandb.init().log.assert_called_once_with({'acc': 1.0, 'trainer/global_step': 3})
|
||||||
|
|
Loading…
Reference in New Issue