feat(wandb): log models as artifacts (#6231)
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com>
This commit is contained in:
parent
9304c0df8f
commit
9097347ea8
|
@ -100,6 +100,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
|
|||
- MLflowLogger now uses the env variable `MLFLOW_TRACKING_URI` as default tracking uri ([#7457](https://github.com/PyTorchLightning/pytorch-lightning/pull/7457))
|
||||
|
||||
|
||||
- Changed `WandbLogger(log_model={True/'all'})` to log models as artifacts ([#6231](https://github.com/PyTorchLightning/pytorch-lightning/pull/6231))
|
||||
- MLFlowLogger now accepts `run_name` as an constructor argument ([#7622](https://github.com/PyTorchLightning/pytorch-lightning/issues/7622))
|
||||
|
||||
|
||||
|
|
|
@ -202,7 +202,7 @@ The :class:`~pytorch_lightning.loggers.TestTubeLogger` is available anywhere exc
|
|||
Weights and Biases
|
||||
==================
|
||||
|
||||
`Weights and Biases <https://www.wandb.com/>`_ is a third-party logger.
|
||||
`Weights and Biases <https://docs.wandb.ai/integrations/lightning/>`_ is a third-party logger.
|
||||
To use :class:`~pytorch_lightning.loggers.WandbLogger` as your logger do the following.
|
||||
First, install the package:
|
||||
|
||||
|
@ -215,9 +215,14 @@ Then configure the logger and pass it to the :class:`~pytorch_lightning.trainer.
|
|||
.. code-block:: python
|
||||
|
||||
from pytorch_lightning.loggers import WandbLogger
|
||||
wandb_logger = WandbLogger(offline=True)
|
||||
|
||||
# instrument experiment with W&B
|
||||
wandb_logger = WandbLogger(project='MNIST', log_model='all')
|
||||
trainer = Trainer(logger=wandb_logger)
|
||||
|
||||
# log gradients and model topology
|
||||
wandb_logger.watch(model)
|
||||
|
||||
The :class:`~pytorch_lightning.loggers.WandbLogger` is available anywhere except ``__init__`` in your
|
||||
:class:`~pytorch_lightning.core.lightning.LightningModule`.
|
||||
|
||||
|
@ -226,8 +231,8 @@ The :class:`~pytorch_lightning.loggers.WandbLogger` is available anywhere except
|
|||
class MyModule(LightningModule):
|
||||
def any_lightning_module_function_or_hook(self):
|
||||
some_img = fake_image()
|
||||
self.logger.experiment.log({
|
||||
"generated_images": [wandb.Image(some_img, caption="...")]
|
||||
self.log({
|
||||
"generated_images": [wandb.Image(some_img, caption="...")]
|
||||
})
|
||||
|
||||
.. seealso::
|
||||
|
|
|
@ -26,6 +26,7 @@ from copy import deepcopy
|
|||
from datetime import timedelta
|
||||
from pathlib import Path
|
||||
from typing import Any, Callable, Dict, Optional, Union
|
||||
from weakref import proxy
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
@ -330,6 +331,10 @@ class ModelCheckpoint(Callback):
|
|||
# Mode 3: save last checkpoints
|
||||
self._save_last_checkpoint(trainer, monitor_candidates)
|
||||
|
||||
# notify loggers
|
||||
if trainer.is_global_zero and trainer.logger:
|
||||
trainer.logger.after_save_checkpoint(proxy(self))
|
||||
|
||||
def _should_skip_saving_checkpoint(self, trainer: 'pl.Trainer') -> bool:
|
||||
from pytorch_lightning.trainer.states import TrainerFn
|
||||
return (
|
||||
|
|
|
@ -20,10 +20,12 @@ from abc import ABC, abstractmethod
|
|||
from argparse import Namespace
|
||||
from functools import wraps
|
||||
from typing import Any, Callable, Dict, Iterable, List, Mapping, MutableMapping, Optional, Sequence, Tuple, Union
|
||||
from weakref import ReferenceType
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint
|
||||
from pytorch_lightning.core.lightning import LightningModule
|
||||
from pytorch_lightning.utilities import rank_zero_only
|
||||
|
||||
|
@ -71,6 +73,15 @@ class LightningLoggerBase(ABC):
|
|||
self._agg_key_funcs = agg_key_funcs if agg_key_funcs else {}
|
||||
self._agg_default_func = agg_default_func
|
||||
|
||||
def after_save_checkpoint(self, checkpoint_callback: 'ReferenceType[ModelCheckpoint]') -> None:
|
||||
"""
|
||||
Called after model checkpoint callback saves a new checkpoint
|
||||
|
||||
Args:
|
||||
model_checkpoint: the model checkpoint callback instance
|
||||
"""
|
||||
pass
|
||||
|
||||
def update_agg_funcs(
|
||||
self,
|
||||
agg_key_funcs: Optional[Mapping[str, Callable[[Sequence[float]], float]]] = None,
|
||||
|
@ -357,6 +368,10 @@ class LoggerCollection(LightningLoggerBase):
|
|||
def __getitem__(self, index: int) -> LightningLoggerBase:
|
||||
return [logger for logger in self._logger_iterable][index]
|
||||
|
||||
def after_save_checkpoint(self, checkpoint_callback: 'ReferenceType[ModelCheckpoint]') -> None:
|
||||
for logger in self._logger_iterable:
|
||||
logger.after_save_checkpoint(checkpoint_callback)
|
||||
|
||||
def update_agg_funcs(
|
||||
self,
|
||||
agg_key_funcs: Optional[Mapping[str, Callable[[Sequence[float]], float]]] = None,
|
||||
|
|
|
@ -15,20 +15,26 @@
|
|||
Weights and Biases Logger
|
||||
-------------------------
|
||||
"""
|
||||
import operator
|
||||
import os
|
||||
from argparse import Namespace
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, Optional, Union
|
||||
from weakref import ReferenceType
|
||||
|
||||
import torch.nn as nn
|
||||
|
||||
from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint
|
||||
from pytorch_lightning.loggers.base import LightningLoggerBase, rank_zero_experiment
|
||||
from pytorch_lightning.utilities import _module_available, rank_zero_only
|
||||
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
||||
from pytorch_lightning.utilities.imports import _compare_version
|
||||
from pytorch_lightning.utilities.warnings import WarningCache
|
||||
|
||||
warning_cache = WarningCache()
|
||||
|
||||
_WANDB_AVAILABLE = _module_available("wandb")
|
||||
_WANDB_GREATER_EQUAL_0_10_22 = _compare_version("wandb", operator.ge, "0.10.22")
|
||||
|
||||
try:
|
||||
import wandb
|
||||
|
@ -40,7 +46,7 @@ except ImportError:
|
|||
|
||||
class WandbLogger(LightningLoggerBase):
|
||||
r"""
|
||||
Log using `Weights and Biases <https://www.wandb.com/>`_.
|
||||
Log using `Weights and Biases <https://docs.wandb.ai/integrations/lightning>`_.
|
||||
|
||||
Install it with pip:
|
||||
|
||||
|
@ -56,7 +62,15 @@ class WandbLogger(LightningLoggerBase):
|
|||
version: Same as id.
|
||||
anonymous: Enables or explicitly disables anonymous logging.
|
||||
project: The name of the project to which this run will belong.
|
||||
log_model: Save checkpoints in wandb dir to upload on W&B servers.
|
||||
log_model: Log checkpoints created by :class:`~pytorch_lightning.callbacks.model_checkpoint.ModelCheckpoint`
|
||||
as W&B artifacts.
|
||||
|
||||
* if ``log_model == 'all'``, checkpoints are logged during training.
|
||||
* if ``log_model == True``, checkpoints are logged at the end of training, except when
|
||||
:paramref:`~pytorch_lightning.callbacks.model_checkpoint.ModelCheckpoint.save_top_k` ``== -1``
|
||||
which also logs every checkpoint during training.
|
||||
* if ``log_model == False`` (default), no checkpoint is logged.
|
||||
|
||||
prefix: A string to put at the beginning of metric keys.
|
||||
experiment: WandB experiment object. Automatically set when creating a run.
|
||||
\**kwargs: Arguments passed to :func:`wandb.init` like `entity`, `group`, `tags`, etc.
|
||||
|
@ -71,15 +85,16 @@ class WandbLogger(LightningLoggerBase):
|
|||
|
||||
from pytorch_lightning.loggers import WandbLogger
|
||||
from pytorch_lightning import Trainer
|
||||
wandb_logger = WandbLogger()
|
||||
|
||||
# instrument experiment with W&B
|
||||
wandb_logger = WandbLogger(project='MNIST', log_model='all')
|
||||
trainer = Trainer(logger=wandb_logger)
|
||||
|
||||
Note: When logging manually through `wandb.log` or `trainer.logger.experiment.log`,
|
||||
make sure to use `commit=False` so the logging step does not increase.
|
||||
# log gradients and model topology
|
||||
wandb_logger.watch(model)
|
||||
|
||||
See Also:
|
||||
- `Tutorial <https://colab.research.google.com/drive/16d1uctGaw2y9KhGBlINNTsWpmlXdJwRW?usp=sharing>`__
|
||||
on how to use W&B with PyTorch Lightning
|
||||
- `Demo in Google Colab <http://wandb.me/lightning>`__ with model logging
|
||||
- `W&B Documentation <https://docs.wandb.ai/integrations/lightning>`__
|
||||
|
||||
"""
|
||||
|
@ -114,6 +129,13 @@ class WandbLogger(LightningLoggerBase):
|
|||
'Hint: Set `offline=False` to log your model.'
|
||||
)
|
||||
|
||||
if log_model and not _WANDB_GREATER_EQUAL_0_10_22:
|
||||
warning_cache.warn(
|
||||
f'Providing log_model={log_model} requires wandb version >= 0.10.22'
|
||||
' for logging associated model metadata.\n'
|
||||
'Hint: Upgrade with `pip install --ugrade wandb`.'
|
||||
)
|
||||
|
||||
if sync_step is not None:
|
||||
warning_cache.warn(
|
||||
"`WandbLogger(sync_step=(True|False))` is deprecated in v1.2.1 and will be removed in v1.5."
|
||||
|
@ -125,6 +147,8 @@ class WandbLogger(LightningLoggerBase):
|
|||
self._log_model = log_model
|
||||
self._prefix = prefix
|
||||
self._experiment = experiment
|
||||
self._logged_model_time = {}
|
||||
self._checkpoint_callback = None
|
||||
# set wandb init arguments
|
||||
anonymous_lut = {True: 'allow', False: None}
|
||||
self._wandb_init = dict(
|
||||
|
@ -168,10 +192,6 @@ class WandbLogger(LightningLoggerBase):
|
|||
os.environ['WANDB_MODE'] = 'dryrun'
|
||||
self._experiment = wandb.init(**self._wandb_init) if wandb.run is None else wandb.run
|
||||
|
||||
# save checkpoints in wandb dir to upload on W&B servers
|
||||
if self._save_dir is None:
|
||||
self._save_dir = self._experiment.dir
|
||||
|
||||
# define default x-axis (for latest wandb versions)
|
||||
if getattr(self._experiment, "define_metric", None):
|
||||
self._experiment.define_metric("trainer/global_step")
|
||||
|
@ -213,8 +233,49 @@ class WandbLogger(LightningLoggerBase):
|
|||
# don't create an experiment if we don't have one
|
||||
return self._experiment.id if self._experiment else self._id
|
||||
|
||||
def after_save_checkpoint(self, checkpoint_callback: 'ReferenceType[ModelCheckpoint]') -> None:
|
||||
# log checkpoints as artifacts
|
||||
if self._log_model == 'all' or self._log_model is True and checkpoint_callback.save_top_k == -1:
|
||||
self._scan_and_log_checkpoints(checkpoint_callback)
|
||||
elif self._log_model is True:
|
||||
self._checkpoint_callback = checkpoint_callback
|
||||
|
||||
@rank_zero_only
|
||||
def finalize(self, status: str) -> None:
|
||||
# upload all checkpoints from saving dir
|
||||
if self._log_model:
|
||||
wandb.save(os.path.join(self.save_dir, "*.ckpt"))
|
||||
# log checkpoints as artifacts
|
||||
if self._checkpoint_callback:
|
||||
self._scan_and_log_checkpoints(self._checkpoint_callback)
|
||||
|
||||
def _scan_and_log_checkpoints(self, checkpoint_callback: 'ReferenceType[ModelCheckpoint]') -> None:
|
||||
# get checkpoints to be saved with associated score
|
||||
checkpoints = {
|
||||
checkpoint_callback.last_model_path: checkpoint_callback.current_score,
|
||||
checkpoint_callback.best_model_path: checkpoint_callback.best_model_score,
|
||||
**checkpoint_callback.best_k_models
|
||||
}
|
||||
checkpoints = sorted([(Path(p).stat().st_mtime, p, s) for p, s in checkpoints.items() if Path(p).is_file()])
|
||||
checkpoints = [
|
||||
c for c in checkpoints if c[1] not in self._logged_model_time.keys() or self._logged_model_time[c[1]] < c[0]
|
||||
]
|
||||
|
||||
# log iteratively all new checkpoints
|
||||
for t, p, s in checkpoints:
|
||||
metadata = {
|
||||
'score': s,
|
||||
'original_filename': Path(p).name,
|
||||
'ModelCheckpoint': {
|
||||
k: getattr(checkpoint_callback, k)
|
||||
for k in [
|
||||
'monitor', 'mode', 'save_last', 'save_top_k', 'save_weights_only', '_every_n_train_steps',
|
||||
'_every_n_val_epochs'
|
||||
]
|
||||
# ensure it does not break if `ModelCheckpoint` args change
|
||||
if hasattr(checkpoint_callback, k)
|
||||
}
|
||||
} if _WANDB_GREATER_EQUAL_0_10_22 else None
|
||||
artifact = wandb.Artifact(name=f"model-{self.experiment.id}", type="model", metadata=metadata)
|
||||
artifact.add_file(p, name='model.ckpt')
|
||||
aliases = ["latest", "best"] if p == checkpoint_callback.best_model_path else ["latest"]
|
||||
self.experiment.log_artifact(artifact, aliases=aliases)
|
||||
# remember logged models - timestamp needed in case filename didn't change (lastkckpt or custom name)
|
||||
self._logged_model_time[p] = t
|
||||
|
|
|
@ -59,6 +59,7 @@ class CustomLogger(LightningLoggerBase):
|
|||
self.hparams_logged = None
|
||||
self.metrics_logged = {}
|
||||
self.finalized = False
|
||||
self.after_save_checkpoint_called = False
|
||||
|
||||
@property
|
||||
def experiment(self):
|
||||
|
@ -92,6 +93,9 @@ class CustomLogger(LightningLoggerBase):
|
|||
def version(self):
|
||||
return "1"
|
||||
|
||||
def after_save_checkpoint(self, checkpoint_callback):
|
||||
self.after_save_checkpoint_called = True
|
||||
|
||||
|
||||
def test_custom_logger(tmpdir):
|
||||
|
||||
|
@ -115,6 +119,7 @@ def test_custom_logger(tmpdir):
|
|||
assert trainer.state.finished, f"Training failed with {trainer.state}"
|
||||
assert logger.hparams_logged == model.hparams
|
||||
assert logger.metrics_logged != {}
|
||||
assert logger.after_save_checkpoint_called
|
||||
assert logger.finalized_status == "success"
|
||||
|
||||
|
||||
|
|
|
@ -24,14 +24,8 @@ from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
|||
from tests.helpers import BoringModel
|
||||
|
||||
|
||||
def get_warnings(recwarn):
|
||||
warnings_text = '\n'.join(str(w.message) for w in recwarn.list)
|
||||
recwarn.clear()
|
||||
return warnings_text
|
||||
|
||||
|
||||
@mock.patch('pytorch_lightning.loggers.wandb.wandb')
|
||||
def test_wandb_logger_init(wandb, recwarn):
|
||||
def test_wandb_logger_init(wandb):
|
||||
"""Verify that basic functionality of wandb logger works.
|
||||
Wandb doesn't work well with pytest so we have to mock it out here."""
|
||||
|
||||
|
@ -51,8 +45,6 @@ def test_wandb_logger_init(wandb, recwarn):
|
|||
run = wandb.init()
|
||||
logger = WandbLogger(experiment=run)
|
||||
assert logger.experiment
|
||||
assert run.dir is not None
|
||||
assert logger.save_dir == run.dir
|
||||
|
||||
# test wandb.init not called if there is a W&B run
|
||||
wandb.init().log.reset_mock()
|
||||
|
@ -140,10 +132,8 @@ def test_wandb_logger_dirs_creation(wandb, tmpdir):
|
|||
|
||||
# mock return values of experiment
|
||||
wandb.run = None
|
||||
wandb.init().step = 0
|
||||
logger.experiment.id = '1'
|
||||
logger.experiment.project_name.return_value = 'project'
|
||||
logger.experiment.step = 0
|
||||
|
||||
for _ in range(2):
|
||||
_ = logger.experiment
|
||||
|
@ -164,6 +154,71 @@ def test_wandb_logger_dirs_creation(wandb, tmpdir):
|
|||
assert trainer.log_dir == logger.save_dir
|
||||
|
||||
|
||||
@mock.patch('pytorch_lightning.loggers.wandb.wandb')
|
||||
def test_wandb_log_model(wandb, tmpdir):
|
||||
""" Test that the logger creates the folders and files in the right place. """
|
||||
|
||||
wandb.run = None
|
||||
model = BoringModel()
|
||||
|
||||
# test log_model=True
|
||||
logger = WandbLogger(log_model=True)
|
||||
logger.experiment.id = '1'
|
||||
logger.experiment.project_name.return_value = 'project'
|
||||
trainer = Trainer(default_root_dir=tmpdir, logger=logger, max_epochs=2, limit_train_batches=3, limit_val_batches=3)
|
||||
trainer.fit(model)
|
||||
wandb.init().log_artifact.assert_called_once()
|
||||
|
||||
# test log_model='all'
|
||||
wandb.init().log_artifact.reset_mock()
|
||||
wandb.init.reset_mock()
|
||||
logger = WandbLogger(log_model='all')
|
||||
logger.experiment.id = '1'
|
||||
logger.experiment.project_name.return_value = 'project'
|
||||
trainer = Trainer(default_root_dir=tmpdir, logger=logger, max_epochs=2, limit_train_batches=3, limit_val_batches=3)
|
||||
trainer.fit(model)
|
||||
assert wandb.init().log_artifact.call_count == 2
|
||||
|
||||
# test log_model=False
|
||||
wandb.init().log_artifact.reset_mock()
|
||||
wandb.init.reset_mock()
|
||||
logger = WandbLogger(log_model=False)
|
||||
logger.experiment.id = '1'
|
||||
logger.experiment.project_name.return_value = 'project'
|
||||
trainer = Trainer(default_root_dir=tmpdir, logger=logger, max_epochs=2, limit_train_batches=3, limit_val_batches=3)
|
||||
trainer.fit(model)
|
||||
assert not wandb.init().log_artifact.called
|
||||
|
||||
# test correct metadata
|
||||
import pytorch_lightning.loggers.wandb as pl_wandb
|
||||
pl_wandb._WANDB_GREATER_EQUAL_0_10_22 = True
|
||||
wandb.init().log_artifact.reset_mock()
|
||||
wandb.init.reset_mock()
|
||||
wandb.Artifact.reset_mock()
|
||||
logger = pl_wandb.WandbLogger(log_model=True)
|
||||
logger.experiment.id = '1'
|
||||
logger.experiment.project_name.return_value = 'project'
|
||||
trainer = Trainer(default_root_dir=tmpdir, logger=logger, max_epochs=2, limit_train_batches=3, limit_val_batches=3)
|
||||
trainer.fit(model)
|
||||
wandb.Artifact.assert_called_once_with(
|
||||
name='model-1',
|
||||
type='model',
|
||||
metadata={
|
||||
'score': None,
|
||||
'original_filename': 'epoch=1-step=5-v3.ckpt',
|
||||
'ModelCheckpoint': {
|
||||
'monitor': None,
|
||||
'mode': 'min',
|
||||
'save_last': None,
|
||||
'save_top_k': None,
|
||||
'save_weights_only': False,
|
||||
'_every_n_train_steps': 0,
|
||||
'_every_n_val_epochs': 1
|
||||
}
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
def test_wandb_sanitize_callable_params(tmpdir):
|
||||
"""
|
||||
Callback function are not serializiable. Therefore, we get them a chance to return
|
||||
|
|
Loading…
Reference in New Issue