feat(wandb): log in sync with Trainer step (#4405)
* feat(wandb): log in sync with Trainer step * docs: update CHANGELOG * style(test_wandb): fix formatting * parentheses Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com> Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com>
This commit is contained in:
parent
41de4538aa
commit
ff41d80706
|
@ -45,6 +45,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
|
||||||
- Fixed santized parameters for `WandbLogger.log_hyperparams` ([#4320](https://github.com/PyTorchLightning/pytorch-lightning/pull/4320))
|
- Fixed santized parameters for `WandbLogger.log_hyperparams` ([#4320](https://github.com/PyTorchLightning/pytorch-lightning/pull/4320))
|
||||||
|
|
||||||
|
|
||||||
|
- W&B log in sync with Trainer step ([#4405](https://github.com/PyTorchLightning/pytorch-lightning/pull/4405))
|
||||||
|
|
||||||
|
|
||||||
### Deprecated
|
### Deprecated
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -94,6 +94,8 @@ class WandbLogger(LightningLoggerBase):
|
||||||
self._offline = offline
|
self._offline = offline
|
||||||
self._log_model = log_model
|
self._log_model = log_model
|
||||||
self._kwargs = kwargs
|
self._kwargs = kwargs
|
||||||
|
# logging multiple Trainer on a single W&B run (k-fold, etc)
|
||||||
|
self._step_offset = 0
|
||||||
|
|
||||||
def __getstate__(self):
|
def __getstate__(self):
|
||||||
state = self.__dict__.copy()
|
state = self.__dict__.copy()
|
||||||
|
@ -141,8 +143,7 @@ class WandbLogger(LightningLoggerBase):
|
||||||
@rank_zero_only
|
@rank_zero_only
|
||||||
def log_metrics(self, metrics: Dict[str, float], step: Optional[int] = None) -> None:
|
def log_metrics(self, metrics: Dict[str, float], step: Optional[int] = None) -> None:
|
||||||
assert rank_zero_only.rank == 0, 'experiment tried to log from global_rank != 0'
|
assert rank_zero_only.rank == 0, 'experiment tried to log from global_rank != 0'
|
||||||
|
self.experiment.log(metrics, step=(step + self._step_offset) if step is not None else None)
|
||||||
self.experiment.log({'global_step': step, **metrics} if step is not None else metrics)
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def save_dir(self) -> Optional[str]:
|
def save_dir(self) -> Optional[str]:
|
||||||
|
@ -159,6 +160,10 @@ class WandbLogger(LightningLoggerBase):
|
||||||
return self._experiment.id if self._experiment else self._id
|
return self._experiment.id if self._experiment else self._id
|
||||||
|
|
||||||
def finalize(self, status: str) -> None:
|
def finalize(self, status: str) -> None:
|
||||||
|
# offset future training logged on same W&B run
|
||||||
|
if self._experiment is not None:
|
||||||
|
self._step_offset = self._experiment.step
|
||||||
|
|
||||||
# upload all checkpoints from saving dir
|
# upload all checkpoints from saving dir
|
||||||
if self._log_model:
|
if self._log_model:
|
||||||
wandb.save(os.path.join(self.save_dir, "*.ckpt"))
|
wandb.save(os.path.join(self.save_dir, "*.ckpt"))
|
||||||
|
|
|
@ -29,11 +29,17 @@ def test_wandb_logger(wandb):
|
||||||
logger = WandbLogger(anonymous=True, offline=True)
|
logger = WandbLogger(anonymous=True, offline=True)
|
||||||
|
|
||||||
logger.log_metrics({'acc': 1.0})
|
logger.log_metrics({'acc': 1.0})
|
||||||
wandb.init().log.assert_called_once_with({'acc': 1.0})
|
wandb.init().log.assert_called_once_with({'acc': 1.0}, step=None)
|
||||||
|
|
||||||
wandb.init().log.reset_mock()
|
wandb.init().log.reset_mock()
|
||||||
logger.log_metrics({'acc': 1.0}, step=3)
|
logger.log_metrics({'acc': 1.0}, step=3)
|
||||||
wandb.init().log.assert_called_once_with({'global_step': 3, 'acc': 1.0})
|
wandb.init().log.assert_called_once_with({'acc': 1.0}, step=3)
|
||||||
|
|
||||||
|
# continue training on same W&B run
|
||||||
|
wandb.init().step = 3
|
||||||
|
logger.finalize('success')
|
||||||
|
logger.log_metrics({'acc': 1.0}, step=3)
|
||||||
|
wandb.init().log.assert_called_with({'acc': 1.0}, step=6)
|
||||||
|
|
||||||
logger.log_hyperparams({'test': None, 'nested': {'a': 1}, 'b': [2, 3, 4]})
|
logger.log_hyperparams({'test': None, 'nested': {'a': 1}, 'b': [2, 3, 4]})
|
||||||
wandb.init().config.update.assert_called_once_with(
|
wandb.init().config.update.assert_called_once_with(
|
||||||
|
|
Loading…
Reference in New Issue