324 lines
9.2 KiB
Python
324 lines
9.2 KiB
Python
import os
|
|
from typing import Any
|
|
import numpy as np
|
|
import pytest
|
|
import torch
|
|
|
|
import tests.base.develop_utils as tutils
|
|
from tests.base import EvalModelTemplate
|
|
from pytorch_lightning.metrics.metric import Metric, TensorMetric, NumpyMetric
|
|
from pytorch_lightning import Trainer
|
|
|
|
|
|
class DummyTensorMetric(TensorMetric):
|
|
def __init__(self):
|
|
super().__init__("dummy")
|
|
|
|
def forward(self, input1, input2):
|
|
assert isinstance(input1, torch.Tensor)
|
|
assert isinstance(input2, torch.Tensor)
|
|
return torch.tensor([1.0])
|
|
|
|
|
|
class DummyNumpyMetric(NumpyMetric):
|
|
def __init__(self):
|
|
super().__init__("dummy")
|
|
|
|
def forward(self, input1, input2):
|
|
assert isinstance(input1, np.ndarray)
|
|
assert isinstance(input2, np.ndarray)
|
|
return 1.0
|
|
|
|
|
|
class DummyTensorCollectionMetric(TensorMetric):
|
|
def __init__(self):
|
|
super().__init__("dummy")
|
|
|
|
def forward(self, input1, input2):
|
|
assert isinstance(input1, torch.Tensor)
|
|
assert isinstance(input2, torch.Tensor)
|
|
return 1.0, 2.0, 3.0, 4.0
|
|
|
|
|
|
@pytest.mark.parametrize("metric", [DummyTensorCollectionMetric()])
|
|
def test_collection_metric(metric: Metric):
|
|
""" Test that metric.device, metric.dtype works for metric collection """
|
|
input1, input2 = torch.tensor([1.0]), torch.tensor([2.0])
|
|
|
|
def change_and_check_device_dtype(device, dtype):
|
|
metric.to(device=device, dtype=dtype)
|
|
|
|
metric_val = metric(input1, input2)
|
|
assert not isinstance(metric_val, torch.Tensor)
|
|
|
|
if device is not None:
|
|
assert metric.device in [device, torch.device(device)]
|
|
|
|
if dtype is not None:
|
|
assert metric.dtype == dtype
|
|
|
|
devices = [None, "cpu"]
|
|
if torch.cuda.is_available():
|
|
devices += ["cuda:0"]
|
|
|
|
for device in devices:
|
|
for dtype in [None, torch.float32, torch.float64]:
|
|
change_and_check_device_dtype(device=device, dtype=dtype)
|
|
|
|
if torch.cuda.is_available():
|
|
metric.cuda(0)
|
|
assert metric.device == torch.device("cuda", index=0)
|
|
|
|
metric.cpu()
|
|
assert metric.device == torch.device("cpu")
|
|
|
|
metric.type(torch.int8)
|
|
assert metric.dtype == torch.int8
|
|
|
|
metric.float()
|
|
assert metric.dtype == torch.float32
|
|
|
|
metric.double()
|
|
assert metric.dtype == torch.float64
|
|
assert all(out.dtype == torch.float64 for out in metric(input1, input2))
|
|
|
|
if torch.cuda.is_available():
|
|
metric.cuda()
|
|
metric.half()
|
|
assert metric.dtype == torch.float16
|
|
|
|
|
|
@pytest.mark.parametrize(
|
|
"metric",
|
|
[
|
|
DummyTensorMetric(),
|
|
DummyNumpyMetric(),
|
|
],
|
|
)
|
|
def test_metric(metric: Metric):
|
|
""" Test that metric.device, metric.dtype works for single metric"""
|
|
input1, input2 = torch.tensor([1.0]), torch.tensor([2.0])
|
|
|
|
def change_and_check_device_dtype(device, dtype):
|
|
metric.to(device=device, dtype=dtype)
|
|
|
|
metric_val = metric(input1, input2)
|
|
assert isinstance(metric_val, torch.Tensor)
|
|
|
|
if device is not None:
|
|
assert metric.device in [device, torch.device(device)]
|
|
assert metric_val.device in [device, torch.device(device)]
|
|
|
|
if dtype is not None:
|
|
assert metric.dtype == dtype
|
|
assert metric_val.dtype == dtype
|
|
|
|
devices = [None, "cpu"]
|
|
if torch.cuda.is_available():
|
|
devices += ["cuda:0"]
|
|
|
|
for device in devices:
|
|
for dtype in [None, torch.float32, torch.float64]:
|
|
change_and_check_device_dtype(device=device, dtype=dtype)
|
|
|
|
if torch.cuda.is_available():
|
|
metric.cuda(0)
|
|
assert metric.device == torch.device("cuda", index=0)
|
|
assert metric(input1, input2).device == torch.device("cuda", index=0)
|
|
|
|
metric.cpu()
|
|
assert metric.device == torch.device("cpu")
|
|
assert metric(input1, input2).device == torch.device("cpu")
|
|
|
|
metric.float()
|
|
assert metric.dtype == torch.float32
|
|
assert metric(input1, input2).dtype == torch.float32
|
|
|
|
metric.double()
|
|
assert metric.dtype == torch.float64
|
|
assert metric(input1, input2).dtype == torch.float64
|
|
|
|
if torch.cuda.is_available():
|
|
metric.cuda()
|
|
metric.half()
|
|
assert metric.dtype == torch.float16
|
|
assert metric(input1, input2).dtype == torch.float16
|
|
|
|
|
|
@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine")
|
|
@pytest.mark.parametrize("metric", [DummyTensorMetric, DummyNumpyMetric])
|
|
def test_model_pickable(tmpdir, metric: Metric):
|
|
"""Make sure that metrics are pickable by including into a model and running in multi-gpu mode"""
|
|
tutils.set_random_master_port()
|
|
|
|
trainer_options = dict(
|
|
default_root_dir=tmpdir,
|
|
max_epochs=1,
|
|
limit_train_batches=10,
|
|
gpus=[0, 1],
|
|
distributed_backend="ddp_spawn",
|
|
)
|
|
|
|
model = EvalModelTemplate()
|
|
model.metric = metric()
|
|
model.training_step = model.training_step__using_metrics
|
|
|
|
trainer = Trainer(**trainer_options)
|
|
result = trainer.fit(model)
|
|
|
|
# correct result and ok accuracy
|
|
assert result == 1, "ddp model failed to complete"
|
|
|
|
|
|
@pytest.mark.parametrize("metric", [DummyTensorMetric(), DummyNumpyMetric()])
|
|
def test_saving_pickable(tmpdir, metric: Metric):
|
|
""" Make sure that metrics are pickable by saving and loading them using torch """
|
|
x, y = torch.randn(10,), torch.randn(
|
|
10,
|
|
)
|
|
results_before_save = metric(x, y)
|
|
|
|
# save metric
|
|
save_path = os.path.join(tmpdir, "save_test.ckpt")
|
|
torch.save(metric, save_path)
|
|
|
|
# load metric
|
|
new_metric = torch.load(save_path)
|
|
results_after_load = new_metric(x, y)
|
|
|
|
# Check metric value is the same
|
|
assert results_before_save == results_after_load
|
|
|
|
|
|
def test_correct_call_order():
|
|
""" Check that hooks are called in the expected order """
|
|
|
|
class DummyMetric(Metric):
|
|
def __init__(self):
|
|
super().__init__("dummy")
|
|
self.call_history = ["init"]
|
|
|
|
@staticmethod
|
|
def input_convert(self, data: Any):
|
|
self.call_history.append("input_convert")
|
|
return super(DummyMetric, self).input_convert(self, data)
|
|
|
|
def forward(self, tensor1, tensor2):
|
|
self.call_history.append("forward")
|
|
return tensor1 - tensor2
|
|
|
|
@staticmethod
|
|
def output_convert(self, data: Any, output: Any):
|
|
self.call_history.append("output_convert")
|
|
return super(DummyMetric, self).output_convert(self, data, output)
|
|
|
|
def ddp_sync(self, tensor: Any):
|
|
self.call_history.append("ddp_sync")
|
|
return super().ddp_sync(tensor)
|
|
|
|
@staticmethod
|
|
def ddp_reduce(self, data: Any, output: Any):
|
|
self.call_history.append("ddp_reduce")
|
|
return super(DummyMetric, self).ddp_reduce(self, data, output)
|
|
|
|
def aggregate(self, *tensors: torch.Tensor) -> torch.Tensor:
|
|
self.call_history.append("aggregate")
|
|
return super().aggregate(*tensors)
|
|
|
|
def reset(self):
|
|
self.call_history.append("reset")
|
|
return super().reset()
|
|
|
|
@property
|
|
def aggregated(self) -> torch.Tensor:
|
|
self.call_history.append("aggregated")
|
|
return super().aggregated
|
|
|
|
@staticmethod
|
|
def compute(self, data: Any, output: Any):
|
|
self.call_history.append("compute")
|
|
return super(DummyMetric, self).compute(self, data, output)
|
|
|
|
metric = DummyMetric()
|
|
assert metric.call_history == ["init"]
|
|
result = metric(torch.tensor([2.0]), torch.tensor([1.0]))
|
|
assert torch.allclose(result, torch.tensor(1.0))
|
|
assert metric.call_history == [
|
|
"init",
|
|
"input_convert",
|
|
"forward",
|
|
"output_convert",
|
|
"ddp_reduce",
|
|
"ddp_sync",
|
|
"aggregate",
|
|
"compute"
|
|
]
|
|
aggr = metric.aggregated
|
|
assert metric.call_history == [
|
|
"init",
|
|
"input_convert",
|
|
"forward",
|
|
"output_convert",
|
|
"ddp_reduce",
|
|
"ddp_sync",
|
|
"aggregate",
|
|
"compute",
|
|
"aggregated",
|
|
"aggregate",
|
|
"reset",
|
|
"compute"
|
|
]
|
|
assert torch.allclose(aggr, result)
|
|
_ = metric(torch.tensor(2.0), torch.tensor(1.0))
|
|
assert metric.call_history == [
|
|
"init",
|
|
"input_convert",
|
|
"forward",
|
|
"output_convert",
|
|
"ddp_reduce",
|
|
"ddp_sync",
|
|
"aggregate",
|
|
"compute",
|
|
"aggregated",
|
|
"aggregate",
|
|
"reset",
|
|
"compute",
|
|
"input_convert",
|
|
"forward",
|
|
"output_convert",
|
|
"ddp_reduce",
|
|
"ddp_sync",
|
|
"aggregate",
|
|
"compute"
|
|
]
|
|
|
|
metric = DummyMetric()
|
|
_ = metric(torch.tensor([2.0]), torch.tensor([1.0]))
|
|
_ = metric(torch.tensor([3.0]), torch.tensor([0.0]))
|
|
|
|
aggregated = metric.aggregated
|
|
|
|
assert torch.allclose(aggregated, torch.tensor(4.0))
|
|
|
|
assert metric.call_history == [
|
|
"init",
|
|
"input_convert",
|
|
"forward",
|
|
"output_convert",
|
|
"ddp_reduce",
|
|
"ddp_sync",
|
|
"aggregate",
|
|
"compute",
|
|
"input_convert",
|
|
"forward",
|
|
"output_convert",
|
|
"ddp_reduce",
|
|
"ddp_sync",
|
|
"aggregate",
|
|
"compute",
|
|
"aggregated",
|
|
"aggregate",
|
|
"reset",
|
|
"compute",
|
|
]
|