feat(wandb): add sync_step (#5351)

* docs(wandb): add details to args

* feat(wandb): no sync between trainer and W&B steps

* style: pep8

* tests(wandb): test sync_step

* docs(wandb): add references

* docs(wandb): fix typo

* feat(wandb): more explicit warning

* feat(wandb): order of args

* style: Apply suggestions from code review

Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com>

* style: long line

Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com>
Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com>
Co-authored-by: Sean Naren <sean.narenthiran@gmail.com>
This commit is contained in:
Boris Dayma 2021-01-24 23:44:09 +01:00 committed by GitHub
parent 0c9960bfbb
commit f0fafa2be0
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 39 additions and 19 deletions

View File

@ -49,19 +49,20 @@ class WandbLogger(LightningLoggerBase):
Args:
name: Display name for the run.
save_dir: Path where data is saved.
save_dir: Path where data is saved (wandb dir by default).
offline: Run offline (data can be streamed later to wandb servers).
id: Sets the version, mainly used to resume a previous run.
version: Same as id.
anonymous: Enables or explicitly disables anonymous logging.
version: Sets the version, mainly used to resume a previous run.
project: The name of the project to which this run will belong.
log_model: Save checkpoints in wandb dir to upload on W&B servers.
experiment: WandB experiment object.
prefix: A string to put at the beginning of metric keys.
sync_step: Sync Trainer step with wandb step.
experiment: WandB experiment object. Automatically set when creating a run.
\**kwargs: Additional arguments like `entity`, `group`, `tags`, etc. used by
:func:`wandb.init` can be passed as keyword arguments in this logger.
Example::
Example:
.. code-block:: python
@ -74,9 +75,9 @@ class WandbLogger(LightningLoggerBase):
make sure to use `commit=False` so the logging step does not increase.
See Also:
- `Tutorial <https://app.wandb.ai/cayush/pytorchlightning/reports/
Use-Pytorch-Lightning-with-Weights-%26-Biases--Vmlldzo2NjQ1Mw>`__
on how to use W&B with Pytorch Lightning.
- `Tutorial <https://colab.research.google.com/drive/16d1uctGaw2y9KhGBlINNTsWpmlXdJwRW?usp=sharing>`__
on how to use W&B with PyTorch Lightning
- `W&B Documentation <https://docs.wandb.ai/integrations/lightning>`__
"""
@ -86,14 +87,15 @@ class WandbLogger(LightningLoggerBase):
self,
name: Optional[str] = None,
save_dir: Optional[str] = None,
offline: bool = False,
offline: Optional[bool] = False,
id: Optional[str] = None,
anonymous: bool = False,
anonymous: Optional[bool] = False,
version: Optional[str] = None,
project: Optional[str] = None,
log_model: bool = False,
log_model: Optional[bool] = False,
experiment=None,
prefix: str = '',
prefix: Optional[str] = '',
sync_step: Optional[bool] = True,
**kwargs
):
if wandb is None:
@ -102,13 +104,14 @@ class WandbLogger(LightningLoggerBase):
super().__init__()
self._name = name
self._save_dir = save_dir
self._anonymous = 'allow' if anonymous else None
self._id = version or id
self._project = project
self._experiment = experiment
self._offline = offline
self._id = version or id
self._anonymous = 'allow' if anonymous else None
self._project = project
self._log_model = log_model
self._prefix = prefix
self._sync_step = sync_step
self._experiment = experiment
self._kwargs = kwargs
# logging multiple Trainer on a single W&B run (k-fold, resuming, etc)
self._step_offset = 0
@ -164,11 +167,16 @@ class WandbLogger(LightningLoggerBase):
assert rank_zero_only.rank == 0, 'experiment tried to log from global_rank != 0'
metrics = self._add_prefix(metrics)
if step is not None and step + self._step_offset < self.experiment.step:
if self._sync_step and step is not None and step + self._step_offset < self.experiment.step:
self.warning_cache.warn(
'Trying to log at a previous step. Use `commit=False` when logging metrics manually.'
)
self.experiment.log(metrics, step=(step + self._step_offset) if step is not None else None)
'Trying to log at a previous step. Use `WandbLogger(sync_step=False)`'
' or try logging with `commit=False` when calling manually `wandb.log`.')
if self._sync_step:
self.experiment.log(metrics, step=(step + self._step_offset) if step is not None else None)
elif step is not None:
self.experiment.log({**metrics, 'trainer_step': (step + self._step_offset)})
else:
self.experiment.log(metrics)
@property
def save_dir(self) -> Optional[str]:

View File

@ -40,6 +40,18 @@ def test_wandb_logger_init(wandb, recwarn):
wandb.init.assert_called_once()
wandb.init().log.assert_called_once_with({'acc': 1.0}, step=None)
# test sync_step functionality
wandb.init().log.reset_mock()
wandb.init.reset_mock()
wandb.run = None
wandb.init().step = 0
logger = WandbLogger(sync_step=False)
logger.log_metrics({'acc': 1.0})
wandb.init().log.assert_called_once_with({'acc': 1.0})
wandb.init().log.reset_mock()
logger.log_metrics({'acc': 1.0}, step=3)
wandb.init().log.assert_called_once_with({'acc': 1.0, 'trainer_step': 3})
# mock wandb step
wandb.init().step = 0