lightning/pytorch_lightning/metrics/functional/psnr.py

66 lines
2.0 KiB
Python

from typing import Optional, Tuple
import torch
from pytorch_lightning.utilities import rank_zero_warn
def _psnr_compute(
sum_squared_error: torch.Tensor,
n_obs: int,
data_range: float,
base: float = 10.0,
reduction: str = 'elementwise_mean',
) -> torch.Tensor:
if reduction != 'elementwise_mean':
rank_zero_warn(f'The `reduction={reduction}` parameter is unused and will not have any effect.')
psnr_base_e = 2 * torch.log(data_range) - torch.log(sum_squared_error / n_obs)
psnr = psnr_base_e * (10 / torch.log(torch.tensor(base)))
return psnr
def _psnr_update(preds: torch.Tensor, target: torch.Tensor) -> Tuple[torch.Tensor, int]:
sum_squared_error = torch.sum(torch.pow(preds - target, 2))
n_obs = target.numel()
return sum_squared_error, n_obs
def psnr(
preds: torch.Tensor,
target: torch.Tensor,
data_range: Optional[float] = None,
base: float = 10.0,
reduction: str = 'elementwise_mean',
) -> torch.Tensor:
"""
Computes the peak signal-to-noise ratio
Args:
preds: estimated signal
target: groun truth signal
data_range: the range of the data. If None, it is determined from the data (max - min)
base: a base of a logarithm to use (default: 10)
reduction: a method to reduce metric score over labels.
- ``'elementwise_mean'``: takes the mean (default)
- ``'sum'``: takes the sum
- ``'none'``: no reduction will be applied
Return:
Tensor with PSNR score
Example:
>>> pred = torch.tensor([[0.0, 1.0], [2.0, 3.0]])
>>> target = torch.tensor([[3.0, 2.0], [1.0, 0.0]])
>>> psnr(pred, target)
tensor(2.5527)
"""
if data_range is None:
data_range = target.max() - target.min()
else:
data_range = torch.tensor(float(data_range))
sum_squared_error, n_obs = _psnr_update(preds, target)
return _psnr_compute(sum_squared_error, n_obs, data_range, base, reduction)