From 9097347ea8e354c5bed1aec186be305c1216bc81 Mon Sep 17 00:00:00 2001 From: Boris Dayma Date: Thu, 27 May 2021 13:15:02 -0500 Subject: [PATCH] feat(wandb): log models as artifacts (#6231) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Adrian Wälchli --- CHANGELOG.md | 1 + docs/source/common/loggers.rst | 13 ++- .../callbacks/model_checkpoint.py | 5 ++ pytorch_lightning/loggers/base.py | 15 ++++ pytorch_lightning/loggers/wandb.py | 89 ++++++++++++++++--- tests/loggers/test_base.py | 5 ++ tests/loggers/test_wandb.py | 77 +++++++++++++--- 7 files changed, 176 insertions(+), 29 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 4259684748..22c7070cfb 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -100,6 +100,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - MLflowLogger now uses the env variable `MLFLOW_TRACKING_URI` as default tracking uri ([#7457](https://github.com/PyTorchLightning/pytorch-lightning/pull/7457)) +- Changed `WandbLogger(log_model={True/'all'})` to log models as artifacts ([#6231](https://github.com/PyTorchLightning/pytorch-lightning/pull/6231)) - MLFlowLogger now accepts `run_name` as an constructor argument ([#7622](https://github.com/PyTorchLightning/pytorch-lightning/issues/7622)) diff --git a/docs/source/common/loggers.rst b/docs/source/common/loggers.rst index c6c5f0d865..5b1f13dbf4 100644 --- a/docs/source/common/loggers.rst +++ b/docs/source/common/loggers.rst @@ -202,7 +202,7 @@ The :class:`~pytorch_lightning.loggers.TestTubeLogger` is available anywhere exc Weights and Biases ================== -`Weights and Biases `_ is a third-party logger. +`Weights and Biases `_ is a third-party logger. To use :class:`~pytorch_lightning.loggers.WandbLogger` as your logger do the following. First, install the package: @@ -215,9 +215,14 @@ Then configure the logger and pass it to the :class:`~pytorch_lightning.trainer. .. code-block:: python from pytorch_lightning.loggers import WandbLogger - wandb_logger = WandbLogger(offline=True) + + # instrument experiment with W&B + wandb_logger = WandbLogger(project='MNIST', log_model='all') trainer = Trainer(logger=wandb_logger) + # log gradients and model topology + wandb_logger.watch(model) + The :class:`~pytorch_lightning.loggers.WandbLogger` is available anywhere except ``__init__`` in your :class:`~pytorch_lightning.core.lightning.LightningModule`. @@ -226,8 +231,8 @@ The :class:`~pytorch_lightning.loggers.WandbLogger` is available anywhere except class MyModule(LightningModule): def any_lightning_module_function_or_hook(self): some_img = fake_image() - self.logger.experiment.log({ - "generated_images": [wandb.Image(some_img, caption="...")] + self.log({ + "generated_images": [wandb.Image(some_img, caption="...")] }) .. seealso:: diff --git a/pytorch_lightning/callbacks/model_checkpoint.py b/pytorch_lightning/callbacks/model_checkpoint.py index 7642ad95d0..067ebfdeaf 100644 --- a/pytorch_lightning/callbacks/model_checkpoint.py +++ b/pytorch_lightning/callbacks/model_checkpoint.py @@ -26,6 +26,7 @@ from copy import deepcopy from datetime import timedelta from pathlib import Path from typing import Any, Callable, Dict, Optional, Union +from weakref import proxy import numpy as np import torch @@ -330,6 +331,10 @@ class ModelCheckpoint(Callback): # Mode 3: save last checkpoints self._save_last_checkpoint(trainer, monitor_candidates) + # notify loggers + if trainer.is_global_zero and trainer.logger: + trainer.logger.after_save_checkpoint(proxy(self)) + def _should_skip_saving_checkpoint(self, trainer: 'pl.Trainer') -> bool: from pytorch_lightning.trainer.states import TrainerFn return ( diff --git a/pytorch_lightning/loggers/base.py b/pytorch_lightning/loggers/base.py index 035a42338f..7736ed24ba 100644 --- a/pytorch_lightning/loggers/base.py +++ b/pytorch_lightning/loggers/base.py @@ -20,10 +20,12 @@ from abc import ABC, abstractmethod from argparse import Namespace from functools import wraps from typing import Any, Callable, Dict, Iterable, List, Mapping, MutableMapping, Optional, Sequence, Tuple, Union +from weakref import ReferenceType import numpy as np import torch +from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint from pytorch_lightning.core.lightning import LightningModule from pytorch_lightning.utilities import rank_zero_only @@ -71,6 +73,15 @@ class LightningLoggerBase(ABC): self._agg_key_funcs = agg_key_funcs if agg_key_funcs else {} self._agg_default_func = agg_default_func + def after_save_checkpoint(self, checkpoint_callback: 'ReferenceType[ModelCheckpoint]') -> None: + """ + Called after model checkpoint callback saves a new checkpoint + + Args: + model_checkpoint: the model checkpoint callback instance + """ + pass + def update_agg_funcs( self, agg_key_funcs: Optional[Mapping[str, Callable[[Sequence[float]], float]]] = None, @@ -357,6 +368,10 @@ class LoggerCollection(LightningLoggerBase): def __getitem__(self, index: int) -> LightningLoggerBase: return [logger for logger in self._logger_iterable][index] + def after_save_checkpoint(self, checkpoint_callback: 'ReferenceType[ModelCheckpoint]') -> None: + for logger in self._logger_iterable: + logger.after_save_checkpoint(checkpoint_callback) + def update_agg_funcs( self, agg_key_funcs: Optional[Mapping[str, Callable[[Sequence[float]], float]]] = None, diff --git a/pytorch_lightning/loggers/wandb.py b/pytorch_lightning/loggers/wandb.py index 0f73153378..c127fa037e 100644 --- a/pytorch_lightning/loggers/wandb.py +++ b/pytorch_lightning/loggers/wandb.py @@ -15,20 +15,26 @@ Weights and Biases Logger ------------------------- """ +import operator import os from argparse import Namespace +from pathlib import Path from typing import Any, Dict, Optional, Union +from weakref import ReferenceType import torch.nn as nn +from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint from pytorch_lightning.loggers.base import LightningLoggerBase, rank_zero_experiment from pytorch_lightning.utilities import _module_available, rank_zero_only from pytorch_lightning.utilities.exceptions import MisconfigurationException +from pytorch_lightning.utilities.imports import _compare_version from pytorch_lightning.utilities.warnings import WarningCache warning_cache = WarningCache() _WANDB_AVAILABLE = _module_available("wandb") +_WANDB_GREATER_EQUAL_0_10_22 = _compare_version("wandb", operator.ge, "0.10.22") try: import wandb @@ -40,7 +46,7 @@ except ImportError: class WandbLogger(LightningLoggerBase): r""" - Log using `Weights and Biases `_. + Log using `Weights and Biases `_. Install it with pip: @@ -56,7 +62,15 @@ class WandbLogger(LightningLoggerBase): version: Same as id. anonymous: Enables or explicitly disables anonymous logging. project: The name of the project to which this run will belong. - log_model: Save checkpoints in wandb dir to upload on W&B servers. + log_model: Log checkpoints created by :class:`~pytorch_lightning.callbacks.model_checkpoint.ModelCheckpoint` + as W&B artifacts. + + * if ``log_model == 'all'``, checkpoints are logged during training. + * if ``log_model == True``, checkpoints are logged at the end of training, except when + :paramref:`~pytorch_lightning.callbacks.model_checkpoint.ModelCheckpoint.save_top_k` ``== -1`` + which also logs every checkpoint during training. + * if ``log_model == False`` (default), no checkpoint is logged. + prefix: A string to put at the beginning of metric keys. experiment: WandB experiment object. Automatically set when creating a run. \**kwargs: Arguments passed to :func:`wandb.init` like `entity`, `group`, `tags`, etc. @@ -71,15 +85,16 @@ class WandbLogger(LightningLoggerBase): from pytorch_lightning.loggers import WandbLogger from pytorch_lightning import Trainer - wandb_logger = WandbLogger() + + # instrument experiment with W&B + wandb_logger = WandbLogger(project='MNIST', log_model='all') trainer = Trainer(logger=wandb_logger) - Note: When logging manually through `wandb.log` or `trainer.logger.experiment.log`, - make sure to use `commit=False` so the logging step does not increase. + # log gradients and model topology + wandb_logger.watch(model) See Also: - - `Tutorial `__ - on how to use W&B with PyTorch Lightning + - `Demo in Google Colab `__ with model logging - `W&B Documentation `__ """ @@ -114,6 +129,13 @@ class WandbLogger(LightningLoggerBase): 'Hint: Set `offline=False` to log your model.' ) + if log_model and not _WANDB_GREATER_EQUAL_0_10_22: + warning_cache.warn( + f'Providing log_model={log_model} requires wandb version >= 0.10.22' + ' for logging associated model metadata.\n' + 'Hint: Upgrade with `pip install --ugrade wandb`.' + ) + if sync_step is not None: warning_cache.warn( "`WandbLogger(sync_step=(True|False))` is deprecated in v1.2.1 and will be removed in v1.5." @@ -125,6 +147,8 @@ class WandbLogger(LightningLoggerBase): self._log_model = log_model self._prefix = prefix self._experiment = experiment + self._logged_model_time = {} + self._checkpoint_callback = None # set wandb init arguments anonymous_lut = {True: 'allow', False: None} self._wandb_init = dict( @@ -168,10 +192,6 @@ class WandbLogger(LightningLoggerBase): os.environ['WANDB_MODE'] = 'dryrun' self._experiment = wandb.init(**self._wandb_init) if wandb.run is None else wandb.run - # save checkpoints in wandb dir to upload on W&B servers - if self._save_dir is None: - self._save_dir = self._experiment.dir - # define default x-axis (for latest wandb versions) if getattr(self._experiment, "define_metric", None): self._experiment.define_metric("trainer/global_step") @@ -213,8 +233,49 @@ class WandbLogger(LightningLoggerBase): # don't create an experiment if we don't have one return self._experiment.id if self._experiment else self._id + def after_save_checkpoint(self, checkpoint_callback: 'ReferenceType[ModelCheckpoint]') -> None: + # log checkpoints as artifacts + if self._log_model == 'all' or self._log_model is True and checkpoint_callback.save_top_k == -1: + self._scan_and_log_checkpoints(checkpoint_callback) + elif self._log_model is True: + self._checkpoint_callback = checkpoint_callback + @rank_zero_only def finalize(self, status: str) -> None: - # upload all checkpoints from saving dir - if self._log_model: - wandb.save(os.path.join(self.save_dir, "*.ckpt")) + # log checkpoints as artifacts + if self._checkpoint_callback: + self._scan_and_log_checkpoints(self._checkpoint_callback) + + def _scan_and_log_checkpoints(self, checkpoint_callback: 'ReferenceType[ModelCheckpoint]') -> None: + # get checkpoints to be saved with associated score + checkpoints = { + checkpoint_callback.last_model_path: checkpoint_callback.current_score, + checkpoint_callback.best_model_path: checkpoint_callback.best_model_score, + **checkpoint_callback.best_k_models + } + checkpoints = sorted([(Path(p).stat().st_mtime, p, s) for p, s in checkpoints.items() if Path(p).is_file()]) + checkpoints = [ + c for c in checkpoints if c[1] not in self._logged_model_time.keys() or self._logged_model_time[c[1]] < c[0] + ] + + # log iteratively all new checkpoints + for t, p, s in checkpoints: + metadata = { + 'score': s, + 'original_filename': Path(p).name, + 'ModelCheckpoint': { + k: getattr(checkpoint_callback, k) + for k in [ + 'monitor', 'mode', 'save_last', 'save_top_k', 'save_weights_only', '_every_n_train_steps', + '_every_n_val_epochs' + ] + # ensure it does not break if `ModelCheckpoint` args change + if hasattr(checkpoint_callback, k) + } + } if _WANDB_GREATER_EQUAL_0_10_22 else None + artifact = wandb.Artifact(name=f"model-{self.experiment.id}", type="model", metadata=metadata) + artifact.add_file(p, name='model.ckpt') + aliases = ["latest", "best"] if p == checkpoint_callback.best_model_path else ["latest"] + self.experiment.log_artifact(artifact, aliases=aliases) + # remember logged models - timestamp needed in case filename didn't change (lastkckpt or custom name) + self._logged_model_time[p] = t diff --git a/tests/loggers/test_base.py b/tests/loggers/test_base.py index c20b609658..9209083148 100644 --- a/tests/loggers/test_base.py +++ b/tests/loggers/test_base.py @@ -59,6 +59,7 @@ class CustomLogger(LightningLoggerBase): self.hparams_logged = None self.metrics_logged = {} self.finalized = False + self.after_save_checkpoint_called = False @property def experiment(self): @@ -92,6 +93,9 @@ class CustomLogger(LightningLoggerBase): def version(self): return "1" + def after_save_checkpoint(self, checkpoint_callback): + self.after_save_checkpoint_called = True + def test_custom_logger(tmpdir): @@ -115,6 +119,7 @@ def test_custom_logger(tmpdir): assert trainer.state.finished, f"Training failed with {trainer.state}" assert logger.hparams_logged == model.hparams assert logger.metrics_logged != {} + assert logger.after_save_checkpoint_called assert logger.finalized_status == "success" diff --git a/tests/loggers/test_wandb.py b/tests/loggers/test_wandb.py index 22be315eaa..27185b911b 100644 --- a/tests/loggers/test_wandb.py +++ b/tests/loggers/test_wandb.py @@ -24,14 +24,8 @@ from pytorch_lightning.utilities.exceptions import MisconfigurationException from tests.helpers import BoringModel -def get_warnings(recwarn): - warnings_text = '\n'.join(str(w.message) for w in recwarn.list) - recwarn.clear() - return warnings_text - - @mock.patch('pytorch_lightning.loggers.wandb.wandb') -def test_wandb_logger_init(wandb, recwarn): +def test_wandb_logger_init(wandb): """Verify that basic functionality of wandb logger works. Wandb doesn't work well with pytest so we have to mock it out here.""" @@ -51,8 +45,6 @@ def test_wandb_logger_init(wandb, recwarn): run = wandb.init() logger = WandbLogger(experiment=run) assert logger.experiment - assert run.dir is not None - assert logger.save_dir == run.dir # test wandb.init not called if there is a W&B run wandb.init().log.reset_mock() @@ -140,10 +132,8 @@ def test_wandb_logger_dirs_creation(wandb, tmpdir): # mock return values of experiment wandb.run = None - wandb.init().step = 0 logger.experiment.id = '1' logger.experiment.project_name.return_value = 'project' - logger.experiment.step = 0 for _ in range(2): _ = logger.experiment @@ -164,6 +154,71 @@ def test_wandb_logger_dirs_creation(wandb, tmpdir): assert trainer.log_dir == logger.save_dir +@mock.patch('pytorch_lightning.loggers.wandb.wandb') +def test_wandb_log_model(wandb, tmpdir): + """ Test that the logger creates the folders and files in the right place. """ + + wandb.run = None + model = BoringModel() + + # test log_model=True + logger = WandbLogger(log_model=True) + logger.experiment.id = '1' + logger.experiment.project_name.return_value = 'project' + trainer = Trainer(default_root_dir=tmpdir, logger=logger, max_epochs=2, limit_train_batches=3, limit_val_batches=3) + trainer.fit(model) + wandb.init().log_artifact.assert_called_once() + + # test log_model='all' + wandb.init().log_artifact.reset_mock() + wandb.init.reset_mock() + logger = WandbLogger(log_model='all') + logger.experiment.id = '1' + logger.experiment.project_name.return_value = 'project' + trainer = Trainer(default_root_dir=tmpdir, logger=logger, max_epochs=2, limit_train_batches=3, limit_val_batches=3) + trainer.fit(model) + assert wandb.init().log_artifact.call_count == 2 + + # test log_model=False + wandb.init().log_artifact.reset_mock() + wandb.init.reset_mock() + logger = WandbLogger(log_model=False) + logger.experiment.id = '1' + logger.experiment.project_name.return_value = 'project' + trainer = Trainer(default_root_dir=tmpdir, logger=logger, max_epochs=2, limit_train_batches=3, limit_val_batches=3) + trainer.fit(model) + assert not wandb.init().log_artifact.called + + # test correct metadata + import pytorch_lightning.loggers.wandb as pl_wandb + pl_wandb._WANDB_GREATER_EQUAL_0_10_22 = True + wandb.init().log_artifact.reset_mock() + wandb.init.reset_mock() + wandb.Artifact.reset_mock() + logger = pl_wandb.WandbLogger(log_model=True) + logger.experiment.id = '1' + logger.experiment.project_name.return_value = 'project' + trainer = Trainer(default_root_dir=tmpdir, logger=logger, max_epochs=2, limit_train_batches=3, limit_val_batches=3) + trainer.fit(model) + wandb.Artifact.assert_called_once_with( + name='model-1', + type='model', + metadata={ + 'score': None, + 'original_filename': 'epoch=1-step=5-v3.ckpt', + 'ModelCheckpoint': { + 'monitor': None, + 'mode': 'min', + 'save_last': None, + 'save_top_k': None, + 'save_weights_only': False, + '_every_n_train_steps': 0, + '_every_n_val_epochs': 1 + } + } + ) + + def test_wandb_sanitize_callable_params(tmpdir): """ Callback function are not serializiable. Therefore, we get them a chance to return