lightning/pytorch_lightning/metrics/functional/ssim.py

153 lines
5.6 KiB
Python
Raw Normal View History

# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Optional, Sequence, Tuple
import torch
from pytorch_lightning.metrics.functional.reduction import reduce
from pytorch_lightning.metrics.utils import _check_same_shape
from torch.nn import functional as F
def _gaussian(kernel_size: int, sigma: int, dtype: torch.dtype, device: torch.device):
dist = torch.arange(start=(1 - kernel_size) / 2, end=(1 + kernel_size) / 2, step=1, dtype=dtype, device=device)
gauss = torch.exp(-torch.pow(dist / sigma, 2) / 2)
return (gauss / gauss.sum()).unsqueeze(dim=0) # (1, kernel_size)
def _gaussian_kernel(channel: int, kernel_size: Sequence[int], sigma: Sequence[float],
dtype: torch.dtype, device: torch.device):
gaussian_kernel_x = _gaussian(kernel_size[0], sigma[0], dtype, device)
gaussian_kernel_y = _gaussian(kernel_size[1], sigma[1], dtype, device)
kernel = torch.matmul(gaussian_kernel_x.t(), gaussian_kernel_y) # (kernel_size, 1) * (1, kernel_size)
return kernel.expand(channel, 1, kernel_size[0], kernel_size[1])
def _ssim_update(
preds: torch.Tensor,
target: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
if preds.dtype != target.dtype:
raise TypeError(
"Expected `preds` and `target` to have the same data type."
f" Got pred: {preds.dtype} and target: {target.dtype}."
)
_check_same_shape(preds, target)
if len(preds.shape) != 4:
raise ValueError(
"Expected `preds` and `target` to have BxCxHxW shape."
f" Got pred: {preds.shape} and target: {target.shape}."
)
return preds, target
def _ssim_compute(
preds: torch.Tensor,
target: torch.Tensor,
kernel_size: Sequence[int] = (11, 11),
sigma: Sequence[float] = (1.5, 1.5),
reduction: str = "elementwise_mean",
data_range: Optional[float] = None,
k1: float = 0.01,
k2: float = 0.03,
):
if len(kernel_size) != 2 or len(sigma) != 2:
raise ValueError(
"Expected `kernel_size` and `sigma` to have the length of two."
f" Got kernel_size: {len(kernel_size)} and sigma: {len(sigma)}."
)
if any(x % 2 == 0 or x <= 0 for x in kernel_size):
raise ValueError(f"Expected `kernel_size` to have odd positive number. Got {kernel_size}.")
if any(y <= 0 for y in sigma):
raise ValueError(f"Expected `sigma` to have positive number. Got {sigma}.")
if data_range is None:
data_range = max(preds.max() - preds.min(), target.max() - target.min())
c1 = pow(k1 * data_range, 2)
c2 = pow(k2 * data_range, 2)
device = preds.device
channel = preds.size(1)
dtype = preds.dtype
kernel = _gaussian_kernel(channel, kernel_size, sigma, dtype, device)
pad_w = (kernel_size[0] - 1) // 2
pad_h = (kernel_size[1] - 1) // 2
preds = F.pad(preds, (pad_w, pad_w, pad_h, pad_h), mode='reflect')
target = F.pad(target, (pad_w, pad_w, pad_h, pad_h), mode='reflect')
input_list = torch.cat((preds, target, preds * preds, target * target, preds * target)) # (5 * B, C, H, W)
outputs = F.conv2d(input_list, kernel, groups=channel)
output_list = [outputs[x * preds.size(0): (x + 1) * preds.size(0)] for x in range(len(outputs))]
mu_pred_sq = output_list[0].pow(2)
mu_target_sq = output_list[1].pow(2)
mu_pred_target = output_list[0] * output_list[1]
sigma_pred_sq = output_list[2] - mu_pred_sq
sigma_target_sq = output_list[3] - mu_target_sq
sigma_pred_target = output_list[4] - mu_pred_target
upper = 2 * sigma_pred_target + c2
lower = sigma_pred_sq + sigma_target_sq + c2
ssim_idx = ((2 * mu_pred_target + c1) * upper) / ((mu_pred_sq + mu_target_sq + c1) * lower)
ssim_idx = ssim_idx[..., pad_h:-pad_h, pad_w:-pad_w]
return reduce(ssim_idx, reduction)
def ssim(
preds: torch.Tensor,
target: torch.Tensor,
kernel_size: Sequence[int] = (11, 11),
sigma: Sequence[float] = (1.5, 1.5),
reduction: str = "elementwise_mean",
data_range: Optional[float] = None,
k1: float = 0.01,
k2: float = 0.03,
) -> torch.Tensor:
"""
Computes Structual Similarity Index Measure
Args:
pred: estimated image
target: ground truth image
kernel_size: size of the gaussian kernel (default: (11, 11))
sigma: Standard deviation of the gaussian kernel (default: (1.5, 1.5))
reduction: a method to reduce metric score over labels.
- ``'elementwise_mean'``: takes the mean (default)
- ``'sum'``: takes the sum
- ``'none'``: no reduction will be applied
data_range: Range of the image. If ``None``, it is determined from the image (max - min)
k1: Parameter of SSIM. Default: 0.01
k2: Parameter of SSIM. Default: 0.03
Return:
Tensor with SSIM score
Example:
>>> preds = torch.rand([16, 1, 16, 16])
>>> target = preds * 0.75
>>> ssim(preds, target)
tensor(0.9219)
"""
preds, target = _ssim_update(preds, target)
return _ssim_compute(preds, target, kernel_size, sigma, reduction, data_range, k1, k2)