67 lines
2.5 KiB
Python
67 lines
2.5 KiB
Python
from collections import namedtuple
|
|
|
|
import torch
|
|
|
|
from tests.metrics.utils import BATCH_SIZE, EXTRA_DIM, NUM_BATCHES, NUM_CLASSES
|
|
|
|
Input = namedtuple('Input', ["preds", "target"])
|
|
|
|
_input_binary_prob = Input(
|
|
preds=torch.rand(NUM_BATCHES, BATCH_SIZE), target=torch.randint(high=2, size=(NUM_BATCHES, BATCH_SIZE))
|
|
)
|
|
|
|
_input_binary = Input(
|
|
preds=torch.randint(high=2, size=(NUM_BATCHES, BATCH_SIZE)),
|
|
target=torch.randint(high=2, size=(NUM_BATCHES, BATCH_SIZE))
|
|
)
|
|
|
|
_input_multilabel_prob = Input(
|
|
preds=torch.rand(NUM_BATCHES, BATCH_SIZE, NUM_CLASSES),
|
|
target=torch.randint(high=2, size=(NUM_BATCHES, BATCH_SIZE, NUM_CLASSES))
|
|
)
|
|
|
|
_input_multilabel_multidim_prob = Input(
|
|
preds=torch.rand(NUM_BATCHES, BATCH_SIZE, NUM_CLASSES, EXTRA_DIM),
|
|
target=torch.randint(high=2, size=(NUM_BATCHES, BATCH_SIZE, NUM_CLASSES, EXTRA_DIM))
|
|
)
|
|
|
|
_input_multilabel = Input(
|
|
preds=torch.randint(high=2, size=(NUM_BATCHES, BATCH_SIZE, NUM_CLASSES)),
|
|
target=torch.randint(high=2, size=(NUM_BATCHES, BATCH_SIZE, NUM_CLASSES))
|
|
)
|
|
|
|
_input_multilabel_multidim = Input(
|
|
preds=torch.randint(high=2, size=(NUM_BATCHES, BATCH_SIZE, NUM_CLASSES, EXTRA_DIM)),
|
|
target=torch.randint(high=2, size=(NUM_BATCHES, BATCH_SIZE, NUM_CLASSES, EXTRA_DIM))
|
|
)
|
|
|
|
# Generate edge multilabel edge case, where nothing matches (scores are undefined)
|
|
__temp_preds = torch.randint(high=2, size=(NUM_BATCHES, BATCH_SIZE, NUM_CLASSES))
|
|
__temp_target = abs(__temp_preds - 1)
|
|
|
|
_input_multilabel_no_match = Input(preds=__temp_preds, target=__temp_target)
|
|
|
|
__mc_prob_preds = torch.rand(NUM_BATCHES, BATCH_SIZE, NUM_CLASSES)
|
|
__mc_prob_preds = __mc_prob_preds / __mc_prob_preds.sum(dim=2, keepdim=True)
|
|
|
|
_input_multiclass_prob = Input(
|
|
preds=__mc_prob_preds, target=torch.randint(high=NUM_CLASSES, size=(NUM_BATCHES, BATCH_SIZE))
|
|
)
|
|
|
|
_input_multiclass = Input(
|
|
preds=torch.randint(high=NUM_CLASSES, size=(NUM_BATCHES, BATCH_SIZE)),
|
|
target=torch.randint(high=NUM_CLASSES, size=(NUM_BATCHES, BATCH_SIZE))
|
|
)
|
|
|
|
__mdmc_prob_preds = torch.rand(NUM_BATCHES, BATCH_SIZE, NUM_CLASSES, EXTRA_DIM)
|
|
__mdmc_prob_preds = __mdmc_prob_preds / __mdmc_prob_preds.sum(dim=2, keepdim=True)
|
|
|
|
_input_multidim_multiclass_prob = Input(
|
|
preds=__mdmc_prob_preds, target=torch.randint(high=NUM_CLASSES, size=(NUM_BATCHES, BATCH_SIZE, EXTRA_DIM))
|
|
)
|
|
|
|
_input_multidim_multiclass = Input(
|
|
preds=torch.randint(high=NUM_CLASSES, size=(NUM_BATCHES, BATCH_SIZE, EXTRA_DIM)),
|
|
target=torch.randint(high=NUM_CLASSES, size=(NUM_BATCHES, BATCH_SIZE, EXTRA_DIM))
|
|
)
|