feat(wandb): add sync_step (#5351)
* docs(wandb): add details to args * feat(wandb): no sync between trainer and W&B steps * style: pep8 * tests(wandb): test sync_step * docs(wandb): add references * docs(wandb): fix typo * feat(wandb): more explicit warning * feat(wandb): order of args * style: Apply suggestions from code review Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com> * style: long line Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com> Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com> Co-authored-by: Sean Naren <sean.narenthiran@gmail.com>
This commit is contained in:
parent
0c9960bfbb
commit
f0fafa2be0
|
@ -49,19 +49,20 @@ class WandbLogger(LightningLoggerBase):
|
|||
|
||||
Args:
|
||||
name: Display name for the run.
|
||||
save_dir: Path where data is saved.
|
||||
save_dir: Path where data is saved (wandb dir by default).
|
||||
offline: Run offline (data can be streamed later to wandb servers).
|
||||
id: Sets the version, mainly used to resume a previous run.
|
||||
version: Same as id.
|
||||
anonymous: Enables or explicitly disables anonymous logging.
|
||||
version: Sets the version, mainly used to resume a previous run.
|
||||
project: The name of the project to which this run will belong.
|
||||
log_model: Save checkpoints in wandb dir to upload on W&B servers.
|
||||
experiment: WandB experiment object.
|
||||
prefix: A string to put at the beginning of metric keys.
|
||||
sync_step: Sync Trainer step with wandb step.
|
||||
experiment: WandB experiment object. Automatically set when creating a run.
|
||||
\**kwargs: Additional arguments like `entity`, `group`, `tags`, etc. used by
|
||||
:func:`wandb.init` can be passed as keyword arguments in this logger.
|
||||
|
||||
Example::
|
||||
Example:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
|
@ -74,9 +75,9 @@ class WandbLogger(LightningLoggerBase):
|
|||
make sure to use `commit=False` so the logging step does not increase.
|
||||
|
||||
See Also:
|
||||
- `Tutorial <https://app.wandb.ai/cayush/pytorchlightning/reports/
|
||||
Use-Pytorch-Lightning-with-Weights-%26-Biases--Vmlldzo2NjQ1Mw>`__
|
||||
on how to use W&B with Pytorch Lightning.
|
||||
- `Tutorial <https://colab.research.google.com/drive/16d1uctGaw2y9KhGBlINNTsWpmlXdJwRW?usp=sharing>`__
|
||||
on how to use W&B with PyTorch Lightning
|
||||
- `W&B Documentation <https://docs.wandb.ai/integrations/lightning>`__
|
||||
|
||||
"""
|
||||
|
||||
|
@ -86,14 +87,15 @@ class WandbLogger(LightningLoggerBase):
|
|||
self,
|
||||
name: Optional[str] = None,
|
||||
save_dir: Optional[str] = None,
|
||||
offline: bool = False,
|
||||
offline: Optional[bool] = False,
|
||||
id: Optional[str] = None,
|
||||
anonymous: bool = False,
|
||||
anonymous: Optional[bool] = False,
|
||||
version: Optional[str] = None,
|
||||
project: Optional[str] = None,
|
||||
log_model: bool = False,
|
||||
log_model: Optional[bool] = False,
|
||||
experiment=None,
|
||||
prefix: str = '',
|
||||
prefix: Optional[str] = '',
|
||||
sync_step: Optional[bool] = True,
|
||||
**kwargs
|
||||
):
|
||||
if wandb is None:
|
||||
|
@ -102,13 +104,14 @@ class WandbLogger(LightningLoggerBase):
|
|||
super().__init__()
|
||||
self._name = name
|
||||
self._save_dir = save_dir
|
||||
self._anonymous = 'allow' if anonymous else None
|
||||
self._id = version or id
|
||||
self._project = project
|
||||
self._experiment = experiment
|
||||
self._offline = offline
|
||||
self._id = version or id
|
||||
self._anonymous = 'allow' if anonymous else None
|
||||
self._project = project
|
||||
self._log_model = log_model
|
||||
self._prefix = prefix
|
||||
self._sync_step = sync_step
|
||||
self._experiment = experiment
|
||||
self._kwargs = kwargs
|
||||
# logging multiple Trainer on a single W&B run (k-fold, resuming, etc)
|
||||
self._step_offset = 0
|
||||
|
@ -164,11 +167,16 @@ class WandbLogger(LightningLoggerBase):
|
|||
assert rank_zero_only.rank == 0, 'experiment tried to log from global_rank != 0'
|
||||
|
||||
metrics = self._add_prefix(metrics)
|
||||
if step is not None and step + self._step_offset < self.experiment.step:
|
||||
if self._sync_step and step is not None and step + self._step_offset < self.experiment.step:
|
||||
self.warning_cache.warn(
|
||||
'Trying to log at a previous step. Use `commit=False` when logging metrics manually.'
|
||||
)
|
||||
self.experiment.log(metrics, step=(step + self._step_offset) if step is not None else None)
|
||||
'Trying to log at a previous step. Use `WandbLogger(sync_step=False)`'
|
||||
' or try logging with `commit=False` when calling manually `wandb.log`.')
|
||||
if self._sync_step:
|
||||
self.experiment.log(metrics, step=(step + self._step_offset) if step is not None else None)
|
||||
elif step is not None:
|
||||
self.experiment.log({**metrics, 'trainer_step': (step + self._step_offset)})
|
||||
else:
|
||||
self.experiment.log(metrics)
|
||||
|
||||
@property
|
||||
def save_dir(self) -> Optional[str]:
|
||||
|
|
|
@ -40,6 +40,18 @@ def test_wandb_logger_init(wandb, recwarn):
|
|||
wandb.init.assert_called_once()
|
||||
wandb.init().log.assert_called_once_with({'acc': 1.0}, step=None)
|
||||
|
||||
# test sync_step functionality
|
||||
wandb.init().log.reset_mock()
|
||||
wandb.init.reset_mock()
|
||||
wandb.run = None
|
||||
wandb.init().step = 0
|
||||
logger = WandbLogger(sync_step=False)
|
||||
logger.log_metrics({'acc': 1.0})
|
||||
wandb.init().log.assert_called_once_with({'acc': 1.0})
|
||||
wandb.init().log.reset_mock()
|
||||
logger.log_metrics({'acc': 1.0}, step=3)
|
||||
wandb.init().log.assert_called_once_with({'acc': 1.0, 'trainer_step': 3})
|
||||
|
||||
# mock wandb step
|
||||
wandb.init().step = 0
|
||||
|
||||
|
|
Loading…
Reference in New Issue