diff --git a/CHANGELOG.md b/CHANGELOG.md
index 4259684748..22c7070cfb 100644
--- a/CHANGELOG.md
+++ b/CHANGELOG.md
@@ -100,6 +100,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- MLflowLogger now uses the env variable `MLFLOW_TRACKING_URI` as default tracking uri ([#7457](https://github.com/PyTorchLightning/pytorch-lightning/pull/7457))
+- Changed `WandbLogger(log_model={True/'all'})` to log models as artifacts ([#6231](https://github.com/PyTorchLightning/pytorch-lightning/pull/6231))
- MLFlowLogger now accepts `run_name` as an constructor argument ([#7622](https://github.com/PyTorchLightning/pytorch-lightning/issues/7622))
diff --git a/docs/source/common/loggers.rst b/docs/source/common/loggers.rst
index c6c5f0d865..5b1f13dbf4 100644
--- a/docs/source/common/loggers.rst
+++ b/docs/source/common/loggers.rst
@@ -202,7 +202,7 @@ The :class:`~pytorch_lightning.loggers.TestTubeLogger` is available anywhere exc
Weights and Biases
==================
-`Weights and Biases `_ is a third-party logger.
+`Weights and Biases `_ is a third-party logger.
To use :class:`~pytorch_lightning.loggers.WandbLogger` as your logger do the following.
First, install the package:
@@ -215,9 +215,14 @@ Then configure the logger and pass it to the :class:`~pytorch_lightning.trainer.
.. code-block:: python
from pytorch_lightning.loggers import WandbLogger
- wandb_logger = WandbLogger(offline=True)
+
+ # instrument experiment with W&B
+ wandb_logger = WandbLogger(project='MNIST', log_model='all')
trainer = Trainer(logger=wandb_logger)
+ # log gradients and model topology
+ wandb_logger.watch(model)
+
The :class:`~pytorch_lightning.loggers.WandbLogger` is available anywhere except ``__init__`` in your
:class:`~pytorch_lightning.core.lightning.LightningModule`.
@@ -226,8 +231,8 @@ The :class:`~pytorch_lightning.loggers.WandbLogger` is available anywhere except
class MyModule(LightningModule):
def any_lightning_module_function_or_hook(self):
some_img = fake_image()
- self.logger.experiment.log({
- "generated_images": [wandb.Image(some_img, caption="...")]
+ self.log({
+ "generated_images": [wandb.Image(some_img, caption="...")]
})
.. seealso::
diff --git a/pytorch_lightning/callbacks/model_checkpoint.py b/pytorch_lightning/callbacks/model_checkpoint.py
index 7642ad95d0..067ebfdeaf 100644
--- a/pytorch_lightning/callbacks/model_checkpoint.py
+++ b/pytorch_lightning/callbacks/model_checkpoint.py
@@ -26,6 +26,7 @@ from copy import deepcopy
from datetime import timedelta
from pathlib import Path
from typing import Any, Callable, Dict, Optional, Union
+from weakref import proxy
import numpy as np
import torch
@@ -330,6 +331,10 @@ class ModelCheckpoint(Callback):
# Mode 3: save last checkpoints
self._save_last_checkpoint(trainer, monitor_candidates)
+ # notify loggers
+ if trainer.is_global_zero and trainer.logger:
+ trainer.logger.after_save_checkpoint(proxy(self))
+
def _should_skip_saving_checkpoint(self, trainer: 'pl.Trainer') -> bool:
from pytorch_lightning.trainer.states import TrainerFn
return (
diff --git a/pytorch_lightning/loggers/base.py b/pytorch_lightning/loggers/base.py
index 035a42338f..7736ed24ba 100644
--- a/pytorch_lightning/loggers/base.py
+++ b/pytorch_lightning/loggers/base.py
@@ -20,10 +20,12 @@ from abc import ABC, abstractmethod
from argparse import Namespace
from functools import wraps
from typing import Any, Callable, Dict, Iterable, List, Mapping, MutableMapping, Optional, Sequence, Tuple, Union
+from weakref import ReferenceType
import numpy as np
import torch
+from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint
from pytorch_lightning.core.lightning import LightningModule
from pytorch_lightning.utilities import rank_zero_only
@@ -71,6 +73,15 @@ class LightningLoggerBase(ABC):
self._agg_key_funcs = agg_key_funcs if agg_key_funcs else {}
self._agg_default_func = agg_default_func
+ def after_save_checkpoint(self, checkpoint_callback: 'ReferenceType[ModelCheckpoint]') -> None:
+ """
+ Called after model checkpoint callback saves a new checkpoint
+
+ Args:
+ model_checkpoint: the model checkpoint callback instance
+ """
+ pass
+
def update_agg_funcs(
self,
agg_key_funcs: Optional[Mapping[str, Callable[[Sequence[float]], float]]] = None,
@@ -357,6 +368,10 @@ class LoggerCollection(LightningLoggerBase):
def __getitem__(self, index: int) -> LightningLoggerBase:
return [logger for logger in self._logger_iterable][index]
+ def after_save_checkpoint(self, checkpoint_callback: 'ReferenceType[ModelCheckpoint]') -> None:
+ for logger in self._logger_iterable:
+ logger.after_save_checkpoint(checkpoint_callback)
+
def update_agg_funcs(
self,
agg_key_funcs: Optional[Mapping[str, Callable[[Sequence[float]], float]]] = None,
diff --git a/pytorch_lightning/loggers/wandb.py b/pytorch_lightning/loggers/wandb.py
index 0f73153378..c127fa037e 100644
--- a/pytorch_lightning/loggers/wandb.py
+++ b/pytorch_lightning/loggers/wandb.py
@@ -15,20 +15,26 @@
Weights and Biases Logger
-------------------------
"""
+import operator
import os
from argparse import Namespace
+from pathlib import Path
from typing import Any, Dict, Optional, Union
+from weakref import ReferenceType
import torch.nn as nn
+from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint
from pytorch_lightning.loggers.base import LightningLoggerBase, rank_zero_experiment
from pytorch_lightning.utilities import _module_available, rank_zero_only
from pytorch_lightning.utilities.exceptions import MisconfigurationException
+from pytorch_lightning.utilities.imports import _compare_version
from pytorch_lightning.utilities.warnings import WarningCache
warning_cache = WarningCache()
_WANDB_AVAILABLE = _module_available("wandb")
+_WANDB_GREATER_EQUAL_0_10_22 = _compare_version("wandb", operator.ge, "0.10.22")
try:
import wandb
@@ -40,7 +46,7 @@ except ImportError:
class WandbLogger(LightningLoggerBase):
r"""
- Log using `Weights and Biases `_.
+ Log using `Weights and Biases `_.
Install it with pip:
@@ -56,7 +62,15 @@ class WandbLogger(LightningLoggerBase):
version: Same as id.
anonymous: Enables or explicitly disables anonymous logging.
project: The name of the project to which this run will belong.
- log_model: Save checkpoints in wandb dir to upload on W&B servers.
+ log_model: Log checkpoints created by :class:`~pytorch_lightning.callbacks.model_checkpoint.ModelCheckpoint`
+ as W&B artifacts.
+
+ * if ``log_model == 'all'``, checkpoints are logged during training.
+ * if ``log_model == True``, checkpoints are logged at the end of training, except when
+ :paramref:`~pytorch_lightning.callbacks.model_checkpoint.ModelCheckpoint.save_top_k` ``== -1``
+ which also logs every checkpoint during training.
+ * if ``log_model == False`` (default), no checkpoint is logged.
+
prefix: A string to put at the beginning of metric keys.
experiment: WandB experiment object. Automatically set when creating a run.
\**kwargs: Arguments passed to :func:`wandb.init` like `entity`, `group`, `tags`, etc.
@@ -71,15 +85,16 @@ class WandbLogger(LightningLoggerBase):
from pytorch_lightning.loggers import WandbLogger
from pytorch_lightning import Trainer
- wandb_logger = WandbLogger()
+
+ # instrument experiment with W&B
+ wandb_logger = WandbLogger(project='MNIST', log_model='all')
trainer = Trainer(logger=wandb_logger)
- Note: When logging manually through `wandb.log` or `trainer.logger.experiment.log`,
- make sure to use `commit=False` so the logging step does not increase.
+ # log gradients and model topology
+ wandb_logger.watch(model)
See Also:
- - `Tutorial `__
- on how to use W&B with PyTorch Lightning
+ - `Demo in Google Colab `__ with model logging
- `W&B Documentation `__
"""
@@ -114,6 +129,13 @@ class WandbLogger(LightningLoggerBase):
'Hint: Set `offline=False` to log your model.'
)
+ if log_model and not _WANDB_GREATER_EQUAL_0_10_22:
+ warning_cache.warn(
+ f'Providing log_model={log_model} requires wandb version >= 0.10.22'
+ ' for logging associated model metadata.\n'
+ 'Hint: Upgrade with `pip install --ugrade wandb`.'
+ )
+
if sync_step is not None:
warning_cache.warn(
"`WandbLogger(sync_step=(True|False))` is deprecated in v1.2.1 and will be removed in v1.5."
@@ -125,6 +147,8 @@ class WandbLogger(LightningLoggerBase):
self._log_model = log_model
self._prefix = prefix
self._experiment = experiment
+ self._logged_model_time = {}
+ self._checkpoint_callback = None
# set wandb init arguments
anonymous_lut = {True: 'allow', False: None}
self._wandb_init = dict(
@@ -168,10 +192,6 @@ class WandbLogger(LightningLoggerBase):
os.environ['WANDB_MODE'] = 'dryrun'
self._experiment = wandb.init(**self._wandb_init) if wandb.run is None else wandb.run
- # save checkpoints in wandb dir to upload on W&B servers
- if self._save_dir is None:
- self._save_dir = self._experiment.dir
-
# define default x-axis (for latest wandb versions)
if getattr(self._experiment, "define_metric", None):
self._experiment.define_metric("trainer/global_step")
@@ -213,8 +233,49 @@ class WandbLogger(LightningLoggerBase):
# don't create an experiment if we don't have one
return self._experiment.id if self._experiment else self._id
+ def after_save_checkpoint(self, checkpoint_callback: 'ReferenceType[ModelCheckpoint]') -> None:
+ # log checkpoints as artifacts
+ if self._log_model == 'all' or self._log_model is True and checkpoint_callback.save_top_k == -1:
+ self._scan_and_log_checkpoints(checkpoint_callback)
+ elif self._log_model is True:
+ self._checkpoint_callback = checkpoint_callback
+
@rank_zero_only
def finalize(self, status: str) -> None:
- # upload all checkpoints from saving dir
- if self._log_model:
- wandb.save(os.path.join(self.save_dir, "*.ckpt"))
+ # log checkpoints as artifacts
+ if self._checkpoint_callback:
+ self._scan_and_log_checkpoints(self._checkpoint_callback)
+
+ def _scan_and_log_checkpoints(self, checkpoint_callback: 'ReferenceType[ModelCheckpoint]') -> None:
+ # get checkpoints to be saved with associated score
+ checkpoints = {
+ checkpoint_callback.last_model_path: checkpoint_callback.current_score,
+ checkpoint_callback.best_model_path: checkpoint_callback.best_model_score,
+ **checkpoint_callback.best_k_models
+ }
+ checkpoints = sorted([(Path(p).stat().st_mtime, p, s) for p, s in checkpoints.items() if Path(p).is_file()])
+ checkpoints = [
+ c for c in checkpoints if c[1] not in self._logged_model_time.keys() or self._logged_model_time[c[1]] < c[0]
+ ]
+
+ # log iteratively all new checkpoints
+ for t, p, s in checkpoints:
+ metadata = {
+ 'score': s,
+ 'original_filename': Path(p).name,
+ 'ModelCheckpoint': {
+ k: getattr(checkpoint_callback, k)
+ for k in [
+ 'monitor', 'mode', 'save_last', 'save_top_k', 'save_weights_only', '_every_n_train_steps',
+ '_every_n_val_epochs'
+ ]
+ # ensure it does not break if `ModelCheckpoint` args change
+ if hasattr(checkpoint_callback, k)
+ }
+ } if _WANDB_GREATER_EQUAL_0_10_22 else None
+ artifact = wandb.Artifact(name=f"model-{self.experiment.id}", type="model", metadata=metadata)
+ artifact.add_file(p, name='model.ckpt')
+ aliases = ["latest", "best"] if p == checkpoint_callback.best_model_path else ["latest"]
+ self.experiment.log_artifact(artifact, aliases=aliases)
+ # remember logged models - timestamp needed in case filename didn't change (lastkckpt or custom name)
+ self._logged_model_time[p] = t
diff --git a/tests/loggers/test_base.py b/tests/loggers/test_base.py
index c20b609658..9209083148 100644
--- a/tests/loggers/test_base.py
+++ b/tests/loggers/test_base.py
@@ -59,6 +59,7 @@ class CustomLogger(LightningLoggerBase):
self.hparams_logged = None
self.metrics_logged = {}
self.finalized = False
+ self.after_save_checkpoint_called = False
@property
def experiment(self):
@@ -92,6 +93,9 @@ class CustomLogger(LightningLoggerBase):
def version(self):
return "1"
+ def after_save_checkpoint(self, checkpoint_callback):
+ self.after_save_checkpoint_called = True
+
def test_custom_logger(tmpdir):
@@ -115,6 +119,7 @@ def test_custom_logger(tmpdir):
assert trainer.state.finished, f"Training failed with {trainer.state}"
assert logger.hparams_logged == model.hparams
assert logger.metrics_logged != {}
+ assert logger.after_save_checkpoint_called
assert logger.finalized_status == "success"
diff --git a/tests/loggers/test_wandb.py b/tests/loggers/test_wandb.py
index 22be315eaa..27185b911b 100644
--- a/tests/loggers/test_wandb.py
+++ b/tests/loggers/test_wandb.py
@@ -24,14 +24,8 @@ from pytorch_lightning.utilities.exceptions import MisconfigurationException
from tests.helpers import BoringModel
-def get_warnings(recwarn):
- warnings_text = '\n'.join(str(w.message) for w in recwarn.list)
- recwarn.clear()
- return warnings_text
-
-
@mock.patch('pytorch_lightning.loggers.wandb.wandb')
-def test_wandb_logger_init(wandb, recwarn):
+def test_wandb_logger_init(wandb):
"""Verify that basic functionality of wandb logger works.
Wandb doesn't work well with pytest so we have to mock it out here."""
@@ -51,8 +45,6 @@ def test_wandb_logger_init(wandb, recwarn):
run = wandb.init()
logger = WandbLogger(experiment=run)
assert logger.experiment
- assert run.dir is not None
- assert logger.save_dir == run.dir
# test wandb.init not called if there is a W&B run
wandb.init().log.reset_mock()
@@ -140,10 +132,8 @@ def test_wandb_logger_dirs_creation(wandb, tmpdir):
# mock return values of experiment
wandb.run = None
- wandb.init().step = 0
logger.experiment.id = '1'
logger.experiment.project_name.return_value = 'project'
- logger.experiment.step = 0
for _ in range(2):
_ = logger.experiment
@@ -164,6 +154,71 @@ def test_wandb_logger_dirs_creation(wandb, tmpdir):
assert trainer.log_dir == logger.save_dir
+@mock.patch('pytorch_lightning.loggers.wandb.wandb')
+def test_wandb_log_model(wandb, tmpdir):
+ """ Test that the logger creates the folders and files in the right place. """
+
+ wandb.run = None
+ model = BoringModel()
+
+ # test log_model=True
+ logger = WandbLogger(log_model=True)
+ logger.experiment.id = '1'
+ logger.experiment.project_name.return_value = 'project'
+ trainer = Trainer(default_root_dir=tmpdir, logger=logger, max_epochs=2, limit_train_batches=3, limit_val_batches=3)
+ trainer.fit(model)
+ wandb.init().log_artifact.assert_called_once()
+
+ # test log_model='all'
+ wandb.init().log_artifact.reset_mock()
+ wandb.init.reset_mock()
+ logger = WandbLogger(log_model='all')
+ logger.experiment.id = '1'
+ logger.experiment.project_name.return_value = 'project'
+ trainer = Trainer(default_root_dir=tmpdir, logger=logger, max_epochs=2, limit_train_batches=3, limit_val_batches=3)
+ trainer.fit(model)
+ assert wandb.init().log_artifact.call_count == 2
+
+ # test log_model=False
+ wandb.init().log_artifact.reset_mock()
+ wandb.init.reset_mock()
+ logger = WandbLogger(log_model=False)
+ logger.experiment.id = '1'
+ logger.experiment.project_name.return_value = 'project'
+ trainer = Trainer(default_root_dir=tmpdir, logger=logger, max_epochs=2, limit_train_batches=3, limit_val_batches=3)
+ trainer.fit(model)
+ assert not wandb.init().log_artifact.called
+
+ # test correct metadata
+ import pytorch_lightning.loggers.wandb as pl_wandb
+ pl_wandb._WANDB_GREATER_EQUAL_0_10_22 = True
+ wandb.init().log_artifact.reset_mock()
+ wandb.init.reset_mock()
+ wandb.Artifact.reset_mock()
+ logger = pl_wandb.WandbLogger(log_model=True)
+ logger.experiment.id = '1'
+ logger.experiment.project_name.return_value = 'project'
+ trainer = Trainer(default_root_dir=tmpdir, logger=logger, max_epochs=2, limit_train_batches=3, limit_val_batches=3)
+ trainer.fit(model)
+ wandb.Artifact.assert_called_once_with(
+ name='model-1',
+ type='model',
+ metadata={
+ 'score': None,
+ 'original_filename': 'epoch=1-step=5-v3.ckpt',
+ 'ModelCheckpoint': {
+ 'monitor': None,
+ 'mode': 'min',
+ 'save_last': None,
+ 'save_top_k': None,
+ 'save_weights_only': False,
+ '_every_n_train_steps': 0,
+ '_every_n_val_epochs': 1
+ }
+ }
+ )
+
+
def test_wandb_sanitize_callable_params(tmpdir):
"""
Callback function are not serializiable. Therefore, we get them a chance to return