feat(wandb): log models as artifacts (#6231)

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com>
This commit is contained in:
Boris Dayma 2021-05-27 13:15:02 -05:00 committed by GitHub
parent 9304c0df8f
commit 9097347ea8
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 176 additions and 29 deletions

View File

@ -100,6 +100,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- MLflowLogger now uses the env variable `MLFLOW_TRACKING_URI` as default tracking uri ([#7457](https://github.com/PyTorchLightning/pytorch-lightning/pull/7457))
- Changed `WandbLogger(log_model={True/'all'})` to log models as artifacts ([#6231](https://github.com/PyTorchLightning/pytorch-lightning/pull/6231))
- MLFlowLogger now accepts `run_name` as an constructor argument ([#7622](https://github.com/PyTorchLightning/pytorch-lightning/issues/7622))

View File

@ -202,7 +202,7 @@ The :class:`~pytorch_lightning.loggers.TestTubeLogger` is available anywhere exc
Weights and Biases
==================
`Weights and Biases <https://www.wandb.com/>`_ is a third-party logger.
`Weights and Biases <https://docs.wandb.ai/integrations/lightning/>`_ is a third-party logger.
To use :class:`~pytorch_lightning.loggers.WandbLogger` as your logger do the following.
First, install the package:
@ -215,9 +215,14 @@ Then configure the logger and pass it to the :class:`~pytorch_lightning.trainer.
.. code-block:: python
from pytorch_lightning.loggers import WandbLogger
wandb_logger = WandbLogger(offline=True)
# instrument experiment with W&B
wandb_logger = WandbLogger(project='MNIST', log_model='all')
trainer = Trainer(logger=wandb_logger)
# log gradients and model topology
wandb_logger.watch(model)
The :class:`~pytorch_lightning.loggers.WandbLogger` is available anywhere except ``__init__`` in your
:class:`~pytorch_lightning.core.lightning.LightningModule`.
@ -226,8 +231,8 @@ The :class:`~pytorch_lightning.loggers.WandbLogger` is available anywhere except
class MyModule(LightningModule):
def any_lightning_module_function_or_hook(self):
some_img = fake_image()
self.logger.experiment.log({
"generated_images": [wandb.Image(some_img, caption="...")]
self.log({
"generated_images": [wandb.Image(some_img, caption="...")]
})
.. seealso::

View File

@ -26,6 +26,7 @@ from copy import deepcopy
from datetime import timedelta
from pathlib import Path
from typing import Any, Callable, Dict, Optional, Union
from weakref import proxy
import numpy as np
import torch
@ -330,6 +331,10 @@ class ModelCheckpoint(Callback):
# Mode 3: save last checkpoints
self._save_last_checkpoint(trainer, monitor_candidates)
# notify loggers
if trainer.is_global_zero and trainer.logger:
trainer.logger.after_save_checkpoint(proxy(self))
def _should_skip_saving_checkpoint(self, trainer: 'pl.Trainer') -> bool:
from pytorch_lightning.trainer.states import TrainerFn
return (

View File

@ -20,10 +20,12 @@ from abc import ABC, abstractmethod
from argparse import Namespace
from functools import wraps
from typing import Any, Callable, Dict, Iterable, List, Mapping, MutableMapping, Optional, Sequence, Tuple, Union
from weakref import ReferenceType
import numpy as np
import torch
from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint
from pytorch_lightning.core.lightning import LightningModule
from pytorch_lightning.utilities import rank_zero_only
@ -71,6 +73,15 @@ class LightningLoggerBase(ABC):
self._agg_key_funcs = agg_key_funcs if agg_key_funcs else {}
self._agg_default_func = agg_default_func
def after_save_checkpoint(self, checkpoint_callback: 'ReferenceType[ModelCheckpoint]') -> None:
"""
Called after model checkpoint callback saves a new checkpoint
Args:
model_checkpoint: the model checkpoint callback instance
"""
pass
def update_agg_funcs(
self,
agg_key_funcs: Optional[Mapping[str, Callable[[Sequence[float]], float]]] = None,
@ -357,6 +368,10 @@ class LoggerCollection(LightningLoggerBase):
def __getitem__(self, index: int) -> LightningLoggerBase:
return [logger for logger in self._logger_iterable][index]
def after_save_checkpoint(self, checkpoint_callback: 'ReferenceType[ModelCheckpoint]') -> None:
for logger in self._logger_iterable:
logger.after_save_checkpoint(checkpoint_callback)
def update_agg_funcs(
self,
agg_key_funcs: Optional[Mapping[str, Callable[[Sequence[float]], float]]] = None,

View File

@ -15,20 +15,26 @@
Weights and Biases Logger
-------------------------
"""
import operator
import os
from argparse import Namespace
from pathlib import Path
from typing import Any, Dict, Optional, Union
from weakref import ReferenceType
import torch.nn as nn
from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint
from pytorch_lightning.loggers.base import LightningLoggerBase, rank_zero_experiment
from pytorch_lightning.utilities import _module_available, rank_zero_only
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.imports import _compare_version
from pytorch_lightning.utilities.warnings import WarningCache
warning_cache = WarningCache()
_WANDB_AVAILABLE = _module_available("wandb")
_WANDB_GREATER_EQUAL_0_10_22 = _compare_version("wandb", operator.ge, "0.10.22")
try:
import wandb
@ -40,7 +46,7 @@ except ImportError:
class WandbLogger(LightningLoggerBase):
r"""
Log using `Weights and Biases <https://www.wandb.com/>`_.
Log using `Weights and Biases <https://docs.wandb.ai/integrations/lightning>`_.
Install it with pip:
@ -56,7 +62,15 @@ class WandbLogger(LightningLoggerBase):
version: Same as id.
anonymous: Enables or explicitly disables anonymous logging.
project: The name of the project to which this run will belong.
log_model: Save checkpoints in wandb dir to upload on W&B servers.
log_model: Log checkpoints created by :class:`~pytorch_lightning.callbacks.model_checkpoint.ModelCheckpoint`
as W&B artifacts.
* if ``log_model == 'all'``, checkpoints are logged during training.
* if ``log_model == True``, checkpoints are logged at the end of training, except when
:paramref:`~pytorch_lightning.callbacks.model_checkpoint.ModelCheckpoint.save_top_k` ``== -1``
which also logs every checkpoint during training.
* if ``log_model == False`` (default), no checkpoint is logged.
prefix: A string to put at the beginning of metric keys.
experiment: WandB experiment object. Automatically set when creating a run.
\**kwargs: Arguments passed to :func:`wandb.init` like `entity`, `group`, `tags`, etc.
@ -71,15 +85,16 @@ class WandbLogger(LightningLoggerBase):
from pytorch_lightning.loggers import WandbLogger
from pytorch_lightning import Trainer
wandb_logger = WandbLogger()
# instrument experiment with W&B
wandb_logger = WandbLogger(project='MNIST', log_model='all')
trainer = Trainer(logger=wandb_logger)
Note: When logging manually through `wandb.log` or `trainer.logger.experiment.log`,
make sure to use `commit=False` so the logging step does not increase.
# log gradients and model topology
wandb_logger.watch(model)
See Also:
- `Tutorial <https://colab.research.google.com/drive/16d1uctGaw2y9KhGBlINNTsWpmlXdJwRW?usp=sharing>`__
on how to use W&B with PyTorch Lightning
- `Demo in Google Colab <http://wandb.me/lightning>`__ with model logging
- `W&B Documentation <https://docs.wandb.ai/integrations/lightning>`__
"""
@ -114,6 +129,13 @@ class WandbLogger(LightningLoggerBase):
'Hint: Set `offline=False` to log your model.'
)
if log_model and not _WANDB_GREATER_EQUAL_0_10_22:
warning_cache.warn(
f'Providing log_model={log_model} requires wandb version >= 0.10.22'
' for logging associated model metadata.\n'
'Hint: Upgrade with `pip install --ugrade wandb`.'
)
if sync_step is not None:
warning_cache.warn(
"`WandbLogger(sync_step=(True|False))` is deprecated in v1.2.1 and will be removed in v1.5."
@ -125,6 +147,8 @@ class WandbLogger(LightningLoggerBase):
self._log_model = log_model
self._prefix = prefix
self._experiment = experiment
self._logged_model_time = {}
self._checkpoint_callback = None
# set wandb init arguments
anonymous_lut = {True: 'allow', False: None}
self._wandb_init = dict(
@ -168,10 +192,6 @@ class WandbLogger(LightningLoggerBase):
os.environ['WANDB_MODE'] = 'dryrun'
self._experiment = wandb.init(**self._wandb_init) if wandb.run is None else wandb.run
# save checkpoints in wandb dir to upload on W&B servers
if self._save_dir is None:
self._save_dir = self._experiment.dir
# define default x-axis (for latest wandb versions)
if getattr(self._experiment, "define_metric", None):
self._experiment.define_metric("trainer/global_step")
@ -213,8 +233,49 @@ class WandbLogger(LightningLoggerBase):
# don't create an experiment if we don't have one
return self._experiment.id if self._experiment else self._id
def after_save_checkpoint(self, checkpoint_callback: 'ReferenceType[ModelCheckpoint]') -> None:
# log checkpoints as artifacts
if self._log_model == 'all' or self._log_model is True and checkpoint_callback.save_top_k == -1:
self._scan_and_log_checkpoints(checkpoint_callback)
elif self._log_model is True:
self._checkpoint_callback = checkpoint_callback
@rank_zero_only
def finalize(self, status: str) -> None:
# upload all checkpoints from saving dir
if self._log_model:
wandb.save(os.path.join(self.save_dir, "*.ckpt"))
# log checkpoints as artifacts
if self._checkpoint_callback:
self._scan_and_log_checkpoints(self._checkpoint_callback)
def _scan_and_log_checkpoints(self, checkpoint_callback: 'ReferenceType[ModelCheckpoint]') -> None:
# get checkpoints to be saved with associated score
checkpoints = {
checkpoint_callback.last_model_path: checkpoint_callback.current_score,
checkpoint_callback.best_model_path: checkpoint_callback.best_model_score,
**checkpoint_callback.best_k_models
}
checkpoints = sorted([(Path(p).stat().st_mtime, p, s) for p, s in checkpoints.items() if Path(p).is_file()])
checkpoints = [
c for c in checkpoints if c[1] not in self._logged_model_time.keys() or self._logged_model_time[c[1]] < c[0]
]
# log iteratively all new checkpoints
for t, p, s in checkpoints:
metadata = {
'score': s,
'original_filename': Path(p).name,
'ModelCheckpoint': {
k: getattr(checkpoint_callback, k)
for k in [
'monitor', 'mode', 'save_last', 'save_top_k', 'save_weights_only', '_every_n_train_steps',
'_every_n_val_epochs'
]
# ensure it does not break if `ModelCheckpoint` args change
if hasattr(checkpoint_callback, k)
}
} if _WANDB_GREATER_EQUAL_0_10_22 else None
artifact = wandb.Artifact(name=f"model-{self.experiment.id}", type="model", metadata=metadata)
artifact.add_file(p, name='model.ckpt')
aliases = ["latest", "best"] if p == checkpoint_callback.best_model_path else ["latest"]
self.experiment.log_artifact(artifact, aliases=aliases)
# remember logged models - timestamp needed in case filename didn't change (lastkckpt or custom name)
self._logged_model_time[p] = t

View File

@ -59,6 +59,7 @@ class CustomLogger(LightningLoggerBase):
self.hparams_logged = None
self.metrics_logged = {}
self.finalized = False
self.after_save_checkpoint_called = False
@property
def experiment(self):
@ -92,6 +93,9 @@ class CustomLogger(LightningLoggerBase):
def version(self):
return "1"
def after_save_checkpoint(self, checkpoint_callback):
self.after_save_checkpoint_called = True
def test_custom_logger(tmpdir):
@ -115,6 +119,7 @@ def test_custom_logger(tmpdir):
assert trainer.state.finished, f"Training failed with {trainer.state}"
assert logger.hparams_logged == model.hparams
assert logger.metrics_logged != {}
assert logger.after_save_checkpoint_called
assert logger.finalized_status == "success"

View File

@ -24,14 +24,8 @@ from pytorch_lightning.utilities.exceptions import MisconfigurationException
from tests.helpers import BoringModel
def get_warnings(recwarn):
warnings_text = '\n'.join(str(w.message) for w in recwarn.list)
recwarn.clear()
return warnings_text
@mock.patch('pytorch_lightning.loggers.wandb.wandb')
def test_wandb_logger_init(wandb, recwarn):
def test_wandb_logger_init(wandb):
"""Verify that basic functionality of wandb logger works.
Wandb doesn't work well with pytest so we have to mock it out here."""
@ -51,8 +45,6 @@ def test_wandb_logger_init(wandb, recwarn):
run = wandb.init()
logger = WandbLogger(experiment=run)
assert logger.experiment
assert run.dir is not None
assert logger.save_dir == run.dir
# test wandb.init not called if there is a W&B run
wandb.init().log.reset_mock()
@ -140,10 +132,8 @@ def test_wandb_logger_dirs_creation(wandb, tmpdir):
# mock return values of experiment
wandb.run = None
wandb.init().step = 0
logger.experiment.id = '1'
logger.experiment.project_name.return_value = 'project'
logger.experiment.step = 0
for _ in range(2):
_ = logger.experiment
@ -164,6 +154,71 @@ def test_wandb_logger_dirs_creation(wandb, tmpdir):
assert trainer.log_dir == logger.save_dir
@mock.patch('pytorch_lightning.loggers.wandb.wandb')
def test_wandb_log_model(wandb, tmpdir):
""" Test that the logger creates the folders and files in the right place. """
wandb.run = None
model = BoringModel()
# test log_model=True
logger = WandbLogger(log_model=True)
logger.experiment.id = '1'
logger.experiment.project_name.return_value = 'project'
trainer = Trainer(default_root_dir=tmpdir, logger=logger, max_epochs=2, limit_train_batches=3, limit_val_batches=3)
trainer.fit(model)
wandb.init().log_artifact.assert_called_once()
# test log_model='all'
wandb.init().log_artifact.reset_mock()
wandb.init.reset_mock()
logger = WandbLogger(log_model='all')
logger.experiment.id = '1'
logger.experiment.project_name.return_value = 'project'
trainer = Trainer(default_root_dir=tmpdir, logger=logger, max_epochs=2, limit_train_batches=3, limit_val_batches=3)
trainer.fit(model)
assert wandb.init().log_artifact.call_count == 2
# test log_model=False
wandb.init().log_artifact.reset_mock()
wandb.init.reset_mock()
logger = WandbLogger(log_model=False)
logger.experiment.id = '1'
logger.experiment.project_name.return_value = 'project'
trainer = Trainer(default_root_dir=tmpdir, logger=logger, max_epochs=2, limit_train_batches=3, limit_val_batches=3)
trainer.fit(model)
assert not wandb.init().log_artifact.called
# test correct metadata
import pytorch_lightning.loggers.wandb as pl_wandb
pl_wandb._WANDB_GREATER_EQUAL_0_10_22 = True
wandb.init().log_artifact.reset_mock()
wandb.init.reset_mock()
wandb.Artifact.reset_mock()
logger = pl_wandb.WandbLogger(log_model=True)
logger.experiment.id = '1'
logger.experiment.project_name.return_value = 'project'
trainer = Trainer(default_root_dir=tmpdir, logger=logger, max_epochs=2, limit_train_batches=3, limit_val_batches=3)
trainer.fit(model)
wandb.Artifact.assert_called_once_with(
name='model-1',
type='model',
metadata={
'score': None,
'original_filename': 'epoch=1-step=5-v3.ckpt',
'ModelCheckpoint': {
'monitor': None,
'mode': 'min',
'save_last': None,
'save_top_k': None,
'save_weights_only': False,
'_every_n_train_steps': 0,
'_every_n_val_epochs': 1
}
}
)
def test_wandb_sanitize_callable_params(tmpdir):
"""
Callback function are not serializiable. Therefore, we get them a chance to return