diff --git a/CHANGELOG.md b/CHANGELOG.md index b20d4ae3ea..1de62b442f 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -45,6 +45,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixed santized parameters for `WandbLogger.log_hyperparams` ([#4320](https://github.com/PyTorchLightning/pytorch-lightning/pull/4320)) +- W&B log in sync with Trainer step ([#4405](https://github.com/PyTorchLightning/pytorch-lightning/pull/4405)) + + ### Deprecated diff --git a/pytorch_lightning/loggers/wandb.py b/pytorch_lightning/loggers/wandb.py index e6ce264d59..5786a52a8e 100644 --- a/pytorch_lightning/loggers/wandb.py +++ b/pytorch_lightning/loggers/wandb.py @@ -94,6 +94,8 @@ class WandbLogger(LightningLoggerBase): self._offline = offline self._log_model = log_model self._kwargs = kwargs + # logging multiple Trainer on a single W&B run (k-fold, etc) + self._step_offset = 0 def __getstate__(self): state = self.__dict__.copy() @@ -141,8 +143,7 @@ class WandbLogger(LightningLoggerBase): @rank_zero_only def log_metrics(self, metrics: Dict[str, float], step: Optional[int] = None) -> None: assert rank_zero_only.rank == 0, 'experiment tried to log from global_rank != 0' - - self.experiment.log({'global_step': step, **metrics} if step is not None else metrics) + self.experiment.log(metrics, step=(step + self._step_offset) if step is not None else None) @property def save_dir(self) -> Optional[str]: @@ -159,6 +160,10 @@ class WandbLogger(LightningLoggerBase): return self._experiment.id if self._experiment else self._id def finalize(self, status: str) -> None: + # offset future training logged on same W&B run + if self._experiment is not None: + self._step_offset = self._experiment.step + # upload all checkpoints from saving dir if self._log_model: wandb.save(os.path.join(self.save_dir, "*.ckpt")) diff --git a/tests/loggers/test_wandb.py b/tests/loggers/test_wandb.py index 6682cfdc88..cfb6533bd9 100644 --- a/tests/loggers/test_wandb.py +++ b/tests/loggers/test_wandb.py @@ -29,11 +29,17 @@ def test_wandb_logger(wandb): logger = WandbLogger(anonymous=True, offline=True) logger.log_metrics({'acc': 1.0}) - wandb.init().log.assert_called_once_with({'acc': 1.0}) + wandb.init().log.assert_called_once_with({'acc': 1.0}, step=None) wandb.init().log.reset_mock() logger.log_metrics({'acc': 1.0}, step=3) - wandb.init().log.assert_called_once_with({'global_step': 3, 'acc': 1.0}) + wandb.init().log.assert_called_once_with({'acc': 1.0}, step=3) + + # continue training on same W&B run + wandb.init().step = 3 + logger.finalize('success') + logger.log_metrics({'acc': 1.0}, step=3) + wandb.init().log.assert_called_with({'acc': 1.0}, step=6) logger.log_hyperparams({'test': None, 'nested': {'a': 1}, 'b': [2, 3, 4]}) wandb.init().config.update.assert_called_once_with(