From ff41d80706c49e34f33215b1f1cf86619b19dfc9 Mon Sep 17 00:00:00 2001 From: Boris Dayma Date: Wed, 28 Oct 2020 14:37:06 -0500 Subject: [PATCH] feat(wandb): log in sync with Trainer step (#4405) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * feat(wandb): log in sync with Trainer step * docs: update CHANGELOG * style(test_wandb): fix formatting * parentheses Co-authored-by: Adrian Wälchli Co-authored-by: Rohit Gupta --- CHANGELOG.md | 3 +++ pytorch_lightning/loggers/wandb.py | 9 +++++++-- tests/loggers/test_wandb.py | 10 ++++++++-- 3 files changed, 18 insertions(+), 4 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index b20d4ae3ea..1de62b442f 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -45,6 +45,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixed santized parameters for `WandbLogger.log_hyperparams` ([#4320](https://github.com/PyTorchLightning/pytorch-lightning/pull/4320)) +- W&B log in sync with Trainer step ([#4405](https://github.com/PyTorchLightning/pytorch-lightning/pull/4405)) + + ### Deprecated diff --git a/pytorch_lightning/loggers/wandb.py b/pytorch_lightning/loggers/wandb.py index e6ce264d59..5786a52a8e 100644 --- a/pytorch_lightning/loggers/wandb.py +++ b/pytorch_lightning/loggers/wandb.py @@ -94,6 +94,8 @@ class WandbLogger(LightningLoggerBase): self._offline = offline self._log_model = log_model self._kwargs = kwargs + # logging multiple Trainer on a single W&B run (k-fold, etc) + self._step_offset = 0 def __getstate__(self): state = self.__dict__.copy() @@ -141,8 +143,7 @@ class WandbLogger(LightningLoggerBase): @rank_zero_only def log_metrics(self, metrics: Dict[str, float], step: Optional[int] = None) -> None: assert rank_zero_only.rank == 0, 'experiment tried to log from global_rank != 0' - - self.experiment.log({'global_step': step, **metrics} if step is not None else metrics) + self.experiment.log(metrics, step=(step + self._step_offset) if step is not None else None) @property def save_dir(self) -> Optional[str]: @@ -159,6 +160,10 @@ class WandbLogger(LightningLoggerBase): return self._experiment.id if self._experiment else self._id def finalize(self, status: str) -> None: + # offset future training logged on same W&B run + if self._experiment is not None: + self._step_offset = self._experiment.step + # upload all checkpoints from saving dir if self._log_model: wandb.save(os.path.join(self.save_dir, "*.ckpt")) diff --git a/tests/loggers/test_wandb.py b/tests/loggers/test_wandb.py index 6682cfdc88..cfb6533bd9 100644 --- a/tests/loggers/test_wandb.py +++ b/tests/loggers/test_wandb.py @@ -29,11 +29,17 @@ def test_wandb_logger(wandb): logger = WandbLogger(anonymous=True, offline=True) logger.log_metrics({'acc': 1.0}) - wandb.init().log.assert_called_once_with({'acc': 1.0}) + wandb.init().log.assert_called_once_with({'acc': 1.0}, step=None) wandb.init().log.reset_mock() logger.log_metrics({'acc': 1.0}, step=3) - wandb.init().log.assert_called_once_with({'global_step': 3, 'acc': 1.0}) + wandb.init().log.assert_called_once_with({'acc': 1.0}, step=3) + + # continue training on same W&B run + wandb.init().step = 3 + logger.finalize('success') + logger.log_metrics({'acc': 1.0}, step=3) + wandb.init().log.assert_called_with({'acc': 1.0}, step=6) logger.log_hyperparams({'test': None, 'nested': {'a': 1}, 'b': [2, 3, 4]}) wandb.init().config.update.assert_called_once_with(