Update metrics to use Enum (#5689)
- Add DataType, AverageMethod and MDMCAverageMethod
This commit is contained in:
parent
8943d8bca0
commit
50fd4879a9
|
@ -9,6 +9,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
|
||||||
|
|
||||||
### Added
|
### Added
|
||||||
|
|
||||||
|
- Added `DataType`, `AverageMethod` and `MDMCAverageMethod` enum in metrics ([#5657](https://github.com/PyTorchLightning/pytorch-lightning/pull/5689)
|
||||||
|
|
||||||
|
|
||||||
- Added support for summarized model total params size in megabytes ([#5590](https://github.com/PyTorchLightning/pytorch-lightning/pull/5590))
|
- Added support for summarized model total params size in megabytes ([#5590](https://github.com/PyTorchLightning/pytorch-lightning/pull/5590))
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -17,6 +17,39 @@ import numpy as np
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from pytorch_lightning.metrics.utils import select_topk, to_onehot
|
from pytorch_lightning.metrics.utils import select_topk, to_onehot
|
||||||
|
from pytorch_lightning.utilities import LightningEnum
|
||||||
|
|
||||||
|
|
||||||
|
class DataType(LightningEnum):
|
||||||
|
"""
|
||||||
|
Enum to represent data type
|
||||||
|
"""
|
||||||
|
|
||||||
|
BINARY = "binary"
|
||||||
|
MULTILABEL = "multi-label"
|
||||||
|
MULTICLASS = "multi-class"
|
||||||
|
MULTIDIM_MULTICLASS = "multi-dim multi-class"
|
||||||
|
|
||||||
|
|
||||||
|
class AverageMethod(LightningEnum):
|
||||||
|
"""
|
||||||
|
Enum to represent average method
|
||||||
|
"""
|
||||||
|
|
||||||
|
MICRO = "micro"
|
||||||
|
MACRO = "macro"
|
||||||
|
WEIGHTED = "weighted"
|
||||||
|
NONE = "none"
|
||||||
|
SAMPLES = "samples"
|
||||||
|
|
||||||
|
|
||||||
|
class MDMCAverageMethod(LightningEnum):
|
||||||
|
"""
|
||||||
|
Enum to represent multi-dim multi-class average method
|
||||||
|
"""
|
||||||
|
|
||||||
|
GLOBAL = "global"
|
||||||
|
SAMPLEWISE = "samplewise"
|
||||||
|
|
||||||
|
|
||||||
def _basic_input_validation(preds: torch.Tensor, target: torch.Tensor, threshold: float, is_multiclass: bool):
|
def _basic_input_validation(preds: torch.Tensor, target: torch.Tensor, threshold: float, is_multiclass: bool):
|
||||||
|
@ -78,13 +111,13 @@ def _check_shape_and_type_consistency(preds: torch.Tensor, target: torch.Tensor)
|
||||||
|
|
||||||
# Get the case
|
# Get the case
|
||||||
if preds.ndim == 1 and preds_float:
|
if preds.ndim == 1 and preds_float:
|
||||||
case = "binary"
|
case = DataType.BINARY
|
||||||
elif preds.ndim == 1 and not preds_float:
|
elif preds.ndim == 1 and not preds_float:
|
||||||
case = "multi-class"
|
case = DataType.MULTICLASS
|
||||||
elif preds.ndim > 1 and preds_float:
|
elif preds.ndim > 1 and preds_float:
|
||||||
case = "multi-label"
|
case = DataType.MULTILABEL
|
||||||
else:
|
else:
|
||||||
case = "multi-dim multi-class"
|
case = DataType.MULTIDIM_MULTICLASS
|
||||||
|
|
||||||
implied_classes = preds[0].numel()
|
implied_classes = preds[0].numel()
|
||||||
|
|
||||||
|
@ -100,9 +133,9 @@ def _check_shape_and_type_consistency(preds: torch.Tensor, target: torch.Tensor)
|
||||||
implied_classes = preds.shape[1]
|
implied_classes = preds.shape[1]
|
||||||
|
|
||||||
if preds.ndim == 2:
|
if preds.ndim == 2:
|
||||||
case = "multi-class"
|
case = DataType.MULTICLASS
|
||||||
else:
|
else:
|
||||||
case = "multi-dim multi-class"
|
case = DataType.MULTIDIM_MULTICLASS
|
||||||
else:
|
else:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"Either `preds` and `target` both should have the (same) shape (N, ...), or `target` should be (N, ...)"
|
"Either `preds` and `target` both should have the (same) shape (N, ...), or `target` should be (N, ...)"
|
||||||
|
@ -182,7 +215,7 @@ def _check_num_classes_ml(num_classes: int, is_multiclass: bool, implied_classes
|
||||||
|
|
||||||
|
|
||||||
def _check_top_k(top_k: int, case: str, implied_classes: int, is_multiclass: Optional[bool], preds_float: bool):
|
def _check_top_k(top_k: int, case: str, implied_classes: int, is_multiclass: Optional[bool], preds_float: bool):
|
||||||
if case == "binary":
|
if case == DataType.BINARY:
|
||||||
raise ValueError("You can not use `top_k` parameter with binary data.")
|
raise ValueError("You can not use `top_k` parameter with binary data.")
|
||||||
if not isinstance(top_k, int) or top_k <= 0:
|
if not isinstance(top_k, int) or top_k <= 0:
|
||||||
raise ValueError("The `top_k` has to be an integer larger than 0.")
|
raise ValueError("The `top_k` has to be an integer larger than 0.")
|
||||||
|
@ -190,7 +223,7 @@ def _check_top_k(top_k: int, case: str, implied_classes: int, is_multiclass: Opt
|
||||||
raise ValueError("You have set `top_k`, but you do not have probability predictions.")
|
raise ValueError("You have set `top_k`, but you do not have probability predictions.")
|
||||||
if is_multiclass is False:
|
if is_multiclass is False:
|
||||||
raise ValueError("If you set `is_multiclass=False`, you can not set `top_k`.")
|
raise ValueError("If you set `is_multiclass=False`, you can not set `top_k`.")
|
||||||
if case == "multi-label" and is_multiclass:
|
if case == DataType.MULTILABEL and is_multiclass:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"If you want to transform multi-label data to 2 class multi-dimensional"
|
"If you want to transform multi-label data to 2 class multi-dimensional"
|
||||||
"multi-class data using `is_multiclass=True`, you can not use `top_k`."
|
"multi-class data using `is_multiclass=True`, you can not use `top_k`."
|
||||||
|
@ -266,7 +299,7 @@ def _check_classification_inputs(
|
||||||
case, implied_classes = _check_shape_and_type_consistency(preds, target)
|
case, implied_classes = _check_shape_and_type_consistency(preds, target)
|
||||||
|
|
||||||
# For (multi-dim) multi-class case with prob preds, check that preds sum up to 1
|
# For (multi-dim) multi-class case with prob preds, check that preds sum up to 1
|
||||||
if "multi-class" in case and preds.is_floating_point():
|
if case in (DataType.MULTICLASS, DataType.MULTIDIM_MULTICLASS) and preds.is_floating_point():
|
||||||
if not torch.isclose(preds.sum(dim=1), torch.ones_like(preds.sum(dim=1))).all():
|
if not torch.isclose(preds.sum(dim=1), torch.ones_like(preds.sum(dim=1))).all():
|
||||||
raise ValueError("Probabilities in `preds` must sum up to 1 accross the `C` dimension.")
|
raise ValueError("Probabilities in `preds` must sum up to 1 accross the `C` dimension.")
|
||||||
|
|
||||||
|
@ -284,11 +317,11 @@ def _check_classification_inputs(
|
||||||
|
|
||||||
# Check that num_classes is consistent
|
# Check that num_classes is consistent
|
||||||
if num_classes:
|
if num_classes:
|
||||||
if case == "binary":
|
if case == DataType.BINARY:
|
||||||
_check_num_classes_binary(num_classes, is_multiclass)
|
_check_num_classes_binary(num_classes, is_multiclass)
|
||||||
elif "multi-class" in case:
|
elif case in (DataType.MULTICLASS, DataType.MULTIDIM_MULTICLASS):
|
||||||
_check_num_classes_mc(preds, target, num_classes, is_multiclass, implied_classes)
|
_check_num_classes_mc(preds, target, num_classes, is_multiclass, implied_classes)
|
||||||
elif case == "multi-label":
|
elif case.MULTILABEL:
|
||||||
_check_num_classes_ml(num_classes, is_multiclass, implied_classes)
|
_check_num_classes_ml(num_classes, is_multiclass, implied_classes)
|
||||||
|
|
||||||
# Check that top_k is consistent
|
# Check that top_k is consistent
|
||||||
|
@ -406,14 +439,14 @@ def _input_format_classification(
|
||||||
top_k=top_k,
|
top_k=top_k,
|
||||||
)
|
)
|
||||||
|
|
||||||
if case in ["binary", "multi-label"] and not top_k:
|
if case in (DataType.BINARY, DataType.MULTILABEL) and not top_k:
|
||||||
preds = (preds >= threshold).int()
|
preds = (preds >= threshold).int()
|
||||||
num_classes = num_classes if not is_multiclass else 2
|
num_classes = num_classes if not is_multiclass else 2
|
||||||
|
|
||||||
if case == "multi-label" and top_k:
|
if case == DataType.MULTILABEL and top_k:
|
||||||
preds = select_topk(preds, top_k)
|
preds = select_topk(preds, top_k)
|
||||||
|
|
||||||
if "multi-class" in case or is_multiclass:
|
if case in (DataType.MULTICLASS, DataType.MULTIDIM_MULTICLASS) or is_multiclass:
|
||||||
if preds.is_floating_point():
|
if preds.is_floating_point():
|
||||||
num_classes = preds.shape[1]
|
num_classes = preds.shape[1]
|
||||||
preds = select_topk(preds, top_k or 1)
|
preds = select_topk(preds, top_k or 1)
|
||||||
|
@ -426,7 +459,7 @@ def _input_format_classification(
|
||||||
if is_multiclass is False:
|
if is_multiclass is False:
|
||||||
preds, target = preds[:, 1, ...], target[:, 1, ...]
|
preds, target = preds[:, 1, ...], target[:, 1, ...]
|
||||||
|
|
||||||
if ("multi-class" in case and is_multiclass is not False) or is_multiclass:
|
if (case in (DataType.MULTICLASS, DataType.MULTIDIM_MULTICLASS) and is_multiclass is not False) or is_multiclass:
|
||||||
target = target.reshape(target.shape[0], target.shape[1], -1)
|
target = target.reshape(target.shape[0], target.shape[1], -1)
|
||||||
preds = preds.reshape(preds.shape[0], preds.shape[1], -1)
|
preds = preds.reshape(preds.shape[0], preds.shape[1], -1)
|
||||||
else:
|
else:
|
||||||
|
@ -486,7 +519,7 @@ model_evaluation.html#multiclass-and-multilabel-classification>`__.
|
||||||
denominator = torch.where(zero_div_mask | ignore_mask, torch.tensor(1.0, device=denominator.device), denominator)
|
denominator = torch.where(zero_div_mask | ignore_mask, torch.tensor(1.0, device=denominator.device), denominator)
|
||||||
weights = torch.where(ignore_mask, torch.tensor(0.0, device=weights.device), weights)
|
weights = torch.where(ignore_mask, torch.tensor(0.0, device=weights.device), weights)
|
||||||
|
|
||||||
if average not in ["micro", "none", None]:
|
if average not in (AverageMethod.MICRO, AverageMethod.NONE, None):
|
||||||
weights = weights / weights.sum(dim=-1, keepdim=True)
|
weights = weights / weights.sum(dim=-1, keepdim=True)
|
||||||
|
|
||||||
scores = weights * (numerator / denominator)
|
scores = weights * (numerator / denominator)
|
||||||
|
@ -494,11 +527,11 @@ model_evaluation.html#multiclass-and-multilabel-classification>`__.
|
||||||
# This is in case where sum(weights) = 0, which happens if we ignore the only present class with average='weighted'
|
# This is in case where sum(weights) = 0, which happens if we ignore the only present class with average='weighted'
|
||||||
scores = torch.where(torch.isnan(scores), torch.tensor(float(zero_division), device=scores.device), scores)
|
scores = torch.where(torch.isnan(scores), torch.tensor(float(zero_division), device=scores.device), scores)
|
||||||
|
|
||||||
if mdmc_average == "samplewise":
|
if mdmc_average == MDMCAverageMethod.SAMPLEWISE:
|
||||||
scores = scores.mean(dim=0)
|
scores = scores.mean(dim=0)
|
||||||
ignore_mask = ignore_mask.sum(dim=0).bool()
|
ignore_mask = ignore_mask.sum(dim=0).bool()
|
||||||
|
|
||||||
if average in ["none", None]:
|
if average in (AverageMethod.NONE, None):
|
||||||
scores = torch.where(ignore_mask, torch.tensor(np.nan, device=scores.device), scores)
|
scores = torch.where(ignore_mask, torch.tensor(np.nan, device=scores.device), scores)
|
||||||
else:
|
else:
|
||||||
scores = scores.sum()
|
scores = scores.sum()
|
||||||
|
|
|
@ -15,7 +15,7 @@ from typing import Optional, Tuple
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from pytorch_lightning.metrics.classification.helpers import _input_format_classification
|
from pytorch_lightning.metrics.classification.helpers import _input_format_classification, DataType
|
||||||
|
|
||||||
|
|
||||||
def _accuracy_update(
|
def _accuracy_update(
|
||||||
|
@ -24,19 +24,19 @@ def _accuracy_update(
|
||||||
|
|
||||||
preds, target, mode = _input_format_classification(preds, target, threshold=threshold, top_k=top_k)
|
preds, target, mode = _input_format_classification(preds, target, threshold=threshold, top_k=top_k)
|
||||||
|
|
||||||
if mode == "multi-label" and top_k:
|
if mode == DataType.MULTILABEL and top_k:
|
||||||
raise ValueError("You can not use the `top_k` parameter to calculate accuracy for multi-label inputs.")
|
raise ValueError("You can not use the `top_k` parameter to calculate accuracy for multi-label inputs.")
|
||||||
|
|
||||||
if mode == "binary" or (mode == "multi-label" and subset_accuracy):
|
if mode == DataType.BINARY or (mode == DataType.MULTILABEL and subset_accuracy):
|
||||||
correct = (preds == target).all(dim=1).sum()
|
correct = (preds == target).all(dim=1).sum()
|
||||||
total = torch.tensor(target.shape[0], device=target.device)
|
total = torch.tensor(target.shape[0], device=target.device)
|
||||||
elif mode == "multi-label" and not subset_accuracy:
|
elif mode == DataType.MULTILABEL and not subset_accuracy:
|
||||||
correct = (preds == target).sum()
|
correct = (preds == target).sum()
|
||||||
total = torch.tensor(target.numel(), device=target.device)
|
total = torch.tensor(target.numel(), device=target.device)
|
||||||
elif mode == "multi-class" or (mode == "multi-dim multi-class" and not subset_accuracy):
|
elif mode == DataType.MULTICLASS or (mode == DataType.MULTIDIM_MULTICLASS and not subset_accuracy):
|
||||||
correct = (preds * target).sum()
|
correct = (preds * target).sum()
|
||||||
total = target.sum()
|
total = target.sum()
|
||||||
elif mode == "multi-dim multi-class" and subset_accuracy:
|
elif mode == DataType.MULTIDIM_MULTICLASS and subset_accuracy:
|
||||||
sample_correct = (preds * target).sum(dim=(1, 2))
|
sample_correct = (preds * target).sum(dim=(1, 2))
|
||||||
correct = (sample_correct == target.shape[2]).sum()
|
correct = (sample_correct == target.shape[2]).sum()
|
||||||
total = torch.tensor(target.shape[0], device=target.device)
|
total = torch.tensor(target.shape[0], device=target.device)
|
||||||
|
|
|
@ -16,7 +16,7 @@ from typing import Optional, Sequence, Tuple
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from pytorch_lightning.metrics.classification.helpers import _input_format_classification
|
from pytorch_lightning.metrics.classification.helpers import _input_format_classification, DataType
|
||||||
from pytorch_lightning.metrics.functional.auc import auc
|
from pytorch_lightning.metrics.functional.auc import auc
|
||||||
from pytorch_lightning.metrics.functional.roc import roc
|
from pytorch_lightning.metrics.functional.roc import roc
|
||||||
from pytorch_lightning.utilities import LightningEnum
|
from pytorch_lightning.utilities import LightningEnum
|
||||||
|
@ -102,7 +102,7 @@ def _auroc_compute(
|
||||||
elif average == AverageMethods.MACRO:
|
elif average == AverageMethods.MACRO:
|
||||||
return torch.mean(torch.stack(auc_scores))
|
return torch.mean(torch.stack(auc_scores))
|
||||||
elif average == AverageMethods.WEIGHTED:
|
elif average == AverageMethods.WEIGHTED:
|
||||||
if mode == 'multi-label':
|
if mode == DataType.MULTILABEL:
|
||||||
support = torch.sum(target, dim=0)
|
support = torch.sum(target, dim=0)
|
||||||
else:
|
else:
|
||||||
support = torch.bincount(target.flatten(), minlength=num_classes)
|
support = torch.bincount(target.flatten(), minlength=num_classes)
|
||||||
|
|
|
@ -15,7 +15,7 @@ from typing import Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from pytorch_lightning.metrics.classification.helpers import _input_format_classification
|
from pytorch_lightning.metrics.classification.helpers import _input_format_classification, DataType
|
||||||
from pytorch_lightning.utilities import rank_zero_warn
|
from pytorch_lightning.utilities import rank_zero_warn
|
||||||
|
|
||||||
|
|
||||||
|
@ -23,7 +23,7 @@ def _confusion_matrix_update(
|
||||||
preds: torch.Tensor, target: torch.Tensor, num_classes: int, threshold: float = 0.5
|
preds: torch.Tensor, target: torch.Tensor, num_classes: int, threshold: float = 0.5
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
preds, target, mode = _input_format_classification(preds, target, threshold)
|
preds, target, mode = _input_format_classification(preds, target, threshold)
|
||||||
if mode not in ('binary', 'multi-label'):
|
if mode not in (DataType.BINARY, DataType.MULTILABEL):
|
||||||
preds = preds.argmax(dim=1)
|
preds = preds.argmax(dim=1)
|
||||||
target = target.argmax(dim=1)
|
target = target.argmax(dim=1)
|
||||||
unique_mapping = (target.view(-1) * num_classes + preds.view(-1)).to(torch.long)
|
unique_mapping = (target.view(-1) * num_classes + preds.view(-1)).to(torch.long)
|
||||||
|
|
|
@ -6,7 +6,7 @@ import torch
|
||||||
from sklearn.metrics import accuracy_score as sk_accuracy
|
from sklearn.metrics import accuracy_score as sk_accuracy
|
||||||
|
|
||||||
from pytorch_lightning.metrics import Accuracy
|
from pytorch_lightning.metrics import Accuracy
|
||||||
from pytorch_lightning.metrics.classification.helpers import _input_format_classification
|
from pytorch_lightning.metrics.classification.helpers import _input_format_classification, DataType
|
||||||
from pytorch_lightning.metrics.functional import accuracy
|
from pytorch_lightning.metrics.functional import accuracy
|
||||||
from tests.metrics.classification.inputs import (
|
from tests.metrics.classification.inputs import (
|
||||||
_binary_inputs,
|
_binary_inputs,
|
||||||
|
@ -29,12 +29,12 @@ def _sk_accuracy(preds, target, subset_accuracy):
|
||||||
sk_preds, sk_target, mode = _input_format_classification(preds, target, threshold=THRESHOLD)
|
sk_preds, sk_target, mode = _input_format_classification(preds, target, threshold=THRESHOLD)
|
||||||
sk_preds, sk_target = sk_preds.numpy(), sk_target.numpy()
|
sk_preds, sk_target = sk_preds.numpy(), sk_target.numpy()
|
||||||
|
|
||||||
if mode == "multi-dim multi-class" and not subset_accuracy:
|
if mode == DataType.MULTIDIM_MULTICLASS and not subset_accuracy:
|
||||||
sk_preds, sk_target = np.transpose(sk_preds, (0, 2, 1)), np.transpose(sk_target, (0, 2, 1))
|
sk_preds, sk_target = np.transpose(sk_preds, (0, 2, 1)), np.transpose(sk_target, (0, 2, 1))
|
||||||
sk_preds, sk_target = sk_preds.reshape(-1, sk_preds.shape[2]), sk_target.reshape(-1, sk_target.shape[2])
|
sk_preds, sk_target = sk_preds.reshape(-1, sk_preds.shape[2]), sk_target.reshape(-1, sk_target.shape[2])
|
||||||
elif mode == mode == "multi-dim multi-class" and subset_accuracy:
|
elif mode == DataType.MULTIDIM_MULTICLASS and subset_accuracy:
|
||||||
return np.all(sk_preds == sk_target, axis=(1, 2)).mean()
|
return np.all(sk_preds == sk_target, axis=(1, 2)).mean()
|
||||||
elif mode == "multi-label" and not subset_accuracy:
|
elif mode == DataType.MULTILABEL and not subset_accuracy:
|
||||||
sk_preds, sk_target = sk_preds.reshape(-1), sk_target.reshape(-1)
|
sk_preds, sk_target = sk_preds.reshape(-1), sk_target.reshape(-1)
|
||||||
|
|
||||||
return sk_accuracy(y_true=sk_target, y_pred=sk_preds)
|
return sk_accuracy(y_true=sk_target, y_pred=sk_preds)
|
||||||
|
|
|
@ -2,7 +2,7 @@ import pytest
|
||||||
import torch
|
import torch
|
||||||
from torch import rand, randint
|
from torch import rand, randint
|
||||||
|
|
||||||
from pytorch_lightning.metrics.classification.helpers import _input_format_classification
|
from pytorch_lightning.metrics.classification.helpers import _input_format_classification, DataType
|
||||||
from pytorch_lightning.metrics.utils import select_topk, to_onehot
|
from pytorch_lightning.metrics.utils import select_topk, to_onehot
|
||||||
from tests.metrics.classification.inputs import _binary_inputs as _bin
|
from tests.metrics.classification.inputs import _binary_inputs as _bin
|
||||||
from tests.metrics.classification.inputs import _binary_prob_inputs as _bin_prob
|
from tests.metrics.classification.inputs import _binary_prob_inputs as _bin_prob
|
||||||
|
@ -155,32 +155,36 @@ def _mlmd_prob_to_mc_preds_tr(x):
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
def test_usual_cases(inputs, num_classes, is_multiclass, top_k, exp_mode, post_preds, post_target):
|
def test_usual_cases(inputs, num_classes, is_multiclass, top_k, exp_mode, post_preds, post_target):
|
||||||
preds_out, target_out, mode = _input_format_classification(
|
def __get_data_type_enum(str_exp_mode):
|
||||||
preds=inputs.preds[0],
|
return next(DataType[n] for n in dir(DataType) if DataType[n] == str_exp_mode)
|
||||||
target=inputs.target[0],
|
|
||||||
threshold=THRESHOLD,
|
|
||||||
num_classes=num_classes,
|
|
||||||
is_multiclass=is_multiclass,
|
|
||||||
top_k=top_k,
|
|
||||||
)
|
|
||||||
|
|
||||||
assert mode == exp_mode
|
for exp_mode in (exp_mode, __get_data_type_enum(exp_mode)):
|
||||||
assert torch.equal(preds_out, post_preds(inputs.preds[0]).int())
|
preds_out, target_out, mode = _input_format_classification(
|
||||||
assert torch.equal(target_out, post_target(inputs.target[0]).int())
|
preds=inputs.preds[0],
|
||||||
|
target=inputs.target[0],
|
||||||
|
threshold=THRESHOLD,
|
||||||
|
num_classes=num_classes,
|
||||||
|
is_multiclass=is_multiclass,
|
||||||
|
top_k=top_k,
|
||||||
|
)
|
||||||
|
|
||||||
# Test that things work when batch_size = 1
|
assert mode == exp_mode
|
||||||
preds_out, target_out, mode = _input_format_classification(
|
assert torch.equal(preds_out, post_preds(inputs.preds[0]).int())
|
||||||
preds=inputs.preds[0][[0], ...],
|
assert torch.equal(target_out, post_target(inputs.target[0]).int())
|
||||||
target=inputs.target[0][[0], ...],
|
|
||||||
threshold=THRESHOLD,
|
|
||||||
num_classes=num_classes,
|
|
||||||
is_multiclass=is_multiclass,
|
|
||||||
top_k=top_k,
|
|
||||||
)
|
|
||||||
|
|
||||||
assert mode == exp_mode
|
# Test that things work when batch_size = 1
|
||||||
assert torch.equal(preds_out, post_preds(inputs.preds[0][[0], ...]).int())
|
preds_out, target_out, mode = _input_format_classification(
|
||||||
assert torch.equal(target_out, post_target(inputs.target[0][[0], ...]).int())
|
preds=inputs.preds[0][[0], ...],
|
||||||
|
target=inputs.target[0][[0], ...],
|
||||||
|
threshold=THRESHOLD,
|
||||||
|
num_classes=num_classes,
|
||||||
|
is_multiclass=is_multiclass,
|
||||||
|
top_k=top_k,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert mode == exp_mode
|
||||||
|
assert torch.equal(preds_out, post_preds(inputs.preds[0][[0], ...]).int())
|
||||||
|
assert torch.equal(target_out, post_target(inputs.target[0][[0], ...]).int())
|
||||||
|
|
||||||
|
|
||||||
# Test that threshold is correctly applied
|
# Test that threshold is correctly applied
|
||||||
|
|
Loading…
Reference in New Issue