lightning/tests/metrics/test_metrics.py

86 lines
2.4 KiB
Python

import numpy as np
import torch
from pytorch_lightning.metrics.metric import Metric, TensorMetric, NumpyMetric
class DummyTensorMetric(TensorMetric):
def __init__(self):
super().__init__('dummy')
def forward(self, input1, input2):
assert isinstance(input1, torch.Tensor)
assert isinstance(input2, torch.Tensor)
return 1.
class DummyNumpyMetric(NumpyMetric):
def __init__(self):
super().__init__('dummy')
def forward(self, input1, input2):
assert isinstance(input1, np.ndarray)
assert isinstance(input2, np.ndarray)
return 1.
def _test_metric(metric: Metric):
input1, input2 = torch.tensor([1.]), torch.tensor([2.])
def change_and_check_device_dtype(device, dtype):
metric.to(device=device, dtype=dtype)
metric_val = metric(input1, input2)
assert isinstance(metric_val, torch.Tensor)
if device is not None:
assert metric.device in [device, torch.device(device)]
assert metric_val.device in [device, torch.device(device)]
if dtype is not None:
assert metric.dtype == dtype
assert metric_val.dtype == dtype
devices = [None, 'cpu']
if torch.cuda.is_available():
devices += ['cuda:0']
for device in devices:
for dtype in [None, torch.float32, torch.float64]:
change_and_check_device_dtype(device=device, dtype=dtype)
if torch.cuda.is_available():
metric.cuda(0)
assert metric.device == torch.device('cuda', index=0)
assert metric(input1, input2).device == torch.device('cuda', index=0)
metric.cpu()
assert metric.device == torch.device('cpu')
assert metric(input1, input2).device == torch.device('cpu')
metric.type(torch.int8)
assert metric.dtype == torch.int8
assert metric(input1, input2).dtype == torch.int8
metric.float()
assert metric.dtype == torch.float32
assert metric(input1, input2).dtype == torch.float32
metric.double()
assert metric.dtype == torch.float64
assert metric(input1, input2).dtype == torch.float64
if torch.cuda.is_available():
metric.cuda()
metric.half()
assert metric.dtype == torch.float16
assert metric(input1, input2).dtype == torch.float16
def test_tensor_metric():
_test_metric(DummyTensorMetric())
def test_numpy_metric():
_test_metric(DummyNumpyMetric())