From f0fafa2be08fa46bb1b4bdddb2fc9add1adbb6f3 Mon Sep 17 00:00:00 2001 From: Boris Dayma Date: Sun, 24 Jan 2021 23:44:09 +0100 Subject: [PATCH] feat(wandb): add sync_step (#5351) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * docs(wandb): add details to args * feat(wandb): no sync between trainer and W&B steps * style: pep8 * tests(wandb): test sync_step * docs(wandb): add references * docs(wandb): fix typo * feat(wandb): more explicit warning * feat(wandb): order of args * style: Apply suggestions from code review Co-authored-by: Adrian Wälchli * style: long line Co-authored-by: Adrian Wälchli Co-authored-by: Rohit Gupta Co-authored-by: Sean Naren --- pytorch_lightning/loggers/wandb.py | 46 ++++++++++++++++++------------ tests/loggers/test_wandb.py | 12 ++++++++ 2 files changed, 39 insertions(+), 19 deletions(-) diff --git a/pytorch_lightning/loggers/wandb.py b/pytorch_lightning/loggers/wandb.py index 443c50ac48..f588429aa8 100644 --- a/pytorch_lightning/loggers/wandb.py +++ b/pytorch_lightning/loggers/wandb.py @@ -49,19 +49,20 @@ class WandbLogger(LightningLoggerBase): Args: name: Display name for the run. - save_dir: Path where data is saved. + save_dir: Path where data is saved (wandb dir by default). offline: Run offline (data can be streamed later to wandb servers). id: Sets the version, mainly used to resume a previous run. + version: Same as id. anonymous: Enables or explicitly disables anonymous logging. - version: Sets the version, mainly used to resume a previous run. project: The name of the project to which this run will belong. log_model: Save checkpoints in wandb dir to upload on W&B servers. - experiment: WandB experiment object. prefix: A string to put at the beginning of metric keys. + sync_step: Sync Trainer step with wandb step. + experiment: WandB experiment object. Automatically set when creating a run. \**kwargs: Additional arguments like `entity`, `group`, `tags`, etc. used by :func:`wandb.init` can be passed as keyword arguments in this logger. - Example:: + Example: .. code-block:: python @@ -74,9 +75,9 @@ class WandbLogger(LightningLoggerBase): make sure to use `commit=False` so the logging step does not increase. See Also: - - `Tutorial `__ - on how to use W&B with Pytorch Lightning. + - `Tutorial `__ + on how to use W&B with PyTorch Lightning + - `W&B Documentation `__ """ @@ -86,14 +87,15 @@ class WandbLogger(LightningLoggerBase): self, name: Optional[str] = None, save_dir: Optional[str] = None, - offline: bool = False, + offline: Optional[bool] = False, id: Optional[str] = None, - anonymous: bool = False, + anonymous: Optional[bool] = False, version: Optional[str] = None, project: Optional[str] = None, - log_model: bool = False, + log_model: Optional[bool] = False, experiment=None, - prefix: str = '', + prefix: Optional[str] = '', + sync_step: Optional[bool] = True, **kwargs ): if wandb is None: @@ -102,13 +104,14 @@ class WandbLogger(LightningLoggerBase): super().__init__() self._name = name self._save_dir = save_dir - self._anonymous = 'allow' if anonymous else None - self._id = version or id - self._project = project - self._experiment = experiment self._offline = offline + self._id = version or id + self._anonymous = 'allow' if anonymous else None + self._project = project self._log_model = log_model self._prefix = prefix + self._sync_step = sync_step + self._experiment = experiment self._kwargs = kwargs # logging multiple Trainer on a single W&B run (k-fold, resuming, etc) self._step_offset = 0 @@ -164,11 +167,16 @@ class WandbLogger(LightningLoggerBase): assert rank_zero_only.rank == 0, 'experiment tried to log from global_rank != 0' metrics = self._add_prefix(metrics) - if step is not None and step + self._step_offset < self.experiment.step: + if self._sync_step and step is not None and step + self._step_offset < self.experiment.step: self.warning_cache.warn( - 'Trying to log at a previous step. Use `commit=False` when logging metrics manually.' - ) - self.experiment.log(metrics, step=(step + self._step_offset) if step is not None else None) + 'Trying to log at a previous step. Use `WandbLogger(sync_step=False)`' + ' or try logging with `commit=False` when calling manually `wandb.log`.') + if self._sync_step: + self.experiment.log(metrics, step=(step + self._step_offset) if step is not None else None) + elif step is not None: + self.experiment.log({**metrics, 'trainer_step': (step + self._step_offset)}) + else: + self.experiment.log(metrics) @property def save_dir(self) -> Optional[str]: diff --git a/tests/loggers/test_wandb.py b/tests/loggers/test_wandb.py index 94c947e8d6..fdbf8602d5 100644 --- a/tests/loggers/test_wandb.py +++ b/tests/loggers/test_wandb.py @@ -40,6 +40,18 @@ def test_wandb_logger_init(wandb, recwarn): wandb.init.assert_called_once() wandb.init().log.assert_called_once_with({'acc': 1.0}, step=None) + # test sync_step functionality + wandb.init().log.reset_mock() + wandb.init.reset_mock() + wandb.run = None + wandb.init().step = 0 + logger = WandbLogger(sync_step=False) + logger.log_metrics({'acc': 1.0}) + wandb.init().log.assert_called_once_with({'acc': 1.0}) + wandb.init().log.reset_mock() + logger.log_metrics({'acc': 1.0}, step=3) + wandb.init().log.assert_called_once_with({'acc': 1.0, 'trainer_step': 3}) + # mock wandb step wandb.init().step = 0