feat(wandb): support media logging (#9545)
This commit is contained in:
parent
ce8233e6f0
commit
2db9ea3500
|
@ -146,6 +146,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
|
|||
- Added `ModelSummary` callback ([#9344](https://github.com/PyTorchLightning/pytorch-lightning/pull/9344))
|
||||
|
||||
|
||||
- Added `log_images`, `log_text` and `log_table` to `WandbLogger` ([#9545](https://github.com/PyTorchLightning/pytorch-lightning/pull/9545))
|
||||
|
||||
|
||||
- Added `PL_RECONCILE_PROCESS` environment variable to enable process reconciliation regardless of cluster environment settings ([#9389](https://github.com/PyTorchLightning/pytorch-lightning/pull/9389))
|
||||
|
||||
|
||||
|
|
|
@ -251,7 +251,9 @@ The :class:`~pytorch_lightning.loggers.WandbLogger` is available anywhere except
|
|||
self.log({"generated_images": [wandb.Image(some_img, caption="...")]})
|
||||
|
||||
.. seealso::
|
||||
:class:`~pytorch_lightning.loggers.WandbLogger` docs.
|
||||
- :class:`~pytorch_lightning.loggers.WandbLogger` docs.
|
||||
- `W&B Documentation <https://docs.wandb.ai/integrations/lightning>`__
|
||||
- `Demo in Google Colab <http://wandb.me/lightning>`__ with hyperparameter search and model logging
|
||||
|
||||
----------------
|
||||
|
||||
|
|
|
@ -298,6 +298,20 @@ class LightningLoggerBase(ABC):
|
|||
"""
|
||||
pass
|
||||
|
||||
def log_text(self, *args, **kwargs) -> None:
|
||||
"""Log text.
|
||||
|
||||
Arguments are directly passed to the logger.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def log_image(self, *args, **kwargs) -> None:
|
||||
"""Log image.
|
||||
|
||||
Arguments are directly passed to the logger.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def save(self) -> None:
|
||||
"""Save log data."""
|
||||
self._finalize_agg_metrics()
|
||||
|
@ -395,6 +409,14 @@ class LoggerCollection(LightningLoggerBase):
|
|||
for logger in self._logger_iterable:
|
||||
logger.log_graph(model, input_array)
|
||||
|
||||
def log_text(self, *args, **kwargs) -> None:
|
||||
for logger in self._logger_iterable:
|
||||
logger.log_text(*args, **kwargs)
|
||||
|
||||
def log_image(self, *args, **kwargs) -> None:
|
||||
for logger in self._logger_iterable:
|
||||
logger.log_image(*args, **kwargs)
|
||||
|
||||
def save(self) -> None:
|
||||
for logger in self._logger_iterable:
|
||||
logger.save()
|
||||
|
|
|
@ -19,7 +19,7 @@ import operator
|
|||
import os
|
||||
from argparse import Namespace
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, Optional, Union
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
from weakref import ReferenceType
|
||||
|
||||
import torch.nn as nn
|
||||
|
@ -46,12 +46,180 @@ class WandbLogger(LightningLoggerBase):
|
|||
r"""
|
||||
Log using `Weights and Biases <https://docs.wandb.ai/integrations/lightning>`_.
|
||||
|
||||
Install it with pip:
|
||||
**Installation and set-up**
|
||||
|
||||
Install with pip:
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
pip install wandb
|
||||
|
||||
Create a `WandbLogger` instance:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
from pytorch_lightning.loggers import WandbLogger
|
||||
|
||||
wandb_logger = WandbLogger(project="MNIST")
|
||||
|
||||
Pass the logger instance to the `Trainer`:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
trainer = Trainer(logger=wandb_logger)
|
||||
|
||||
A new W&B run will be created when training starts if you have not created one manually before with `wandb.init()`.
|
||||
|
||||
**Log metrics**
|
||||
|
||||
Log from :class:`~pytorch_lightning.core.lightning.LightningModule`:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
class LitModule(LightningModule):
|
||||
def training_step(self, batch, batch_idx):
|
||||
self.log("train/loss", loss)
|
||||
|
||||
Use directly wandb module:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
wandb.log({"train/loss": loss})
|
||||
|
||||
**Log hyper-parameters**
|
||||
|
||||
Save :class:`~pytorch_lightning.core.lightning.LightningModule` parameters:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
class LitModule(LightningModule):
|
||||
def __init__(self, *args, **kwarg):
|
||||
self.save_hyperparameters()
|
||||
|
||||
Add other config parameters:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
# add one parameter
|
||||
wandb_logger.experiment.config["key"] = value
|
||||
|
||||
# add multiple parameters
|
||||
wandb_logger.experiment.config.update({key1: val1, key2: val2})
|
||||
|
||||
# use directly wandb module
|
||||
wandb.config["key"] = value
|
||||
wandb.config.update()
|
||||
|
||||
**Log gradients, parameters and model topology**
|
||||
|
||||
Call the `watch` method for automatically tracking gradients:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
# log gradients and model topology
|
||||
wandb_logger.watch(model)
|
||||
|
||||
# log gradients, parameter histogram and model topology
|
||||
wandb_logger.watch(model, log="all")
|
||||
|
||||
# change log frequency of gradients and parameters (100 steps by default)
|
||||
wandb_logger.watch(model, log_freq=500)
|
||||
|
||||
# do not log graph (in case of errors)
|
||||
wandb_logger.watch(model, log_graph=False)
|
||||
|
||||
The `watch` method adds hooks to the model which can be removed at the end of training:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
wandb_logger.unwatch(model)
|
||||
|
||||
**Log model checkpoints**
|
||||
|
||||
Log model checkpoints at the end of training:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
wandb_logger = WandbLogger(log_model=True)
|
||||
|
||||
Log model checkpoints as they get created during training:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
wandb_logger = WandbLogger(log_model="all")
|
||||
|
||||
Custom checkpointing can be set up through :class:`~pytorch_lightning.callbacks.ModelCheckpoint`:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
# log model only if `val_accuracy` increases
|
||||
wandb_logger = WandbLogger(log_model="all")
|
||||
checkpoint_callback = ModelCheckpoint(monitor="val_accuracy", mode="max")
|
||||
trainer = Trainer(logger=wandb_logger, callbacks=[checkpoint_callback])
|
||||
|
||||
`latest` and `best` aliases are automatically set to easily retrieve a model checkpoint:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
# reference can be retrieved in artifacts panel
|
||||
# "VERSION" can be a version (ex: "v2") or an alias ("latest or "best")
|
||||
checkpoint_reference = "USER/PROJECT/MODEL-RUN_ID:VERSION"
|
||||
|
||||
# download checkpoint locally (if not already cached)
|
||||
run = wandb.init(project="MNIST")
|
||||
artifact = run.use_artifact(checkpoint_reference, type="model")
|
||||
artifact_dir = artifact.download()
|
||||
|
||||
# load checkpoint
|
||||
model = LitModule.load_from_checkpoint(Path(artifact_dir) / "model.ckpt")
|
||||
|
||||
**Log media**
|
||||
|
||||
Log text with:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
# using columns and data
|
||||
columns = ["input", "label", "prediction"]
|
||||
data = [["cheese", "english", "english"], ["fromage", "french", "spanish"]]
|
||||
wandb_logger.log_text(key="samples", columns=columns, data=data)
|
||||
|
||||
# using a pandas DataFrame
|
||||
wandb_logger.log_text(key="samples", dataframe=my_dataframe)
|
||||
|
||||
Log images with:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
# using tensors, numpy arrays or PIL images
|
||||
wandb_logger.log_image(key="samples", images=[img1, img2])
|
||||
|
||||
# adding captions
|
||||
wandb_logger.log_image(key="samples", images=[img1, img2], caption=["tree", "person"])
|
||||
|
||||
# using file path
|
||||
wandb_logger.log_image(key="samples", images=["img_1.jpg", "img_2.jpg"])
|
||||
|
||||
More arguments can be passed for logging segmentation masks and bounding boxes. Refer to
|
||||
`Image Overlays documentation <https://docs.wandb.ai/guides/track/log/media#image-overlays>`_.
|
||||
|
||||
**Log Tables**
|
||||
|
||||
`W&B Tables <https://docs.wandb.ai/guides/data-vis>`_ can be used to log, query and analyze tabular data.
|
||||
|
||||
They support any type of media (text, image, video, audio, molecule, html, etc) and are great for storing,
|
||||
understanding and sharing any form of data, from datasets to model predictions.
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
columns = ["caption", "image", "sound"]
|
||||
data = [["cheese", wandb.Image(img_1), wandb.Audio(snd_1)], ["wine", wandb.Image(img_2), wandb.Audio(snd_2)]]
|
||||
wandb_logger.log_table(key="samples", columns=columns, data=data)
|
||||
|
||||
See Also:
|
||||
- `Demo in Google Colab <http://wandb.me/lightning>`__ with hyperparameter search and model logging
|
||||
- `W&B Documentation <https://docs.wandb.ai/integrations/lightning>`__
|
||||
|
||||
Args:
|
||||
name: Display name for the run.
|
||||
save_dir: Path where data is saved (wandb dir by default).
|
||||
|
@ -61,7 +229,7 @@ class WandbLogger(LightningLoggerBase):
|
|||
anonymous: Enables or explicitly disables anonymous logging.
|
||||
project: The name of the project to which this run will belong.
|
||||
log_model: Log checkpoints created by :class:`~pytorch_lightning.callbacks.model_checkpoint.ModelCheckpoint`
|
||||
as W&B artifacts.
|
||||
as W&B artifacts. `latest` and `best` aliases are automatically set.
|
||||
|
||||
* if ``log_model == 'all'``, checkpoints are logged during training.
|
||||
* if ``log_model == True``, checkpoints are logged at the end of training, except when
|
||||
|
@ -77,23 +245,7 @@ class WandbLogger(LightningLoggerBase):
|
|||
ModuleNotFoundError:
|
||||
If required WandB package is not installed on the device.
|
||||
MisconfigurationException:
|
||||
If both ``log_model`` and ``offline``is set to ``True``.
|
||||
|
||||
Example::
|
||||
|
||||
from pytorch_lightning.loggers import WandbLogger
|
||||
from pytorch_lightning import Trainer
|
||||
|
||||
# instrument experiment with W&B
|
||||
wandb_logger = WandbLogger(project='MNIST', log_model='all')
|
||||
trainer = Trainer(logger=wandb_logger)
|
||||
|
||||
# log gradients and model topology
|
||||
wandb_logger.watch(model)
|
||||
|
||||
See Also:
|
||||
- `Demo in Google Colab <http://wandb.me/lightning>`__ with model logging
|
||||
- `W&B Documentation <https://docs.wandb.ai/integrations/lightning>`__
|
||||
If both ``log_model`` and ``offline`` is set to ``True``.
|
||||
|
||||
"""
|
||||
|
||||
|
@ -175,6 +327,8 @@ class WandbLogger(LightningLoggerBase):
|
|||
|
||||
Example::
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
self.logger.experiment.some_wandb_function()
|
||||
|
||||
"""
|
||||
|
@ -217,6 +371,56 @@ class WandbLogger(LightningLoggerBase):
|
|||
else:
|
||||
self.experiment.log(metrics)
|
||||
|
||||
@rank_zero_only
|
||||
def log_table(
|
||||
self,
|
||||
key: str,
|
||||
columns: List[str] = None,
|
||||
data: List[List[Any]] = None,
|
||||
dataframe: Any = None,
|
||||
step: Optional[int] = None,
|
||||
) -> None:
|
||||
"""Log a Table containing any object type (text, image, audio, video, molecule, html, etc).
|
||||
|
||||
Can be defined either with `columns` and `data` or with `dataframe`.
|
||||
"""
|
||||
|
||||
metrics = {key: wandb.Table(columns=columns, data=data, dataframe=dataframe)}
|
||||
self.log_metrics(metrics, step)
|
||||
|
||||
@rank_zero_only
|
||||
def log_text(
|
||||
self,
|
||||
key: str,
|
||||
columns: List[str] = None,
|
||||
data: List[List[str]] = None,
|
||||
dataframe: Any = None,
|
||||
step: Optional[int] = None,
|
||||
) -> None:
|
||||
"""Log text as a Table.
|
||||
|
||||
Can be defined either with `columns` and `data` or with `dataframe`.
|
||||
"""
|
||||
|
||||
self.log_table(key, columns, data, dataframe, step)
|
||||
|
||||
@rank_zero_only
|
||||
def log_image(self, key: str, images: List[Any], **kwargs: str) -> None:
|
||||
"""Log images (tensors, numpy arrays, PIL Images or file paths).
|
||||
|
||||
Optional kwargs are lists passed to each image (ex: caption, masks, boxes).
|
||||
"""
|
||||
if not isinstance(images, list):
|
||||
raise TypeError(f'Expected a list as "images", found {type(images)}')
|
||||
n = len(images)
|
||||
for k, v in kwargs.items():
|
||||
if len(v) != n:
|
||||
raise ValueError(f"Expected {n} items but only found {len(v)} for {k}")
|
||||
step = kwargs.pop("step", None)
|
||||
kwarg_list = [{k: kwargs[k][i] for k in kwargs.keys()} for i in range(n)]
|
||||
metrics = {key: [wandb.Image(img, **kwarg) for img, kwarg in zip(images, kwarg_list)]}
|
||||
self.log_metrics(metrics, step)
|
||||
|
||||
@property
|
||||
def save_dir(self) -> Optional[str]:
|
||||
"""Gets the save directory.
|
||||
|
|
|
@ -216,6 +216,69 @@ def test_wandb_log_model(wandb, tmpdir):
|
|||
)
|
||||
|
||||
|
||||
@mock.patch("pytorch_lightning.loggers.wandb.wandb")
|
||||
def test_wandb_log_media(wandb, tmpdir):
|
||||
"""Test that the logger creates the folders and files in the right place."""
|
||||
|
||||
wandb.run = None
|
||||
|
||||
# test log_text with columns and data
|
||||
columns = ["input", "label", "prediction"]
|
||||
data = [["cheese", "english", "english"], ["fromage", "french", "spanish"]]
|
||||
logger = WandbLogger()
|
||||
logger.log_text(key="samples", columns=columns, data=data)
|
||||
wandb.Table.assert_called_once_with(
|
||||
columns=["input", "label", "prediction"],
|
||||
data=[["cheese", "english", "english"], ["fromage", "french", "spanish"]],
|
||||
dataframe=None,
|
||||
)
|
||||
wandb.init().log.assert_called_once_with({"samples": wandb.Table()})
|
||||
|
||||
# test log_text with dataframe
|
||||
wandb.Table.reset_mock()
|
||||
wandb.init().log.reset_mock()
|
||||
df = 'pandas.DataFrame({"col1": [1, 2], "col2": [3, 4]})' # TODO: incompatible numpy/pandas versions in test env
|
||||
logger.log_text(key="samples", dataframe=df)
|
||||
wandb.Table.assert_called_once_with(
|
||||
columns=None,
|
||||
data=None,
|
||||
dataframe=df,
|
||||
)
|
||||
wandb.init().log.assert_called_once_with({"samples": wandb.Table()})
|
||||
|
||||
# test log_image
|
||||
wandb.init().log.reset_mock()
|
||||
logger.log_image(key="samples", images=["1.jpg", "2.jpg"])
|
||||
wandb.Image.assert_called_with("2.jpg")
|
||||
wandb.init().log.assert_called_once_with({"samples": [wandb.Image(), wandb.Image()]})
|
||||
|
||||
# test log_image with captions
|
||||
wandb.init().log.reset_mock()
|
||||
wandb.Image.reset_mock()
|
||||
logger.log_image(key="samples", images=["1.jpg", "2.jpg"], caption=["caption 1", "caption 2"])
|
||||
wandb.Image.assert_called_with("2.jpg", caption="caption 2")
|
||||
wandb.init().log.assert_called_once_with({"samples": [wandb.Image(), wandb.Image()]})
|
||||
|
||||
# test log_image without a list
|
||||
with pytest.raises(TypeError, match="""Expected a list as "images", found <class 'str'>"""):
|
||||
logger.log_image(key="samples", images="1.jpg")
|
||||
|
||||
# test log_image with wrong number of captions
|
||||
with pytest.raises(ValueError, match="Expected 2 items but only found 1 for caption"):
|
||||
logger.log_image(key="samples", images=["1.jpg", "2.jpg"], caption=["caption 1"])
|
||||
|
||||
# test log_table
|
||||
wandb.Table.reset_mock()
|
||||
wandb.init().log.reset_mock()
|
||||
logger.log_table(key="samples", columns=columns, data=data, dataframe=df, step=5)
|
||||
wandb.Table.assert_called_once_with(
|
||||
columns=columns,
|
||||
data=data,
|
||||
dataframe=df,
|
||||
)
|
||||
wandb.init().log.assert_called_once_with({"samples": wandb.Table(), "trainer/global_step": 5})
|
||||
|
||||
|
||||
def test_wandb_sanitize_callable_params(tmpdir):
|
||||
"""Callback function are not serializiable.
|
||||
|
||||
|
|
Loading…
Reference in New Issue