Let metadata `score` be serializable by wand (#15544)
Co-authored-by: awaelchli <aedu.waelchli@gmail.com>
This commit is contained in:
parent
12d6e44796
commit
7bdfced27c
|
@ -54,6 +54,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
|
|||
|
||||
- Fixed an issue with the `BaseFinetuning` callback not setting the `track_running_stats` attribute for batch normaliztion layers ([#15063](https://github.com/Lightning-AI/lightning/pull/15063))
|
||||
|
||||
- Fixed an issue with `WandbLogger(log_model=True|'all)` raising an error and not being able to serialize tensors in the metadata ([#15544](https://github.com/Lightning-AI/lightning/pull/15544))
|
||||
|
||||
|
||||
## [1.8.0] - 2022-11-01
|
||||
|
||||
|
|
|
@ -22,6 +22,7 @@ from typing import Any, Dict, List, Mapping, Optional, Union
|
|||
|
||||
import torch.nn as nn
|
||||
from lightning_utilities.core.imports import RequirementCache
|
||||
from torch import Tensor
|
||||
|
||||
from pytorch_lightning.callbacks import Checkpoint
|
||||
from pytorch_lightning.loggers.logger import Logger, rank_zero_experiment
|
||||
|
@ -572,7 +573,7 @@ class WandbLogger(Logger):
|
|||
for t, p, s, tag in checkpoints:
|
||||
metadata = (
|
||||
{
|
||||
"score": s,
|
||||
"score": s.item() if isinstance(s, Tensor) else s,
|
||||
"original_filename": Path(p).name,
|
||||
checkpoint_callback.__class__.__name__: {
|
||||
k: getattr(checkpoint_callback, k)
|
||||
|
|
|
@ -19,6 +19,7 @@ import pytest
|
|||
|
||||
import pytorch_lightning
|
||||
from pytorch_lightning import Trainer
|
||||
from pytorch_lightning.callbacks import ModelCheckpoint
|
||||
from pytorch_lightning.demos.boring_classes import BoringModel
|
||||
from pytorch_lightning.loggers import WandbLogger
|
||||
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
||||
|
@ -251,6 +252,40 @@ def test_wandb_log_model(wandb, monkeypatch, tmpdir):
|
|||
)
|
||||
|
||||
|
||||
@mock.patch("pytorch_lightning.loggers.wandb.Run", new=mock.Mock)
|
||||
@mock.patch("pytorch_lightning.loggers.wandb.wandb")
|
||||
def test_wandb_log_model_with_score(wandb, monkeypatch, tmpdir):
|
||||
"""Test to prevent regression on #15543, ensuring the score is logged as a Python number, not a scalar
|
||||
tensor."""
|
||||
monkeypatch.setattr(pytorch_lightning.loggers.wandb, "_WANDB_GREATER_EQUAL_0_10_22", True)
|
||||
|
||||
wandb.run = None
|
||||
model = BoringModel()
|
||||
|
||||
wandb.init().log_artifact.reset_mock()
|
||||
wandb.init.reset_mock()
|
||||
wandb.Artifact.reset_mock()
|
||||
logger = WandbLogger(log_model=True)
|
||||
logger.experiment.id = "1"
|
||||
logger.experiment.name = "run_name"
|
||||
checkpoint_callback = ModelCheckpoint(monitor="step")
|
||||
trainer = Trainer(
|
||||
default_root_dir=tmpdir,
|
||||
logger=logger,
|
||||
callbacks=[checkpoint_callback],
|
||||
max_epochs=1,
|
||||
limit_train_batches=3,
|
||||
limit_val_batches=1,
|
||||
)
|
||||
trainer.fit(model)
|
||||
|
||||
calls = wandb.Artifact.call_args_list
|
||||
assert len(calls) == 1
|
||||
score = calls[0][1]["metadata"]["score"]
|
||||
# model checkpoint monitors scalar tensors, but wandb can't serializable them - expect Python scalars in metadata
|
||||
assert isinstance(score, int) and score == 3
|
||||
|
||||
|
||||
@mock.patch("pytorch_lightning.loggers.wandb.Run", new=mock.Mock)
|
||||
@mock.patch("pytorch_lightning.loggers.wandb.wandb")
|
||||
def test_wandb_log_media(wandb, tmpdir):
|
||||
|
|
Loading…
Reference in New Issue