lightning/tests/metrics/test_composition.py

505 lines
14 KiB
Python

from distutils.version import LooseVersion
from operator import neg, pos
import pytest
import torch
from pytorch_lightning.metrics.compositional import CompositionalMetric
from pytorch_lightning.metrics.metric import Metric
_MARK_TORCH_LOWER_1_4 = dict(
condition=LooseVersion(torch.__version__) < LooseVersion("1.5.0"), reason='required PT >= 1.5'
)
_MARK_TORCH_LOWER_1_5 = dict(
condition=LooseVersion(torch.__version__) < LooseVersion("1.6.0"), reason='required PT >= 1.6'
)
class DummyMetric(Metric):
def __init__(self, val_to_return):
super().__init__()
self._num_updates = 0
self._val_to_return = val_to_return
def update(self, *args, **kwargs) -> None:
self._num_updates += 1
def compute(self):
return torch.tensor(self._val_to_return)
def reset(self):
self._num_updates = 0
return super().reset()
@pytest.mark.parametrize(
["second_operand", "expected_result"],
[
(DummyMetric(2), torch.tensor(4)),
(2, torch.tensor(4)),
(2.0, torch.tensor(4.0)),
(torch.tensor(2), torch.tensor(4)),
],
)
def test_metrics_add(second_operand, expected_result):
first_metric = DummyMetric(2)
final_add = first_metric + second_operand
final_radd = second_operand + first_metric
assert isinstance(final_add, CompositionalMetric)
assert isinstance(final_radd, CompositionalMetric)
assert torch.allclose(expected_result, final_add.compute())
assert torch.allclose(expected_result, final_radd.compute())
@pytest.mark.parametrize(
["second_operand", "expected_result"],
[(DummyMetric(3), torch.tensor(2)), (3, torch.tensor(2)), (3, torch.tensor(2)), (torch.tensor(3), torch.tensor(2))],
)
@pytest.mark.skipif(**_MARK_TORCH_LOWER_1_4)
def test_metrics_and(second_operand, expected_result):
first_metric = DummyMetric(2)
final_and = first_metric & second_operand
final_rand = second_operand & first_metric
assert isinstance(final_and, CompositionalMetric)
assert isinstance(final_rand, CompositionalMetric)
assert torch.allclose(expected_result, final_and.compute())
assert torch.allclose(expected_result, final_rand.compute())
@pytest.mark.parametrize(
["second_operand", "expected_result"],
[
(DummyMetric(2), torch.tensor(True)),
(2, torch.tensor(True)),
(2.0, torch.tensor(True)),
(torch.tensor(2), torch.tensor(True)),
],
)
def test_metrics_eq(second_operand, expected_result):
first_metric = DummyMetric(2)
final_eq = first_metric == second_operand
assert isinstance(final_eq, CompositionalMetric)
# can't use allclose for bool tensors
assert (expected_result == final_eq.compute()).all()
@pytest.mark.parametrize(
["second_operand", "expected_result"],
[
(DummyMetric(2), torch.tensor(2)),
(2, torch.tensor(2)),
(2.0, torch.tensor(2.0)),
(torch.tensor(2), torch.tensor(2)),
],
)
@pytest.mark.skipif(**_MARK_TORCH_LOWER_1_4)
def test_metrics_floordiv(second_operand, expected_result):
first_metric = DummyMetric(5)
final_floordiv = first_metric // second_operand
assert isinstance(final_floordiv, CompositionalMetric)
assert torch.allclose(expected_result, final_floordiv.compute())
@pytest.mark.parametrize(
["second_operand", "expected_result"],
[
(DummyMetric(2), torch.tensor(True)),
(2, torch.tensor(True)),
(2.0, torch.tensor(True)),
(torch.tensor(2), torch.tensor(True)),
],
)
def test_metrics_ge(second_operand, expected_result):
first_metric = DummyMetric(5)
final_ge = first_metric >= second_operand
assert isinstance(final_ge, CompositionalMetric)
# can't use allclose for bool tensors
assert (expected_result == final_ge.compute()).all()
@pytest.mark.parametrize(
["second_operand", "expected_result"],
[
(DummyMetric(2), torch.tensor(True)),
(2, torch.tensor(True)),
(2.0, torch.tensor(True)),
(torch.tensor(2), torch.tensor(True)),
],
)
def test_metrics_gt(second_operand, expected_result):
first_metric = DummyMetric(5)
final_gt = first_metric > second_operand
assert isinstance(final_gt, CompositionalMetric)
# can't use allclose for bool tensors
assert (expected_result == final_gt.compute()).all()
@pytest.mark.parametrize(
["second_operand", "expected_result"],
[
(DummyMetric(2), torch.tensor(False)),
(2, torch.tensor(False)),
(2.0, torch.tensor(False)),
(torch.tensor(2), torch.tensor(False)),
],
)
def test_metrics_le(second_operand, expected_result):
first_metric = DummyMetric(5)
final_le = first_metric <= second_operand
assert isinstance(final_le, CompositionalMetric)
# can't use allclose for bool tensors
assert (expected_result == final_le.compute()).all()
@pytest.mark.parametrize(
["second_operand", "expected_result"],
[
(DummyMetric(2), torch.tensor(False)),
(2, torch.tensor(False)),
(2.0, torch.tensor(False)),
(torch.tensor(2), torch.tensor(False)),
],
)
def test_metrics_lt(second_operand, expected_result):
first_metric = DummyMetric(5)
final_lt = first_metric < second_operand
assert isinstance(final_lt, CompositionalMetric)
# can't use allclose for bool tensors
assert (expected_result == final_lt.compute()).all()
@pytest.mark.parametrize(
["second_operand", "expected_result"],
[(DummyMetric([2, 2, 2]), torch.tensor(12)), (torch.tensor([2, 2, 2]), torch.tensor(12))],
)
def test_metrics_matmul(second_operand, expected_result):
first_metric = DummyMetric([2, 2, 2])
final_matmul = first_metric @ second_operand
assert isinstance(final_matmul, CompositionalMetric)
assert torch.allclose(expected_result, final_matmul.compute())
@pytest.mark.parametrize(
["second_operand", "expected_result"],
[
(DummyMetric(2), torch.tensor(1)),
(2, torch.tensor(1)),
(2.0, torch.tensor(1)),
(torch.tensor(2), torch.tensor(1)),
],
)
def test_metrics_mod(second_operand, expected_result):
first_metric = DummyMetric(5)
final_mod = first_metric % second_operand
assert isinstance(final_mod, CompositionalMetric)
# prevent Runtime error for PT 1.8 - Long did not match Float
assert torch.allclose(expected_result.to(float), final_mod.compute().to(float))
@pytest.mark.parametrize(
["second_operand", "expected_result"],
[
(DummyMetric(2), torch.tensor(4)),
(2, torch.tensor(4)),
(2.0, torch.tensor(4.0)),
(torch.tensor(2), torch.tensor(4)),
],
)
def test_metrics_mul(second_operand, expected_result):
first_metric = DummyMetric(2)
final_mul = first_metric * second_operand
final_rmul = second_operand * first_metric
assert isinstance(final_mul, CompositionalMetric)
assert isinstance(final_rmul, CompositionalMetric)
assert torch.allclose(expected_result, final_mul.compute())
assert torch.allclose(expected_result, final_rmul.compute())
@pytest.mark.parametrize(
["second_operand", "expected_result"],
[
(DummyMetric(2), torch.tensor(False)),
(2, torch.tensor(False)),
(2.0, torch.tensor(False)),
(torch.tensor(2), torch.tensor(False)),
],
)
def test_metrics_ne(second_operand, expected_result):
first_metric = DummyMetric(2)
final_ne = first_metric != second_operand
assert isinstance(final_ne, CompositionalMetric)
# can't use allclose for bool tensors
assert (expected_result == final_ne.compute()).all()
@pytest.mark.parametrize(
["second_operand", "expected_result"],
[(DummyMetric([1, 0, 3]), torch.tensor([-1, -2, 3])), (torch.tensor([1, 0, 3]), torch.tensor([-1, -2, 3]))],
)
@pytest.mark.skipif(**_MARK_TORCH_LOWER_1_4)
def test_metrics_or(second_operand, expected_result):
first_metric = DummyMetric([-1, -2, 3])
final_or = first_metric | second_operand
final_ror = second_operand | first_metric
assert isinstance(final_or, CompositionalMetric)
assert isinstance(final_ror, CompositionalMetric)
assert torch.allclose(expected_result, final_or.compute())
assert torch.allclose(expected_result, final_ror.compute())
@pytest.mark.parametrize(
["second_operand", "expected_result"],
[
pytest.param(DummyMetric(2), torch.tensor(4)),
pytest.param(2, torch.tensor(4)),
pytest.param(2.0, torch.tensor(4.0), marks=pytest.mark.skipif(**_MARK_TORCH_LOWER_1_5)),
pytest.param(torch.tensor(2), torch.tensor(4)),
],
)
def test_metrics_pow(second_operand, expected_result):
first_metric = DummyMetric(2)
final_pow = first_metric**second_operand
assert isinstance(final_pow, CompositionalMetric)
assert torch.allclose(expected_result, final_pow.compute())
@pytest.mark.parametrize(
["first_operand", "expected_result"],
[(5, torch.tensor(2)), (5.0, torch.tensor(2.0)), (torch.tensor(5), torch.tensor(2))],
)
@pytest.mark.skipif(**_MARK_TORCH_LOWER_1_4)
def test_metrics_rfloordiv(first_operand, expected_result):
second_operand = DummyMetric(2)
final_rfloordiv = first_operand // second_operand
assert isinstance(final_rfloordiv, CompositionalMetric)
assert torch.allclose(expected_result, final_rfloordiv.compute())
@pytest.mark.parametrize(["first_operand", "expected_result"], [(torch.tensor([2, 2, 2]), torch.tensor(12))])
def test_metrics_rmatmul(first_operand, expected_result):
second_operand = DummyMetric([2, 2, 2])
final_rmatmul = first_operand @ second_operand
assert isinstance(final_rmatmul, CompositionalMetric)
assert torch.allclose(expected_result, final_rmatmul.compute())
@pytest.mark.parametrize(["first_operand", "expected_result"], [(torch.tensor(2), torch.tensor(2))])
def test_metrics_rmod(first_operand, expected_result):
second_operand = DummyMetric(5)
final_rmod = first_operand % second_operand
assert isinstance(final_rmod, CompositionalMetric)
assert torch.allclose(expected_result, final_rmod.compute())
@pytest.mark.parametrize(
"first_operand,expected_result",
[
pytest.param(DummyMetric(2), torch.tensor(4)),
pytest.param(2, torch.tensor(4)),
pytest.param(2.0, torch.tensor(4.0), marks=pytest.mark.skipif(**_MARK_TORCH_LOWER_1_5)),
],
)
def test_metrics_rpow(first_operand, expected_result):
second_operand = DummyMetric(2)
final_rpow = first_operand**second_operand
assert isinstance(final_rpow, CompositionalMetric)
assert torch.allclose(expected_result, final_rpow.compute())
@pytest.mark.parametrize(
["first_operand", "expected_result"],
[
(DummyMetric(3), torch.tensor(1)),
(3, torch.tensor(1)),
(3.0, torch.tensor(1.0)),
(torch.tensor(3), torch.tensor(1)),
],
)
def test_metrics_rsub(first_operand, expected_result):
second_operand = DummyMetric(2)
final_rsub = first_operand - second_operand
assert isinstance(final_rsub, CompositionalMetric)
assert torch.allclose(expected_result, final_rsub.compute())
@pytest.mark.parametrize(
["first_operand", "expected_result"],
[
(DummyMetric(6), torch.tensor(2.0)),
(6, torch.tensor(2.0)),
(6.0, torch.tensor(2.0)),
(torch.tensor(6), torch.tensor(2.0)),
],
)
@pytest.mark.skipif(**_MARK_TORCH_LOWER_1_4)
def test_metrics_rtruediv(first_operand, expected_result):
second_operand = DummyMetric(3)
final_rtruediv = first_operand / second_operand
assert isinstance(final_rtruediv, CompositionalMetric)
assert torch.allclose(expected_result, final_rtruediv.compute())
@pytest.mark.parametrize(
["second_operand", "expected_result"],
[
(DummyMetric(2), torch.tensor(1)),
(2, torch.tensor(1)),
(2.0, torch.tensor(1.0)),
(torch.tensor(2), torch.tensor(1)),
],
)
def test_metrics_sub(second_operand, expected_result):
first_metric = DummyMetric(3)
final_sub = first_metric - second_operand
assert isinstance(final_sub, CompositionalMetric)
assert torch.allclose(expected_result, final_sub.compute())
@pytest.mark.parametrize(
["second_operand", "expected_result"],
[
(DummyMetric(3), torch.tensor(2.0)),
(3, torch.tensor(2.0)),
(3.0, torch.tensor(2.0)),
(torch.tensor(3), torch.tensor(2.0)),
],
)
@pytest.mark.skipif(**_MARK_TORCH_LOWER_1_4)
def test_metrics_truediv(second_operand, expected_result):
first_metric = DummyMetric(6)
final_truediv = first_metric / second_operand
assert isinstance(final_truediv, CompositionalMetric)
assert torch.allclose(expected_result, final_truediv.compute())
@pytest.mark.parametrize(
["second_operand", "expected_result"],
[(DummyMetric([1, 0, 3]), torch.tensor([-2, -2, 0])), (torch.tensor([1, 0, 3]), torch.tensor([-2, -2, 0]))],
)
def test_metrics_xor(second_operand, expected_result):
first_metric = DummyMetric([-1, -2, 3])
final_xor = first_metric ^ second_operand
final_rxor = second_operand ^ first_metric
assert isinstance(final_xor, CompositionalMetric)
assert isinstance(final_rxor, CompositionalMetric)
assert torch.allclose(expected_result, final_xor.compute())
assert torch.allclose(expected_result, final_rxor.compute())
def test_metrics_abs():
first_metric = DummyMetric(-1)
final_abs = abs(first_metric)
assert isinstance(final_abs, CompositionalMetric)
assert torch.allclose(torch.tensor(1), final_abs.compute())
def test_metrics_invert():
first_metric = DummyMetric(1)
final_inverse = ~first_metric
assert isinstance(final_inverse, CompositionalMetric)
assert torch.allclose(torch.tensor(-2), final_inverse.compute())
def test_metrics_neg():
first_metric = DummyMetric(1)
final_neg = neg(first_metric)
assert isinstance(final_neg, CompositionalMetric)
assert torch.allclose(torch.tensor(-1), final_neg.compute())
def test_metrics_pos():
first_metric = DummyMetric(-1)
final_pos = pos(first_metric)
assert isinstance(final_pos, CompositionalMetric)
assert torch.allclose(torch.tensor(1), final_pos.compute())
def test_compositional_metrics_update():
compos = DummyMetric(5) + DummyMetric(4)
assert isinstance(compos, CompositionalMetric)
compos.update()
compos.update()
compos.update()
assert isinstance(compos.metric_a, DummyMetric)
assert isinstance(compos.metric_b, DummyMetric)
assert compos.metric_a._num_updates == 3
assert compos.metric_b._num_updates == 3