diff --git a/src/pytorch_lightning/CHANGELOG.md b/src/pytorch_lightning/CHANGELOG.md index 5bd34c2864..5c064c5296 100644 --- a/src/pytorch_lightning/CHANGELOG.md +++ b/src/pytorch_lightning/CHANGELOG.md @@ -54,6 +54,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixed an issue with the `BaseFinetuning` callback not setting the `track_running_stats` attribute for batch normaliztion layers ([#15063](https://github.com/Lightning-AI/lightning/pull/15063)) +- Fixed an issue with `WandbLogger(log_model=True|'all)` raising an error and not being able to serialize tensors in the metadata ([#15544](https://github.com/Lightning-AI/lightning/pull/15544)) + ## [1.8.0] - 2022-11-01 diff --git a/src/pytorch_lightning/loggers/wandb.py b/src/pytorch_lightning/loggers/wandb.py index 5d60989c65..55757487d7 100644 --- a/src/pytorch_lightning/loggers/wandb.py +++ b/src/pytorch_lightning/loggers/wandb.py @@ -22,6 +22,7 @@ from typing import Any, Dict, List, Mapping, Optional, Union import torch.nn as nn from lightning_utilities.core.imports import RequirementCache +from torch import Tensor from pytorch_lightning.callbacks import Checkpoint from pytorch_lightning.loggers.logger import Logger, rank_zero_experiment @@ -572,7 +573,7 @@ class WandbLogger(Logger): for t, p, s, tag in checkpoints: metadata = ( { - "score": s, + "score": s.item() if isinstance(s, Tensor) else s, "original_filename": Path(p).name, checkpoint_callback.__class__.__name__: { k: getattr(checkpoint_callback, k) diff --git a/tests/tests_pytorch/loggers/test_wandb.py b/tests/tests_pytorch/loggers/test_wandb.py index ec7661aadb..3ce1401efa 100644 --- a/tests/tests_pytorch/loggers/test_wandb.py +++ b/tests/tests_pytorch/loggers/test_wandb.py @@ -19,6 +19,7 @@ import pytest import pytorch_lightning from pytorch_lightning import Trainer +from pytorch_lightning.callbacks import ModelCheckpoint from pytorch_lightning.demos.boring_classes import BoringModel from pytorch_lightning.loggers import WandbLogger from pytorch_lightning.utilities.exceptions import MisconfigurationException @@ -251,6 +252,40 @@ def test_wandb_log_model(wandb, monkeypatch, tmpdir): ) +@mock.patch("pytorch_lightning.loggers.wandb.Run", new=mock.Mock) +@mock.patch("pytorch_lightning.loggers.wandb.wandb") +def test_wandb_log_model_with_score(wandb, monkeypatch, tmpdir): + """Test to prevent regression on #15543, ensuring the score is logged as a Python number, not a scalar + tensor.""" + monkeypatch.setattr(pytorch_lightning.loggers.wandb, "_WANDB_GREATER_EQUAL_0_10_22", True) + + wandb.run = None + model = BoringModel() + + wandb.init().log_artifact.reset_mock() + wandb.init.reset_mock() + wandb.Artifact.reset_mock() + logger = WandbLogger(log_model=True) + logger.experiment.id = "1" + logger.experiment.name = "run_name" + checkpoint_callback = ModelCheckpoint(monitor="step") + trainer = Trainer( + default_root_dir=tmpdir, + logger=logger, + callbacks=[checkpoint_callback], + max_epochs=1, + limit_train_batches=3, + limit_val_batches=1, + ) + trainer.fit(model) + + calls = wandb.Artifact.call_args_list + assert len(calls) == 1 + score = calls[0][1]["metadata"]["score"] + # model checkpoint monitors scalar tensors, but wandb can't serializable them - expect Python scalars in metadata + assert isinstance(score, int) and score == 3 + + @mock.patch("pytorch_lightning.loggers.wandb.Run", new=mock.Mock) @mock.patch("pytorch_lightning.loggers.wandb.wandb") def test_wandb_log_media(wandb, tmpdir):