103 lines
3.6 KiB
Python
103 lines
3.6 KiB
Python
from collections import namedtuple
|
|
from functools import partial
|
|
|
|
import pytest
|
|
import torch
|
|
from skimage.metrics import structural_similarity
|
|
|
|
from pytorch_lightning.metrics.functional import ssim
|
|
from pytorch_lightning.metrics.regression import SSIM
|
|
from tests.metrics.utils import BATCH_SIZE, MetricTester, NUM_BATCHES
|
|
|
|
torch.manual_seed(42)
|
|
|
|
|
|
Input = namedtuple('Input', ["preds", "target", "multichannel"])
|
|
|
|
|
|
_inputs = []
|
|
for size, channel, coef, multichannel, dtype in [
|
|
(12, 3, 0.9, True, torch.float),
|
|
(13, 1, 0.8, False, torch.float32),
|
|
(14, 1, 0.7, False, torch.double),
|
|
(15, 3, 0.6, True, torch.float64),
|
|
]:
|
|
preds = torch.rand(NUM_BATCHES, BATCH_SIZE, channel, size, size, dtype=dtype)
|
|
_inputs.append(
|
|
Input(
|
|
preds=preds,
|
|
target=preds * coef,
|
|
multichannel=multichannel,
|
|
)
|
|
)
|
|
|
|
|
|
def _sk_metric(preds, target, data_range, multichannel):
|
|
c, h, w = preds.shape[-3:]
|
|
sk_preds = preds.view(-1, c, h, w).permute(0, 2, 3, 1).numpy()
|
|
sk_target = target.view(-1, c, h, w).permute(0, 2, 3, 1).numpy()
|
|
if not multichannel:
|
|
sk_preds = sk_preds[:, :, :, 0]
|
|
sk_target = sk_target[:, :, :, 0]
|
|
|
|
return structural_similarity(
|
|
sk_target, sk_preds, data_range=data_range, multichannel=multichannel,
|
|
gaussian_weights=True, win_size=11, sigma=1.5, use_sample_covariance=False
|
|
)
|
|
|
|
|
|
@pytest.mark.parametrize(
|
|
"preds, target, multichannel",
|
|
[(i.preds, i.target, i.multichannel) for i in _inputs],
|
|
)
|
|
class TestSSIM(MetricTester):
|
|
atol = 6e-5
|
|
|
|
@pytest.mark.parametrize("ddp", [True, False])
|
|
@pytest.mark.parametrize("dist_sync_on_step", [True, False])
|
|
def test_ssim(self, preds, target, multichannel, ddp, dist_sync_on_step):
|
|
self.run_class_metric_test(
|
|
ddp,
|
|
preds,
|
|
target,
|
|
SSIM,
|
|
partial(_sk_metric, data_range=1.0, multichannel=multichannel),
|
|
metric_args={"data_range": 1.0},
|
|
dist_sync_on_step=dist_sync_on_step,
|
|
)
|
|
|
|
def test_ssim_functional(self, preds, target, multichannel):
|
|
self.run_functional_metric_test(
|
|
preds,
|
|
target,
|
|
ssim,
|
|
partial(_sk_metric, data_range=1.0, multichannel=multichannel),
|
|
metric_args={"data_range": 1.0},
|
|
)
|
|
|
|
|
|
@pytest.mark.parametrize(
|
|
['pred', 'target', 'kernel', 'sigma'],
|
|
[
|
|
pytest.param([1, 16, 16], [1, 16, 16], [11, 11], [1.5, 1.5]), # len(shape)
|
|
pytest.param([1, 1, 16, 16], [1, 1, 16, 16], [11, 11], [1.5]), # len(kernel), len(sigma)
|
|
pytest.param([1, 1, 16, 16], [1, 1, 16, 16], [11], [1.5, 1.5]), # len(kernel), len(sigma)
|
|
pytest.param([1, 1, 16, 16], [1, 1, 16, 16], [11], [1.5]), # len(kernel), len(sigma)
|
|
pytest.param([1, 1, 16, 16], [1, 1, 16, 16], [11, 0], [1.5, 1.5]), # invalid kernel input
|
|
pytest.param([1, 1, 16, 16], [1, 1, 16, 16], [11, 10], [1.5, 1.5]), # invalid kernel input
|
|
pytest.param([1, 1, 16, 16], [1, 1, 16, 16], [11, -11], [1.5, 1.5]), # invalid kernel input
|
|
pytest.param([1, 1, 16, 16], [1, 1, 16, 16], [11, 11], [1.5, 0]), # invalid sigma input
|
|
pytest.param([1, 1, 16, 16], [1, 1, 16, 16], [11, 0], [1.5, -1.5]), # invalid sigma input
|
|
],
|
|
)
|
|
def test_ssim_invalid_inputs(pred, target, kernel, sigma):
|
|
pred_t = torch.rand(pred)
|
|
target_t = torch.rand(target, dtype=torch.float64)
|
|
with pytest.raises(TypeError):
|
|
ssim(pred_t, target_t)
|
|
|
|
pred = torch.rand(pred)
|
|
target = torch.rand(target)
|
|
with pytest.raises(ValueError):
|
|
ssim(pred, target, kernel, sigma)
|