diff --git a/CHANGELOG.md b/CHANGELOG.md index 0250507d61..b9a063510c 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -460,6 +460,10 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixed optimizer `state` not moved to `GPU` ([#7277](https://github.com/PyTorchLightning/pytorch-lightning/pull/7277)) +- Fixed custom init args for `WandbLogger` ([#6989](https://github.com/PyTorchLightning/pytorch-lightning/pull/6989)) + + + ## [1.2.7] - 2021-04-06 ### Fixed diff --git a/pytorch_lightning/loggers/wandb.py b/pytorch_lightning/loggers/wandb.py index ad93d1de28..0f73153378 100644 --- a/pytorch_lightning/loggers/wandb.py +++ b/pytorch_lightning/loggers/wandb.py @@ -59,8 +59,7 @@ class WandbLogger(LightningLoggerBase): log_model: Save checkpoints in wandb dir to upload on W&B servers. prefix: A string to put at the beginning of metric keys. experiment: WandB experiment object. Automatically set when creating a run. - \**kwargs: Additional arguments like `entity`, `group`, `tags`, etc. used by - :func:`wandb.init` can be passed as keyword arguments in this logger. + \**kwargs: Arguments passed to :func:`wandb.init` like `entity`, `group`, `tags`, etc. Raises: ImportError: @@ -93,7 +92,7 @@ class WandbLogger(LightningLoggerBase): save_dir: Optional[str] = None, offline: Optional[bool] = False, id: Optional[str] = None, - anonymous: Optional[bool] = False, + anonymous: Optional[bool] = None, version: Optional[str] = None, project: Optional[str] = None, log_model: Optional[bool] = False, @@ -122,16 +121,25 @@ class WandbLogger(LightningLoggerBase): ) super().__init__() - self._name = name - self._save_dir = save_dir self._offline = offline - self._id = version or id - self._anonymous = 'allow' if anonymous else None - self._project = project self._log_model = log_model self._prefix = prefix self._experiment = experiment - self._kwargs = kwargs + # set wandb init arguments + anonymous_lut = {True: 'allow', False: None} + self._wandb_init = dict( + name=name, + project=project, + id=version or id, + dir=save_dir, + resume='allow', + anonymous=anonymous_lut.get(anonymous, anonymous) + ) + self._wandb_init.update(**kwargs) + # extract parameters + self._save_dir = self._wandb_init.get('dir') + self._name = self._wandb_init.get('name') + self._id = self._wandb_init.get('id') def __getstate__(self): state = self.__dict__.copy() @@ -158,15 +166,7 @@ class WandbLogger(LightningLoggerBase): if self._experiment is None: if self._offline: os.environ['WANDB_MODE'] = 'dryrun' - self._experiment = wandb.init( - name=self._name, - dir=self._save_dir, - project=self._project, - anonymous=self._anonymous, - id=self._id, - resume='allow', - **self._kwargs - ) if wandb.run is None else wandb.run + self._experiment = wandb.init(**self._wandb_init) if wandb.run is None else wandb.run # save checkpoints in wandb dir to upload on W&B servers if self._save_dir is None: diff --git a/tests/loggers/test_wandb.py b/tests/loggers/test_wandb.py index 1bae877608..22be315eaa 100644 --- a/tests/loggers/test_wandb.py +++ b/tests/loggers/test_wandb.py @@ -37,9 +37,13 @@ def test_wandb_logger_init(wandb, recwarn): # test wandb.init called when there is no W&B run wandb.run = None - logger = WandbLogger() + logger = WandbLogger( + name='test_name', save_dir='test_save_dir', version='test_id', project='test_project', resume='never' + ) logger.log_metrics({'acc': 1.0}) - wandb.init.assert_called_once() + wandb.init.assert_called_once_with( + name='test_name', dir='test_save_dir', id='test_id', project='test_project', resume='never', anonymous=None + ) wandb.init().log.assert_called_once_with({'acc': 1.0}) # test wandb.init and setting logger experiment externally @@ -55,6 +59,8 @@ def test_wandb_logger_init(wandb, recwarn): wandb.init.reset_mock() wandb.run = wandb.init() logger = WandbLogger() + # verify default resume value + assert logger._wandb_init['resume'] == 'allow' logger.log_metrics({'acc': 1.0}, step=3) wandb.init.assert_called_once() wandb.init().log.assert_called_once_with({'acc': 1.0, 'trainer/global_step': 3})