From 2db9ea35006f5f540fd755ca9009eafc69b4447a Mon Sep 17 00:00:00 2001 From: Boris Dayma Date: Mon, 11 Oct 2021 04:15:36 -0500 Subject: [PATCH] feat(wandb): support media logging (#9545) --- CHANGELOG.md | 3 + docs/source/common/loggers.rst | 4 +- pytorch_lightning/loggers/base.py | 22 +++ pytorch_lightning/loggers/wandb.py | 244 ++++++++++++++++++++++++++--- tests/loggers/test_wandb.py | 63 ++++++++ 5 files changed, 315 insertions(+), 21 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index ec0ab057e6..70044b8779 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -146,6 +146,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Added `ModelSummary` callback ([#9344](https://github.com/PyTorchLightning/pytorch-lightning/pull/9344)) +- Added `log_images`, `log_text` and `log_table` to `WandbLogger` ([#9545](https://github.com/PyTorchLightning/pytorch-lightning/pull/9545)) + + - Added `PL_RECONCILE_PROCESS` environment variable to enable process reconciliation regardless of cluster environment settings ([#9389](https://github.com/PyTorchLightning/pytorch-lightning/pull/9389)) diff --git a/docs/source/common/loggers.rst b/docs/source/common/loggers.rst index fccb8a62b6..a49e9e3ec3 100644 --- a/docs/source/common/loggers.rst +++ b/docs/source/common/loggers.rst @@ -251,7 +251,9 @@ The :class:`~pytorch_lightning.loggers.WandbLogger` is available anywhere except self.log({"generated_images": [wandb.Image(some_img, caption="...")]}) .. seealso:: - :class:`~pytorch_lightning.loggers.WandbLogger` docs. + - :class:`~pytorch_lightning.loggers.WandbLogger` docs. + - `W&B Documentation `__ + - `Demo in Google Colab `__ with hyperparameter search and model logging ---------------- diff --git a/pytorch_lightning/loggers/base.py b/pytorch_lightning/loggers/base.py index 71e7b9f902..5c64164541 100644 --- a/pytorch_lightning/loggers/base.py +++ b/pytorch_lightning/loggers/base.py @@ -298,6 +298,20 @@ class LightningLoggerBase(ABC): """ pass + def log_text(self, *args, **kwargs) -> None: + """Log text. + + Arguments are directly passed to the logger. + """ + raise NotImplementedError + + def log_image(self, *args, **kwargs) -> None: + """Log image. + + Arguments are directly passed to the logger. + """ + raise NotImplementedError + def save(self) -> None: """Save log data.""" self._finalize_agg_metrics() @@ -395,6 +409,14 @@ class LoggerCollection(LightningLoggerBase): for logger in self._logger_iterable: logger.log_graph(model, input_array) + def log_text(self, *args, **kwargs) -> None: + for logger in self._logger_iterable: + logger.log_text(*args, **kwargs) + + def log_image(self, *args, **kwargs) -> None: + for logger in self._logger_iterable: + logger.log_image(*args, **kwargs) + def save(self) -> None: for logger in self._logger_iterable: logger.save() diff --git a/pytorch_lightning/loggers/wandb.py b/pytorch_lightning/loggers/wandb.py index 0e9d68b369..8d15e4fa6e 100644 --- a/pytorch_lightning/loggers/wandb.py +++ b/pytorch_lightning/loggers/wandb.py @@ -19,7 +19,7 @@ import operator import os from argparse import Namespace from pathlib import Path -from typing import Any, Dict, Optional, Union +from typing import Any, Dict, List, Optional, Union from weakref import ReferenceType import torch.nn as nn @@ -46,12 +46,180 @@ class WandbLogger(LightningLoggerBase): r""" Log using `Weights and Biases `_. - Install it with pip: + **Installation and set-up** + + Install with pip: .. code-block:: bash pip install wandb + Create a `WandbLogger` instance: + + .. code-block:: python + + from pytorch_lightning.loggers import WandbLogger + + wandb_logger = WandbLogger(project="MNIST") + + Pass the logger instance to the `Trainer`: + + .. code-block:: python + + trainer = Trainer(logger=wandb_logger) + + A new W&B run will be created when training starts if you have not created one manually before with `wandb.init()`. + + **Log metrics** + + Log from :class:`~pytorch_lightning.core.lightning.LightningModule`: + + .. code-block:: python + + class LitModule(LightningModule): + def training_step(self, batch, batch_idx): + self.log("train/loss", loss) + + Use directly wandb module: + + .. code-block:: python + + wandb.log({"train/loss": loss}) + + **Log hyper-parameters** + + Save :class:`~pytorch_lightning.core.lightning.LightningModule` parameters: + + .. code-block:: python + + class LitModule(LightningModule): + def __init__(self, *args, **kwarg): + self.save_hyperparameters() + + Add other config parameters: + + .. code-block:: python + + # add one parameter + wandb_logger.experiment.config["key"] = value + + # add multiple parameters + wandb_logger.experiment.config.update({key1: val1, key2: val2}) + + # use directly wandb module + wandb.config["key"] = value + wandb.config.update() + + **Log gradients, parameters and model topology** + + Call the `watch` method for automatically tracking gradients: + + .. code-block:: python + + # log gradients and model topology + wandb_logger.watch(model) + + # log gradients, parameter histogram and model topology + wandb_logger.watch(model, log="all") + + # change log frequency of gradients and parameters (100 steps by default) + wandb_logger.watch(model, log_freq=500) + + # do not log graph (in case of errors) + wandb_logger.watch(model, log_graph=False) + + The `watch` method adds hooks to the model which can be removed at the end of training: + + .. code-block:: python + + wandb_logger.unwatch(model) + + **Log model checkpoints** + + Log model checkpoints at the end of training: + + .. code-block:: python + + wandb_logger = WandbLogger(log_model=True) + + Log model checkpoints as they get created during training: + + .. code-block:: python + + wandb_logger = WandbLogger(log_model="all") + + Custom checkpointing can be set up through :class:`~pytorch_lightning.callbacks.ModelCheckpoint`: + + .. code-block:: python + + # log model only if `val_accuracy` increases + wandb_logger = WandbLogger(log_model="all") + checkpoint_callback = ModelCheckpoint(monitor="val_accuracy", mode="max") + trainer = Trainer(logger=wandb_logger, callbacks=[checkpoint_callback]) + + `latest` and `best` aliases are automatically set to easily retrieve a model checkpoint: + + .. code-block:: python + + # reference can be retrieved in artifacts panel + # "VERSION" can be a version (ex: "v2") or an alias ("latest or "best") + checkpoint_reference = "USER/PROJECT/MODEL-RUN_ID:VERSION" + + # download checkpoint locally (if not already cached) + run = wandb.init(project="MNIST") + artifact = run.use_artifact(checkpoint_reference, type="model") + artifact_dir = artifact.download() + + # load checkpoint + model = LitModule.load_from_checkpoint(Path(artifact_dir) / "model.ckpt") + + **Log media** + + Log text with: + + .. code-block:: python + + # using columns and data + columns = ["input", "label", "prediction"] + data = [["cheese", "english", "english"], ["fromage", "french", "spanish"]] + wandb_logger.log_text(key="samples", columns=columns, data=data) + + # using a pandas DataFrame + wandb_logger.log_text(key="samples", dataframe=my_dataframe) + + Log images with: + + .. code-block:: python + + # using tensors, numpy arrays or PIL images + wandb_logger.log_image(key="samples", images=[img1, img2]) + + # adding captions + wandb_logger.log_image(key="samples", images=[img1, img2], caption=["tree", "person"]) + + # using file path + wandb_logger.log_image(key="samples", images=["img_1.jpg", "img_2.jpg"]) + + More arguments can be passed for logging segmentation masks and bounding boxes. Refer to + `Image Overlays documentation `_. + + **Log Tables** + + `W&B Tables `_ can be used to log, query and analyze tabular data. + + They support any type of media (text, image, video, audio, molecule, html, etc) and are great for storing, + understanding and sharing any form of data, from datasets to model predictions. + + .. code-block:: python + + columns = ["caption", "image", "sound"] + data = [["cheese", wandb.Image(img_1), wandb.Audio(snd_1)], ["wine", wandb.Image(img_2), wandb.Audio(snd_2)]] + wandb_logger.log_table(key="samples", columns=columns, data=data) + + See Also: + - `Demo in Google Colab `__ with hyperparameter search and model logging + - `W&B Documentation `__ + Args: name: Display name for the run. save_dir: Path where data is saved (wandb dir by default). @@ -61,7 +229,7 @@ class WandbLogger(LightningLoggerBase): anonymous: Enables or explicitly disables anonymous logging. project: The name of the project to which this run will belong. log_model: Log checkpoints created by :class:`~pytorch_lightning.callbacks.model_checkpoint.ModelCheckpoint` - as W&B artifacts. + as W&B artifacts. `latest` and `best` aliases are automatically set. * if ``log_model == 'all'``, checkpoints are logged during training. * if ``log_model == True``, checkpoints are logged at the end of training, except when @@ -77,23 +245,7 @@ class WandbLogger(LightningLoggerBase): ModuleNotFoundError: If required WandB package is not installed on the device. MisconfigurationException: - If both ``log_model`` and ``offline``is set to ``True``. - - Example:: - - from pytorch_lightning.loggers import WandbLogger - from pytorch_lightning import Trainer - - # instrument experiment with W&B - wandb_logger = WandbLogger(project='MNIST', log_model='all') - trainer = Trainer(logger=wandb_logger) - - # log gradients and model topology - wandb_logger.watch(model) - - See Also: - - `Demo in Google Colab `__ with model logging - - `W&B Documentation `__ + If both ``log_model`` and ``offline`` is set to ``True``. """ @@ -175,6 +327,8 @@ class WandbLogger(LightningLoggerBase): Example:: + .. code-block:: python + self.logger.experiment.some_wandb_function() """ @@ -217,6 +371,56 @@ class WandbLogger(LightningLoggerBase): else: self.experiment.log(metrics) + @rank_zero_only + def log_table( + self, + key: str, + columns: List[str] = None, + data: List[List[Any]] = None, + dataframe: Any = None, + step: Optional[int] = None, + ) -> None: + """Log a Table containing any object type (text, image, audio, video, molecule, html, etc). + + Can be defined either with `columns` and `data` or with `dataframe`. + """ + + metrics = {key: wandb.Table(columns=columns, data=data, dataframe=dataframe)} + self.log_metrics(metrics, step) + + @rank_zero_only + def log_text( + self, + key: str, + columns: List[str] = None, + data: List[List[str]] = None, + dataframe: Any = None, + step: Optional[int] = None, + ) -> None: + """Log text as a Table. + + Can be defined either with `columns` and `data` or with `dataframe`. + """ + + self.log_table(key, columns, data, dataframe, step) + + @rank_zero_only + def log_image(self, key: str, images: List[Any], **kwargs: str) -> None: + """Log images (tensors, numpy arrays, PIL Images or file paths). + + Optional kwargs are lists passed to each image (ex: caption, masks, boxes). + """ + if not isinstance(images, list): + raise TypeError(f'Expected a list as "images", found {type(images)}') + n = len(images) + for k, v in kwargs.items(): + if len(v) != n: + raise ValueError(f"Expected {n} items but only found {len(v)} for {k}") + step = kwargs.pop("step", None) + kwarg_list = [{k: kwargs[k][i] for k in kwargs.keys()} for i in range(n)] + metrics = {key: [wandb.Image(img, **kwarg) for img, kwarg in zip(images, kwarg_list)]} + self.log_metrics(metrics, step) + @property def save_dir(self) -> Optional[str]: """Gets the save directory. diff --git a/tests/loggers/test_wandb.py b/tests/loggers/test_wandb.py index 8388d7877a..85b20c5624 100644 --- a/tests/loggers/test_wandb.py +++ b/tests/loggers/test_wandb.py @@ -216,6 +216,69 @@ def test_wandb_log_model(wandb, tmpdir): ) +@mock.patch("pytorch_lightning.loggers.wandb.wandb") +def test_wandb_log_media(wandb, tmpdir): + """Test that the logger creates the folders and files in the right place.""" + + wandb.run = None + + # test log_text with columns and data + columns = ["input", "label", "prediction"] + data = [["cheese", "english", "english"], ["fromage", "french", "spanish"]] + logger = WandbLogger() + logger.log_text(key="samples", columns=columns, data=data) + wandb.Table.assert_called_once_with( + columns=["input", "label", "prediction"], + data=[["cheese", "english", "english"], ["fromage", "french", "spanish"]], + dataframe=None, + ) + wandb.init().log.assert_called_once_with({"samples": wandb.Table()}) + + # test log_text with dataframe + wandb.Table.reset_mock() + wandb.init().log.reset_mock() + df = 'pandas.DataFrame({"col1": [1, 2], "col2": [3, 4]})' # TODO: incompatible numpy/pandas versions in test env + logger.log_text(key="samples", dataframe=df) + wandb.Table.assert_called_once_with( + columns=None, + data=None, + dataframe=df, + ) + wandb.init().log.assert_called_once_with({"samples": wandb.Table()}) + + # test log_image + wandb.init().log.reset_mock() + logger.log_image(key="samples", images=["1.jpg", "2.jpg"]) + wandb.Image.assert_called_with("2.jpg") + wandb.init().log.assert_called_once_with({"samples": [wandb.Image(), wandb.Image()]}) + + # test log_image with captions + wandb.init().log.reset_mock() + wandb.Image.reset_mock() + logger.log_image(key="samples", images=["1.jpg", "2.jpg"], caption=["caption 1", "caption 2"]) + wandb.Image.assert_called_with("2.jpg", caption="caption 2") + wandb.init().log.assert_called_once_with({"samples": [wandb.Image(), wandb.Image()]}) + + # test log_image without a list + with pytest.raises(TypeError, match="""Expected a list as "images", found """): + logger.log_image(key="samples", images="1.jpg") + + # test log_image with wrong number of captions + with pytest.raises(ValueError, match="Expected 2 items but only found 1 for caption"): + logger.log_image(key="samples", images=["1.jpg", "2.jpg"], caption=["caption 1"]) + + # test log_table + wandb.Table.reset_mock() + wandb.init().log.reset_mock() + logger.log_table(key="samples", columns=columns, data=data, dataframe=df, step=5) + wandb.Table.assert_called_once_with( + columns=columns, + data=data, + dataframe=df, + ) + wandb.init().log.assert_called_once_with({"samples": wandb.Table(), "trainer/global_step": 5}) + + def test_wandb_sanitize_callable_params(tmpdir): """Callback function are not serializiable.