Let metadata `score` be serializable by wand (#15544)

Co-authored-by: awaelchli <aedu.waelchli@gmail.com>
This commit is contained in:
geoffrey-g-delhomme 2022-11-05 15:51:49 +01:00 committed by GitHub
parent 12d6e44796
commit 7bdfced27c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 39 additions and 1 deletions

View File

@ -54,6 +54,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Fixed an issue with the `BaseFinetuning` callback not setting the `track_running_stats` attribute for batch normaliztion layers ([#15063](https://github.com/Lightning-AI/lightning/pull/15063))
- Fixed an issue with `WandbLogger(log_model=True|'all)` raising an error and not being able to serialize tensors in the metadata ([#15544](https://github.com/Lightning-AI/lightning/pull/15544))
## [1.8.0] - 2022-11-01

View File

@ -22,6 +22,7 @@ from typing import Any, Dict, List, Mapping, Optional, Union
import torch.nn as nn
from lightning_utilities.core.imports import RequirementCache
from torch import Tensor
from pytorch_lightning.callbacks import Checkpoint
from pytorch_lightning.loggers.logger import Logger, rank_zero_experiment
@ -572,7 +573,7 @@ class WandbLogger(Logger):
for t, p, s, tag in checkpoints:
metadata = (
{
"score": s,
"score": s.item() if isinstance(s, Tensor) else s,
"original_filename": Path(p).name,
checkpoint_callback.__class__.__name__: {
k: getattr(checkpoint_callback, k)

View File

@ -19,6 +19,7 @@ import pytest
import pytorch_lightning
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.demos.boring_classes import BoringModel
from pytorch_lightning.loggers import WandbLogger
from pytorch_lightning.utilities.exceptions import MisconfigurationException
@ -251,6 +252,40 @@ def test_wandb_log_model(wandb, monkeypatch, tmpdir):
)
@mock.patch("pytorch_lightning.loggers.wandb.Run", new=mock.Mock)
@mock.patch("pytorch_lightning.loggers.wandb.wandb")
def test_wandb_log_model_with_score(wandb, monkeypatch, tmpdir):
"""Test to prevent regression on #15543, ensuring the score is logged as a Python number, not a scalar
tensor."""
monkeypatch.setattr(pytorch_lightning.loggers.wandb, "_WANDB_GREATER_EQUAL_0_10_22", True)
wandb.run = None
model = BoringModel()
wandb.init().log_artifact.reset_mock()
wandb.init.reset_mock()
wandb.Artifact.reset_mock()
logger = WandbLogger(log_model=True)
logger.experiment.id = "1"
logger.experiment.name = "run_name"
checkpoint_callback = ModelCheckpoint(monitor="step")
trainer = Trainer(
default_root_dir=tmpdir,
logger=logger,
callbacks=[checkpoint_callback],
max_epochs=1,
limit_train_batches=3,
limit_val_batches=1,
)
trainer.fit(model)
calls = wandb.Artifact.call_args_list
assert len(calls) == 1
score = calls[0][1]["metadata"]["score"]
# model checkpoint monitors scalar tensors, but wandb can't serializable them - expect Python scalars in metadata
assert isinstance(score, int) and score == 3
@mock.patch("pytorch_lightning.loggers.wandb.Run", new=mock.Mock)
@mock.patch("pytorch_lightning.loggers.wandb.wandb")
def test_wandb_log_media(wandb, tmpdir):