feat(wandb): support media logging (#9545)

This commit is contained in:
Boris Dayma 2021-10-11 04:15:36 -05:00 committed by GitHub
parent ce8233e6f0
commit 2db9ea3500
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 315 additions and 21 deletions

View File

@ -146,6 +146,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Added `ModelSummary` callback ([#9344](https://github.com/PyTorchLightning/pytorch-lightning/pull/9344))
- Added `log_images`, `log_text` and `log_table` to `WandbLogger` ([#9545](https://github.com/PyTorchLightning/pytorch-lightning/pull/9545))
- Added `PL_RECONCILE_PROCESS` environment variable to enable process reconciliation regardless of cluster environment settings ([#9389](https://github.com/PyTorchLightning/pytorch-lightning/pull/9389))

View File

@ -251,7 +251,9 @@ The :class:`~pytorch_lightning.loggers.WandbLogger` is available anywhere except
self.log({"generated_images": [wandb.Image(some_img, caption="...")]})
.. seealso::
:class:`~pytorch_lightning.loggers.WandbLogger` docs.
- :class:`~pytorch_lightning.loggers.WandbLogger` docs.
- `W&B Documentation <https://docs.wandb.ai/integrations/lightning>`__
- `Demo in Google Colab <http://wandb.me/lightning>`__ with hyperparameter search and model logging
----------------

View File

@ -298,6 +298,20 @@ class LightningLoggerBase(ABC):
"""
pass
def log_text(self, *args, **kwargs) -> None:
"""Log text.
Arguments are directly passed to the logger.
"""
raise NotImplementedError
def log_image(self, *args, **kwargs) -> None:
"""Log image.
Arguments are directly passed to the logger.
"""
raise NotImplementedError
def save(self) -> None:
"""Save log data."""
self._finalize_agg_metrics()
@ -395,6 +409,14 @@ class LoggerCollection(LightningLoggerBase):
for logger in self._logger_iterable:
logger.log_graph(model, input_array)
def log_text(self, *args, **kwargs) -> None:
for logger in self._logger_iterable:
logger.log_text(*args, **kwargs)
def log_image(self, *args, **kwargs) -> None:
for logger in self._logger_iterable:
logger.log_image(*args, **kwargs)
def save(self) -> None:
for logger in self._logger_iterable:
logger.save()

View File

@ -19,7 +19,7 @@ import operator
import os
from argparse import Namespace
from pathlib import Path
from typing import Any, Dict, Optional, Union
from typing import Any, Dict, List, Optional, Union
from weakref import ReferenceType
import torch.nn as nn
@ -46,12 +46,180 @@ class WandbLogger(LightningLoggerBase):
r"""
Log using `Weights and Biases <https://docs.wandb.ai/integrations/lightning>`_.
Install it with pip:
**Installation and set-up**
Install with pip:
.. code-block:: bash
pip install wandb
Create a `WandbLogger` instance:
.. code-block:: python
from pytorch_lightning.loggers import WandbLogger
wandb_logger = WandbLogger(project="MNIST")
Pass the logger instance to the `Trainer`:
.. code-block:: python
trainer = Trainer(logger=wandb_logger)
A new W&B run will be created when training starts if you have not created one manually before with `wandb.init()`.
**Log metrics**
Log from :class:`~pytorch_lightning.core.lightning.LightningModule`:
.. code-block:: python
class LitModule(LightningModule):
def training_step(self, batch, batch_idx):
self.log("train/loss", loss)
Use directly wandb module:
.. code-block:: python
wandb.log({"train/loss": loss})
**Log hyper-parameters**
Save :class:`~pytorch_lightning.core.lightning.LightningModule` parameters:
.. code-block:: python
class LitModule(LightningModule):
def __init__(self, *args, **kwarg):
self.save_hyperparameters()
Add other config parameters:
.. code-block:: python
# add one parameter
wandb_logger.experiment.config["key"] = value
# add multiple parameters
wandb_logger.experiment.config.update({key1: val1, key2: val2})
# use directly wandb module
wandb.config["key"] = value
wandb.config.update()
**Log gradients, parameters and model topology**
Call the `watch` method for automatically tracking gradients:
.. code-block:: python
# log gradients and model topology
wandb_logger.watch(model)
# log gradients, parameter histogram and model topology
wandb_logger.watch(model, log="all")
# change log frequency of gradients and parameters (100 steps by default)
wandb_logger.watch(model, log_freq=500)
# do not log graph (in case of errors)
wandb_logger.watch(model, log_graph=False)
The `watch` method adds hooks to the model which can be removed at the end of training:
.. code-block:: python
wandb_logger.unwatch(model)
**Log model checkpoints**
Log model checkpoints at the end of training:
.. code-block:: python
wandb_logger = WandbLogger(log_model=True)
Log model checkpoints as they get created during training:
.. code-block:: python
wandb_logger = WandbLogger(log_model="all")
Custom checkpointing can be set up through :class:`~pytorch_lightning.callbacks.ModelCheckpoint`:
.. code-block:: python
# log model only if `val_accuracy` increases
wandb_logger = WandbLogger(log_model="all")
checkpoint_callback = ModelCheckpoint(monitor="val_accuracy", mode="max")
trainer = Trainer(logger=wandb_logger, callbacks=[checkpoint_callback])
`latest` and `best` aliases are automatically set to easily retrieve a model checkpoint:
.. code-block:: python
# reference can be retrieved in artifacts panel
# "VERSION" can be a version (ex: "v2") or an alias ("latest or "best")
checkpoint_reference = "USER/PROJECT/MODEL-RUN_ID:VERSION"
# download checkpoint locally (if not already cached)
run = wandb.init(project="MNIST")
artifact = run.use_artifact(checkpoint_reference, type="model")
artifact_dir = artifact.download()
# load checkpoint
model = LitModule.load_from_checkpoint(Path(artifact_dir) / "model.ckpt")
**Log media**
Log text with:
.. code-block:: python
# using columns and data
columns = ["input", "label", "prediction"]
data = [["cheese", "english", "english"], ["fromage", "french", "spanish"]]
wandb_logger.log_text(key="samples", columns=columns, data=data)
# using a pandas DataFrame
wandb_logger.log_text(key="samples", dataframe=my_dataframe)
Log images with:
.. code-block:: python
# using tensors, numpy arrays or PIL images
wandb_logger.log_image(key="samples", images=[img1, img2])
# adding captions
wandb_logger.log_image(key="samples", images=[img1, img2], caption=["tree", "person"])
# using file path
wandb_logger.log_image(key="samples", images=["img_1.jpg", "img_2.jpg"])
More arguments can be passed for logging segmentation masks and bounding boxes. Refer to
`Image Overlays documentation <https://docs.wandb.ai/guides/track/log/media#image-overlays>`_.
**Log Tables**
`W&B Tables <https://docs.wandb.ai/guides/data-vis>`_ can be used to log, query and analyze tabular data.
They support any type of media (text, image, video, audio, molecule, html, etc) and are great for storing,
understanding and sharing any form of data, from datasets to model predictions.
.. code-block:: python
columns = ["caption", "image", "sound"]
data = [["cheese", wandb.Image(img_1), wandb.Audio(snd_1)], ["wine", wandb.Image(img_2), wandb.Audio(snd_2)]]
wandb_logger.log_table(key="samples", columns=columns, data=data)
See Also:
- `Demo in Google Colab <http://wandb.me/lightning>`__ with hyperparameter search and model logging
- `W&B Documentation <https://docs.wandb.ai/integrations/lightning>`__
Args:
name: Display name for the run.
save_dir: Path where data is saved (wandb dir by default).
@ -61,7 +229,7 @@ class WandbLogger(LightningLoggerBase):
anonymous: Enables or explicitly disables anonymous logging.
project: The name of the project to which this run will belong.
log_model: Log checkpoints created by :class:`~pytorch_lightning.callbacks.model_checkpoint.ModelCheckpoint`
as W&B artifacts.
as W&B artifacts. `latest` and `best` aliases are automatically set.
* if ``log_model == 'all'``, checkpoints are logged during training.
* if ``log_model == True``, checkpoints are logged at the end of training, except when
@ -77,23 +245,7 @@ class WandbLogger(LightningLoggerBase):
ModuleNotFoundError:
If required WandB package is not installed on the device.
MisconfigurationException:
If both ``log_model`` and ``offline``is set to ``True``.
Example::
from pytorch_lightning.loggers import WandbLogger
from pytorch_lightning import Trainer
# instrument experiment with W&B
wandb_logger = WandbLogger(project='MNIST', log_model='all')
trainer = Trainer(logger=wandb_logger)
# log gradients and model topology
wandb_logger.watch(model)
See Also:
- `Demo in Google Colab <http://wandb.me/lightning>`__ with model logging
- `W&B Documentation <https://docs.wandb.ai/integrations/lightning>`__
If both ``log_model`` and ``offline`` is set to ``True``.
"""
@ -175,6 +327,8 @@ class WandbLogger(LightningLoggerBase):
Example::
.. code-block:: python
self.logger.experiment.some_wandb_function()
"""
@ -217,6 +371,56 @@ class WandbLogger(LightningLoggerBase):
else:
self.experiment.log(metrics)
@rank_zero_only
def log_table(
self,
key: str,
columns: List[str] = None,
data: List[List[Any]] = None,
dataframe: Any = None,
step: Optional[int] = None,
) -> None:
"""Log a Table containing any object type (text, image, audio, video, molecule, html, etc).
Can be defined either with `columns` and `data` or with `dataframe`.
"""
metrics = {key: wandb.Table(columns=columns, data=data, dataframe=dataframe)}
self.log_metrics(metrics, step)
@rank_zero_only
def log_text(
self,
key: str,
columns: List[str] = None,
data: List[List[str]] = None,
dataframe: Any = None,
step: Optional[int] = None,
) -> None:
"""Log text as a Table.
Can be defined either with `columns` and `data` or with `dataframe`.
"""
self.log_table(key, columns, data, dataframe, step)
@rank_zero_only
def log_image(self, key: str, images: List[Any], **kwargs: str) -> None:
"""Log images (tensors, numpy arrays, PIL Images or file paths).
Optional kwargs are lists passed to each image (ex: caption, masks, boxes).
"""
if not isinstance(images, list):
raise TypeError(f'Expected a list as "images", found {type(images)}')
n = len(images)
for k, v in kwargs.items():
if len(v) != n:
raise ValueError(f"Expected {n} items but only found {len(v)} for {k}")
step = kwargs.pop("step", None)
kwarg_list = [{k: kwargs[k][i] for k in kwargs.keys()} for i in range(n)]
metrics = {key: [wandb.Image(img, **kwarg) for img, kwarg in zip(images, kwarg_list)]}
self.log_metrics(metrics, step)
@property
def save_dir(self) -> Optional[str]:
"""Gets the save directory.

View File

@ -216,6 +216,69 @@ def test_wandb_log_model(wandb, tmpdir):
)
@mock.patch("pytorch_lightning.loggers.wandb.wandb")
def test_wandb_log_media(wandb, tmpdir):
"""Test that the logger creates the folders and files in the right place."""
wandb.run = None
# test log_text with columns and data
columns = ["input", "label", "prediction"]
data = [["cheese", "english", "english"], ["fromage", "french", "spanish"]]
logger = WandbLogger()
logger.log_text(key="samples", columns=columns, data=data)
wandb.Table.assert_called_once_with(
columns=["input", "label", "prediction"],
data=[["cheese", "english", "english"], ["fromage", "french", "spanish"]],
dataframe=None,
)
wandb.init().log.assert_called_once_with({"samples": wandb.Table()})
# test log_text with dataframe
wandb.Table.reset_mock()
wandb.init().log.reset_mock()
df = 'pandas.DataFrame({"col1": [1, 2], "col2": [3, 4]})' # TODO: incompatible numpy/pandas versions in test env
logger.log_text(key="samples", dataframe=df)
wandb.Table.assert_called_once_with(
columns=None,
data=None,
dataframe=df,
)
wandb.init().log.assert_called_once_with({"samples": wandb.Table()})
# test log_image
wandb.init().log.reset_mock()
logger.log_image(key="samples", images=["1.jpg", "2.jpg"])
wandb.Image.assert_called_with("2.jpg")
wandb.init().log.assert_called_once_with({"samples": [wandb.Image(), wandb.Image()]})
# test log_image with captions
wandb.init().log.reset_mock()
wandb.Image.reset_mock()
logger.log_image(key="samples", images=["1.jpg", "2.jpg"], caption=["caption 1", "caption 2"])
wandb.Image.assert_called_with("2.jpg", caption="caption 2")
wandb.init().log.assert_called_once_with({"samples": [wandb.Image(), wandb.Image()]})
# test log_image without a list
with pytest.raises(TypeError, match="""Expected a list as "images", found <class 'str'>"""):
logger.log_image(key="samples", images="1.jpg")
# test log_image with wrong number of captions
with pytest.raises(ValueError, match="Expected 2 items but only found 1 for caption"):
logger.log_image(key="samples", images=["1.jpg", "2.jpg"], caption=["caption 1"])
# test log_table
wandb.Table.reset_mock()
wandb.init().log.reset_mock()
logger.log_table(key="samples", columns=columns, data=data, dataframe=df, step=5)
wandb.Table.assert_called_once_with(
columns=columns,
data=data,
dataframe=df,
)
wandb.init().log.assert_called_once_with({"samples": wandb.Table(), "trainer/global_step": 5})
def test_wandb_sanitize_callable_params(tmpdir):
"""Callback function are not serializiable.