312 lines
13 KiB
Python
312 lines
13 KiB
Python
import pytest
|
|
import torch
|
|
from torch import rand, randint
|
|
|
|
from pytorch_lightning.metrics.classification.helpers import _input_format_classification, DataType
|
|
from pytorch_lightning.metrics.utils import select_topk, to_onehot
|
|
from tests.metrics.classification.inputs import _input_binary as _bin
|
|
from tests.metrics.classification.inputs import _input_binary_prob as _bin_prob
|
|
from tests.metrics.classification.inputs import _input_multiclass as _mc
|
|
from tests.metrics.classification.inputs import _input_multiclass_prob as _mc_prob
|
|
from tests.metrics.classification.inputs import _input_multidim_multiclass as _mdmc
|
|
from tests.metrics.classification.inputs import _input_multidim_multiclass_prob as _mdmc_prob
|
|
from tests.metrics.classification.inputs import _input_multilabel as _ml
|
|
from tests.metrics.classification.inputs import _input_multilabel_multidim as _mlmd
|
|
from tests.metrics.classification.inputs import _input_multilabel_multidim_prob as _mlmd_prob
|
|
from tests.metrics.classification.inputs import _input_multilabel_prob as _ml_prob
|
|
from tests.metrics.classification.inputs import Input
|
|
from tests.metrics.utils import BATCH_SIZE, EXTRA_DIM, NUM_BATCHES, NUM_CLASSES, THRESHOLD
|
|
|
|
torch.manual_seed(42)
|
|
|
|
# Some additional inputs to test on
|
|
_ml_prob_half = Input(_ml_prob.preds.half(), _ml_prob.target)
|
|
|
|
_mc_prob_2cls_preds = rand(NUM_BATCHES, BATCH_SIZE, 2)
|
|
_mc_prob_2cls_preds /= _mc_prob_2cls_preds.sum(dim=2, keepdim=True)
|
|
_mc_prob_2cls = Input(_mc_prob_2cls_preds, randint(high=2, size=(NUM_BATCHES, BATCH_SIZE)))
|
|
|
|
_mdmc_prob_many_dims_preds = rand(NUM_BATCHES, BATCH_SIZE, NUM_CLASSES, EXTRA_DIM, EXTRA_DIM)
|
|
_mdmc_prob_many_dims_preds /= _mdmc_prob_many_dims_preds.sum(dim=2, keepdim=True)
|
|
_mdmc_prob_many_dims = Input(
|
|
_mdmc_prob_many_dims_preds,
|
|
randint(high=2, size=(NUM_BATCHES, BATCH_SIZE, EXTRA_DIM, EXTRA_DIM)),
|
|
)
|
|
|
|
_mdmc_prob_2cls_preds = rand(NUM_BATCHES, BATCH_SIZE, 2, EXTRA_DIM)
|
|
_mdmc_prob_2cls_preds /= _mdmc_prob_2cls_preds.sum(dim=2, keepdim=True)
|
|
_mdmc_prob_2cls = Input(_mdmc_prob_2cls_preds, randint(high=2, size=(NUM_BATCHES, BATCH_SIZE, EXTRA_DIM)))
|
|
|
|
# Some utils
|
|
T = torch.Tensor
|
|
|
|
|
|
def _idn(x):
|
|
return x
|
|
|
|
|
|
def _usq(x):
|
|
return x.unsqueeze(-1)
|
|
|
|
|
|
def _thrs(x):
|
|
return x >= THRESHOLD
|
|
|
|
|
|
def _rshp1(x):
|
|
return x.reshape(x.shape[0], -1)
|
|
|
|
|
|
def _rshp2(x):
|
|
return x.reshape(x.shape[0], x.shape[1], -1)
|
|
|
|
|
|
def _onehot(x):
|
|
return to_onehot(x, NUM_CLASSES)
|
|
|
|
|
|
def _onehot2(x):
|
|
return to_onehot(x, 2)
|
|
|
|
|
|
def _top1(x):
|
|
return select_topk(x, 1)
|
|
|
|
|
|
def _top2(x):
|
|
return select_topk(x, 2)
|
|
|
|
|
|
# To avoid ugly black line wrapping
|
|
def _ml_preds_tr(x):
|
|
return _rshp1(_thrs(x))
|
|
|
|
|
|
def _onehot_rshp1(x):
|
|
return _onehot(_rshp1(x))
|
|
|
|
|
|
def _onehot2_rshp1(x):
|
|
return _onehot2(_rshp1(x))
|
|
|
|
|
|
def _top1_rshp2(x):
|
|
return _top1(_rshp2(x))
|
|
|
|
|
|
def _top2_rshp2(x):
|
|
return _top2(_rshp2(x))
|
|
|
|
|
|
def _probs_to_mc_preds_tr(x):
|
|
return _onehot2(_thrs(x))
|
|
|
|
|
|
def _mlmd_prob_to_mc_preds_tr(x):
|
|
return _onehot2(_rshp1(_thrs(x)))
|
|
|
|
|
|
########################
|
|
# Test correct inputs
|
|
########################
|
|
|
|
|
|
@pytest.mark.parametrize(
|
|
"inputs, num_classes, is_multiclass, top_k, exp_mode, post_preds, post_target",
|
|
[
|
|
#############################
|
|
# Test usual expected cases
|
|
(_bin, None, False, None, "multi-class", _usq, _usq),
|
|
(_bin, 1, False, None, "multi-class", _usq, _usq),
|
|
(_bin_prob, None, None, None, "binary", lambda x: _usq(_thrs(x)), _usq),
|
|
(_ml_prob, None, None, None, "multi-label", _thrs, _idn),
|
|
(_ml, None, False, None, "multi-dim multi-class", _idn, _idn),
|
|
(_ml_prob, None, None, None, "multi-label", _ml_preds_tr, _rshp1),
|
|
(_ml_prob, None, None, 2, "multi-label", _top2, _rshp1),
|
|
(_mlmd, None, False, None, "multi-dim multi-class", _rshp1, _rshp1),
|
|
(_mc, NUM_CLASSES, None, None, "multi-class", _onehot, _onehot),
|
|
(_mc_prob, None, None, None, "multi-class", _top1, _onehot),
|
|
(_mc_prob, None, None, 2, "multi-class", _top2, _onehot),
|
|
(_mdmc, NUM_CLASSES, None, None, "multi-dim multi-class", _onehot, _onehot),
|
|
(_mdmc_prob, None, None, None, "multi-dim multi-class", _top1_rshp2, _onehot),
|
|
(_mdmc_prob, None, None, 2, "multi-dim multi-class", _top2_rshp2, _onehot),
|
|
(_mdmc_prob_many_dims, None, None, None, "multi-dim multi-class", _top1_rshp2, _onehot_rshp1),
|
|
(_mdmc_prob_many_dims, None, None, 2, "multi-dim multi-class", _top2_rshp2, _onehot_rshp1),
|
|
###########################
|
|
# Test some special cases
|
|
# Make sure that half precision works, i.e. is converted to full precision
|
|
(_ml_prob_half, None, None, None, "multi-label", lambda x: _ml_preds_tr(x.float()), _rshp1),
|
|
# Binary as multiclass
|
|
(_bin, None, None, None, "multi-class", _onehot2, _onehot2),
|
|
# Binary probs as multiclass
|
|
(_bin_prob, None, True, None, "binary", _probs_to_mc_preds_tr, _onehot2),
|
|
# Multilabel as multiclass
|
|
(_ml, None, True, None, "multi-dim multi-class", _onehot2, _onehot2),
|
|
# Multilabel probs as multiclass
|
|
(_ml_prob, None, True, None, "multi-label", _probs_to_mc_preds_tr, _onehot2),
|
|
# Multidim multilabel as multiclass
|
|
(_mlmd, None, True, None, "multi-dim multi-class", _onehot2_rshp1, _onehot2_rshp1),
|
|
# Multidim multilabel probs as multiclass
|
|
(_mlmd_prob, None, True, None, "multi-label", _mlmd_prob_to_mc_preds_tr, _onehot2_rshp1),
|
|
# Multiclass prob with 2 classes as binary
|
|
(_mc_prob_2cls, None, False, None, "multi-class", lambda x: _top1(x)[:, [1]], _usq),
|
|
# Multi-dim multi-class with 2 classes as multi-label
|
|
(_mdmc_prob_2cls, None, False, None, "multi-dim multi-class", lambda x: _top1(x)[:, 1], _idn),
|
|
],
|
|
)
|
|
def test_usual_cases(inputs, num_classes, is_multiclass, top_k, exp_mode, post_preds, post_target):
|
|
|
|
def __get_data_type_enum(str_exp_mode):
|
|
return next(DataType[n] for n in dir(DataType) if DataType[n] == str_exp_mode)
|
|
|
|
for exp_mode in (exp_mode, __get_data_type_enum(exp_mode)):
|
|
preds_out, target_out, mode = _input_format_classification(
|
|
preds=inputs.preds[0],
|
|
target=inputs.target[0],
|
|
threshold=THRESHOLD,
|
|
num_classes=num_classes,
|
|
is_multiclass=is_multiclass,
|
|
top_k=top_k,
|
|
)
|
|
|
|
assert mode == exp_mode
|
|
assert torch.equal(preds_out, post_preds(inputs.preds[0]).int())
|
|
assert torch.equal(target_out, post_target(inputs.target[0]).int())
|
|
|
|
# Test that things work when batch_size = 1
|
|
preds_out, target_out, mode = _input_format_classification(
|
|
preds=inputs.preds[0][[0], ...],
|
|
target=inputs.target[0][[0], ...],
|
|
threshold=THRESHOLD,
|
|
num_classes=num_classes,
|
|
is_multiclass=is_multiclass,
|
|
top_k=top_k,
|
|
)
|
|
|
|
assert mode == exp_mode
|
|
assert torch.equal(preds_out, post_preds(inputs.preds[0][[0], ...]).int())
|
|
assert torch.equal(target_out, post_target(inputs.target[0][[0], ...]).int())
|
|
|
|
|
|
# Test that threshold is correctly applied
|
|
def test_threshold():
|
|
target = T([1, 1, 1]).int()
|
|
preds_probs = T([0.5 - 1e-5, 0.5, 0.5 + 1e-5])
|
|
|
|
preds_probs_out, _, _ = _input_format_classification(preds_probs, target, threshold=0.5)
|
|
|
|
assert torch.equal(torch.tensor([0, 1, 1], dtype=torch.int), preds_probs_out.squeeze().int())
|
|
|
|
|
|
########################################################################
|
|
# Test incorrect inputs
|
|
########################################################################
|
|
|
|
|
|
@pytest.mark.parametrize("threshold", [-0.5, 0.0, 1.0, 1.5])
|
|
def test_incorrect_threshold(threshold):
|
|
preds, target = rand(size=(7, )), randint(high=2, size=(7, ))
|
|
with pytest.raises(ValueError):
|
|
_input_format_classification(preds, target, threshold=threshold)
|
|
|
|
|
|
@pytest.mark.parametrize(
|
|
"preds, target, num_classes, is_multiclass",
|
|
[
|
|
# Target not integer
|
|
(randint(high=2, size=(7, )), randint(high=2, size=(7, )).float(), None, None),
|
|
# Target negative
|
|
(randint(high=2, size=(7, )), -randint(high=2, size=(7, )), None, None),
|
|
# Preds negative integers
|
|
(-randint(high=2, size=(7, )), randint(high=2, size=(7, )), None, None),
|
|
# Negative probabilities
|
|
(-rand(size=(7, )), randint(high=2, size=(7, )), None, None),
|
|
# is_multiclass=False and target > 1
|
|
(rand(size=(7, )), randint(low=2, high=4, size=(7, )), None, False),
|
|
# is_multiclass=False and preds integers with > 1
|
|
(randint(low=2, high=4, size=(7, )), randint(high=2, size=(7, )), None, False),
|
|
# Wrong batch size
|
|
(randint(high=2, size=(8, )), randint(high=2, size=(7, )), None, None),
|
|
# Completely wrong shape
|
|
(randint(high=2, size=(7, )), randint(high=2, size=(7, 4)), None, None),
|
|
# Same #dims, different shape
|
|
(randint(high=2, size=(7, 3)), randint(high=2, size=(7, 4)), None, None),
|
|
# Same shape and preds floats, target not binary
|
|
(rand(size=(7, 3)), randint(low=2, high=4, size=(7, 3)), None, None),
|
|
# #dims in preds = 1 + #dims in target, C shape not second or last
|
|
(rand(size=(7, 3, 4, 3)), randint(high=4, size=(7, 3, 3)), None, None),
|
|
# #dims in preds = 1 + #dims in target, preds not float
|
|
(randint(high=2, size=(7, 3, 3, 4)), randint(high=4, size=(7, 3, 3)), None, None),
|
|
# is_multiclass=False, with C dimension > 2
|
|
(_mc_prob.preds[0], randint(high=2, size=(BATCH_SIZE, )), None, False),
|
|
# Probs of multiclass preds do not sum up to 1
|
|
(rand(size=(7, 3, 5)), randint(high=2, size=(7, 5)), None, None),
|
|
# Max target larger or equal to C dimension
|
|
(_mc_prob.preds[0], randint(low=NUM_CLASSES + 1, high=100, size=(BATCH_SIZE, )), None, None),
|
|
# C dimension not equal to num_classes
|
|
(_mc_prob.preds[0], _mc_prob.target[0], NUM_CLASSES + 1, None),
|
|
# Max target larger than num_classes (with #dim preds = 1 + #dims target)
|
|
(_mc_prob.preds[0], randint(low=NUM_CLASSES + 1, high=100, size=(BATCH_SIZE, NUM_CLASSES)), 4, None),
|
|
# Max target larger than num_classes (with #dim preds = #dims target)
|
|
(randint(high=4, size=(7, 3)), randint(low=5, high=7, size=(7, 3)), 4, None),
|
|
# Max preds larger than num_classes (with #dim preds = #dims target)
|
|
(randint(low=5, high=7, size=(7, 3)), randint(high=4, size=(7, 3)), 4, None),
|
|
# Num_classes=1, but is_multiclass not false
|
|
(randint(high=2, size=(7, )), randint(high=2, size=(7, )), 1, None),
|
|
# is_multiclass=False, but implied class dimension (for multi-label, from shape) != num_classes
|
|
(randint(high=2, size=(7, 3, 3)), randint(high=2, size=(7, 3, 3)), 4, False),
|
|
# Multilabel input with implied class dimension != num_classes
|
|
(rand(size=(7, 3, 3)), randint(high=2, size=(7, 3, 3)), 4, False),
|
|
# Multilabel input with is_multiclass=True, but num_classes != 2 (or None)
|
|
(rand(size=(7, 3)), randint(high=2, size=(7, 3)), 4, True),
|
|
# Binary input, num_classes > 2
|
|
(rand(size=(7, )), randint(high=2, size=(7, )), 4, None),
|
|
# Binary input, num_classes == 2 and is_multiclass not True
|
|
(rand(size=(7, )), randint(high=2, size=(7, )), 2, None),
|
|
(rand(size=(7, )), randint(high=2, size=(7, )), 2, False),
|
|
# Binary input, num_classes == 1 and is_multiclass=True
|
|
(rand(size=(7, )), randint(high=2, size=(7, )), 1, True),
|
|
],
|
|
)
|
|
def test_incorrect_inputs(preds, target, num_classes, is_multiclass):
|
|
with pytest.raises(ValueError):
|
|
_input_format_classification(
|
|
preds=preds, target=target, threshold=THRESHOLD, num_classes=num_classes, is_multiclass=is_multiclass
|
|
)
|
|
|
|
|
|
@pytest.mark.parametrize(
|
|
"preds, target, num_classes, is_multiclass, top_k",
|
|
[
|
|
# Topk set with non (md)mc or ml prob data
|
|
(_bin.preds[0], _bin.target[0], None, None, 2),
|
|
(_bin_prob.preds[0], _bin_prob.target[0], None, None, 2),
|
|
(_mc.preds[0], _mc.target[0], None, None, 2),
|
|
(_ml.preds[0], _ml.target[0], None, None, 2),
|
|
(_mlmd.preds[0], _mlmd.target[0], None, None, 2),
|
|
(_mdmc.preds[0], _mdmc.target[0], None, None, 2),
|
|
# top_k = 0
|
|
(_mc_prob_2cls.preds[0], _mc_prob_2cls.target[0], None, None, 0),
|
|
# top_k = float
|
|
(_mc_prob_2cls.preds[0], _mc_prob_2cls.target[0], None, None, 0.123),
|
|
# top_k =2 with 2 classes, is_multiclass=False
|
|
(_mc_prob_2cls.preds[0], _mc_prob_2cls.target[0], None, False, 2),
|
|
# top_k = number of classes (C dimension)
|
|
(_mc_prob.preds[0], _mc_prob.target[0], None, None, NUM_CLASSES),
|
|
# is_multiclass = True for ml prob inputs, top_k set
|
|
(_ml_prob.preds[0], _ml_prob.target[0], None, True, 2),
|
|
# top_k = num_classes for ml prob inputs
|
|
(_ml_prob.preds[0], _ml_prob.target[0], None, True, NUM_CLASSES),
|
|
],
|
|
)
|
|
def test_incorrect_inputs_topk(preds, target, num_classes, is_multiclass, top_k):
|
|
with pytest.raises(ValueError):
|
|
_input_format_classification(
|
|
preds=preds,
|
|
target=target,
|
|
threshold=THRESHOLD,
|
|
num_classes=num_classes,
|
|
is_multiclass=is_multiclass,
|
|
top_k=top_k,
|
|
)
|