lightning/tests/metrics/test_remove_1-5_metrics.py

349 lines
12 KiB
Python

# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Test deprecated functionality which will be removed in v1.5.0"""
import pytest
import torch
from pytorch_lightning.metrics import (
Accuracy,
AUC,
AUROC,
AveragePrecision,
ConfusionMatrix,
ExplainedVariance,
F1,
FBeta,
HammingDistance,
IoU,
MeanAbsoluteError,
MeanSquaredError,
MeanSquaredLogError,
MetricCollection,
Precision,
PrecisionRecallCurve,
PSNR,
R2Score,
Recall,
ROC,
SSIM,
StatScores,
)
from pytorch_lightning.metrics.functional import (
auc,
auroc,
average_precision,
bleu_score,
confusion_matrix,
embedding_similarity,
explained_variance,
f1,
fbeta,
hamming_distance,
iou,
mean_absolute_error,
mean_squared_error,
mean_squared_log_error,
precision,
precision_recall,
precision_recall_curve,
psnr,
r2score,
recall,
roc,
ssim,
stat_scores,
)
from pytorch_lightning.metrics.functional.accuracy import accuracy
from pytorch_lightning.metrics.functional.mean_relative_error import mean_relative_error
from pytorch_lightning.metrics.utils import get_num_classes, select_topk, to_categorical, to_onehot
def test_v1_5_metrics_utils():
x = torch.tensor([1, 2, 3])
with pytest.deprecated_call(match="It will be removed in v1.5.0"):
assert torch.equal(to_onehot(x), torch.Tensor([[0, 1, 0, 0], [0, 0, 1, 0], [0, 0, 0, 1]]).to(int))
with pytest.deprecated_call(match="It will be removed in v1.5.0"):
assert get_num_classes(torch.tensor([1, 2, 3]), torch.tensor([1, 2, 0])) == 4
x = torch.tensor([[1.1, 2.0, 3.0], [2.0, 1.0, 0.5]])
with pytest.deprecated_call(match="It will be removed in v1.5.0"):
assert torch.equal(select_topk(x, topk=2), torch.Tensor([[0, 1, 1], [1, 1, 0]]).to(torch.int32))
x = torch.tensor([[0.2, 0.5], [0.9, 0.1]])
with pytest.deprecated_call(match="It will be removed in v1.5.0"):
assert torch.equal(to_categorical(x), torch.Tensor([1, 0]).to(int))
def test_v1_5_metrics_collection():
target = torch.tensor([0, 2, 0, 2, 0, 1, 0, 2])
preds = torch.tensor([2, 1, 2, 0, 1, 2, 2, 2])
MetricCollection.__init__._warned = False
with pytest.deprecated_call(match="It will be removed in v1.5.0."):
metrics = MetricCollection([Accuracy()])
assert metrics(preds, target) == {'Accuracy': torch.tensor(0.1250)}
def test_v1_5_metric_accuracy():
accuracy._warned = False
preds = torch.tensor([0, 0, 1, 0, 1])
target = torch.tensor([0, 0, 1, 1, 1])
with pytest.deprecated_call(match='It will be removed in v1.5.0'):
assert accuracy(preds, target) == torch.tensor(0.8)
Accuracy.__init__._warned = False
with pytest.deprecated_call(match='It will be removed in v1.5.0'):
Accuracy()
def test_v1_5_metric_auc_auroc():
AUC.__init__._warned = False
with pytest.deprecated_call(match='It will be removed in v1.5.0'):
AUC()
ROC.__init__._warned = False
with pytest.deprecated_call(match='It will be removed in v1.5.0'):
ROC()
AUROC.__init__._warned = False
with pytest.deprecated_call(match='It will be removed in v1.5.0'):
AUROC()
x = torch.tensor([0, 1, 2, 3])
y = torch.tensor([0, 1, 2, 2])
auc._warned = False
with pytest.deprecated_call(match='It will be removed in v1.5.0'):
assert auc(x, y) == torch.tensor(4.)
preds = torch.tensor([0, 1, 2, 3])
target = torch.tensor([0, 1, 1, 1])
roc._warned = False
with pytest.deprecated_call(match='It will be removed in v1.5.0'):
fpr, tpr, thrs = roc(preds, target, pos_label=1)
assert torch.equal(fpr, torch.tensor([0., 0., 0., 0., 1.]))
assert torch.allclose(tpr, torch.tensor([0.0000, 0.3333, 0.6667, 1.0000, 1.0000]), atol=1e-4)
assert torch.equal(thrs, torch.tensor([4, 3, 2, 1, 0]))
preds = torch.tensor([0.13, 0.26, 0.08, 0.19, 0.34])
target = torch.tensor([0, 0, 1, 1, 1])
auroc._warned = False
with pytest.deprecated_call(match='It will be removed in v1.5.0'):
assert auroc(preds, target) == torch.tensor(0.5)
def test_v1_5_metric_precision_recall():
AveragePrecision.__init__._warned = False
with pytest.deprecated_call(match='It will be removed in v1.5.0'):
AveragePrecision()
Precision.__init__._warned = False
with pytest.deprecated_call(match='It will be removed in v1.5.0'):
Precision()
Recall.__init__._warned = False
with pytest.deprecated_call(match='It will be removed in v1.5.0'):
Recall()
PrecisionRecallCurve.__init__._warned = False
with pytest.deprecated_call(match='It will be removed in v1.5.0'):
PrecisionRecallCurve()
pred = torch.tensor([0, 1, 2, 3])
target = torch.tensor([0, 1, 1, 1])
average_precision._warned = False
with pytest.deprecated_call(match='It will be removed in v1.5.0'):
assert average_precision(pred, target) == torch.tensor(1.)
precision._warned = False
with pytest.deprecated_call(match='It will be removed in v1.5.0'):
assert precision(pred, target) == torch.tensor(0.5)
recall._warned = False
with pytest.deprecated_call(match='It will be removed in v1.5.0'):
assert recall(pred, target) == torch.tensor(0.5)
precision_recall._warned = False
with pytest.deprecated_call(match='It will be removed in v1.5.0'):
prec, rc = precision_recall(pred, target)
assert prec == torch.tensor(0.5)
assert rc == torch.tensor(0.5)
precision_recall_curve._warned = False
with pytest.deprecated_call(match='It will be removed in v1.5.0'):
prec, rc, thrs = precision_recall_curve(pred, target)
assert torch.equal(prec, torch.tensor([1., 1., 1., 1.]))
assert torch.allclose(rc, torch.tensor([1., 0.6667, 0.3333, 0.]), atol=1e-4)
assert torch.equal(thrs, torch.tensor([1, 2, 3]))
def test_v1_5_metric_classif_mix():
ConfusionMatrix.__init__._warned = False
with pytest.deprecated_call(match='It will be removed in v1.5.0'):
ConfusionMatrix(num_classes=1)
FBeta.__init__._warned = False
with pytest.deprecated_call(match='It will be removed in v1.5.0'):
FBeta(num_classes=1)
F1.__init__._warned = False
with pytest.deprecated_call(match='It will be removed in v1.5.0'):
F1(num_classes=1)
HammingDistance.__init__._warned = False
with pytest.deprecated_call(match='It will be removed in v1.5.0'):
HammingDistance()
StatScores.__init__._warned = False
with pytest.deprecated_call(match='It will be removed in v1.5.0'):
StatScores()
target = torch.tensor([1, 1, 0, 0])
preds = torch.tensor([0, 1, 0, 0])
confusion_matrix._warned = False
with pytest.deprecated_call(match='It will be removed in v1.5.0'):
assert torch.equal(confusion_matrix(preds, target, num_classes=2), torch.tensor([[2., 0.], [1., 1.]]))
target = torch.tensor([0, 1, 2, 0, 1, 2])
preds = torch.tensor([0, 2, 1, 0, 0, 1])
fbeta._warned = False
with pytest.deprecated_call(match='It will be removed in v1.5.0'):
assert torch.allclose(fbeta(preds, target, num_classes=3, beta=0.5), torch.tensor(0.3333), atol=1e-4)
f1._warned = False
with pytest.deprecated_call(match='It will be removed in v1.5.0'):
assert torch.allclose(f1(preds, target, num_classes=3), torch.tensor(0.3333), atol=1e-4)
target = torch.tensor([[0, 1], [1, 1]])
preds = torch.tensor([[0, 1], [0, 1]])
hamming_distance._warned = False
with pytest.deprecated_call(match='It will be removed in v1.5.0'):
assert hamming_distance(preds, target) == torch.tensor(0.25)
preds = torch.tensor([1, 0, 2, 1])
target = torch.tensor([1, 1, 2, 0])
stat_scores._warned = False
with pytest.deprecated_call(match='It will be removed in v1.5.0'):
assert torch.equal(stat_scores(preds, target, reduce='micro'), torch.tensor([2, 2, 6, 2, 4]))
def test_v1_5_metric_detect():
IoU.__init__._warned = False
with pytest.deprecated_call(match='It will be removed in v1.5.0'):
IoU(num_classes=1)
target = torch.randint(0, 2, (10, 25, 25))
preds = torch.tensor(target)
preds[2:5, 7:13, 9:15] = 1 - preds[2:5, 7:13, 9:15]
iou._warned = False
with pytest.deprecated_call(match='It will be removed in v1.5.0'):
res = iou(preds, target)
assert torch.allclose(res, torch.tensor(0.9660), atol=1e-4)
def test_v1_5_metric_regress():
ExplainedVariance.__init__._warned = False
with pytest.deprecated_call(match='It will be removed in v1.5.0'):
ExplainedVariance()
MeanAbsoluteError.__init__._warned = False
with pytest.deprecated_call(match='It will be removed in v1.5.0'):
MeanAbsoluteError()
MeanSquaredError.__init__._warned = False
with pytest.deprecated_call(match='It will be removed in v1.5.0'):
MeanSquaredError()
MeanSquaredLogError.__init__._warned = False
with pytest.deprecated_call(match='It will be removed in v1.5.0'):
MeanSquaredLogError()
target = torch.tensor([3, -0.5, 2, 7])
preds = torch.tensor([2.5, 0.0, 2, 8])
explained_variance._warned = False
with pytest.deprecated_call(match='It will be removed in v1.5.0'):
res = explained_variance(preds, target)
assert torch.allclose(res, torch.tensor(0.9572), atol=1e-4)
x = torch.tensor([0., 1, 2, 3])
y = torch.tensor([0., 1, 2, 2])
mean_absolute_error._warned = False
with pytest.deprecated_call(match='It will be removed in v1.5.0'):
assert mean_absolute_error(x, y) == 0.25
mean_relative_error._warned = False
with pytest.deprecated_call(match='It will be removed in v1.5.0'):
assert mean_relative_error(x, y) == 0.125
mean_squared_error._warned = False
with pytest.deprecated_call(match='It will be removed in v1.5.0'):
assert mean_squared_error(x, y) == 0.25
mean_squared_log_error._warned = False
with pytest.deprecated_call(match='It will be removed in v1.5.0'):
res = mean_squared_log_error(x, y)
assert torch.allclose(res, torch.tensor(0.0207), atol=1e-4)
PSNR.__init__._warned = False
with pytest.deprecated_call(match='It will be removed in v1.5.0'):
PSNR()
R2Score.__init__._warned = False
with pytest.deprecated_call(match='It will be removed in v1.5.0'):
R2Score()
SSIM.__init__._warned = False
with pytest.deprecated_call(match='It will be removed in v1.5.0'):
SSIM()
preds = torch.tensor([[0.0, 1.0], [2.0, 3.0]])
target = torch.tensor([[3.0, 2.0], [1.0, 0.0]])
psnr._warned = False
with pytest.deprecated_call(match='It will be removed in v1.5.0'):
res = psnr(preds, target)
assert torch.allclose(res, torch.tensor(2.5527), atol=1e-4)
target = torch.tensor([3, -0.5, 2, 7])
preds = torch.tensor([2.5, 0.0, 2, 8])
r2score._warned = False
with pytest.deprecated_call(match='It will be removed in v1.5.0'):
res = r2score(preds, target)
assert torch.allclose(res, torch.tensor(0.9486), atol=1e-4)
preds = torch.rand([16, 1, 16, 16])
target = preds * 0.75
ssim._warned = False
with pytest.deprecated_call(match='It will be removed in v1.5.0'):
res = ssim(preds, target)
assert torch.allclose(res, torch.tensor(0.9219), atol=1e-4)
def test_v1_5_metric_others():
translate_corpus = ['the cat is on the mat'.split()]
reference_corpus = [['there is a cat on the mat'.split(), 'a cat is on the mat'.split()]]
bleu_score._warned = False
with pytest.deprecated_call(match='It will be removed in v1.5.0'):
res = bleu_score(translate_corpus, reference_corpus)
assert torch.allclose(res, torch.tensor(0.7598), atol=1e-4)
embeddings = torch.tensor([[1., 2., 3., 4.], [1., 2., 3., 4.], [4., 5., 6., 7.]])
embedding_similarity._warned = False
with pytest.deprecated_call(match='It will be removed in v1.5.0'):
res = embedding_similarity(embeddings)
assert torch.allclose(
res, torch.tensor([[0.0000, 1.0000, 0.9759], [1.0000, 0.0000, 0.9759], [0.9759, 0.9759, 0.0000]]), atol=1e-4
)