298 lines
9.9 KiB
Python
298 lines
9.9 KiB
Python
import pytest
|
|
import sys
|
|
from collections import namedtuple
|
|
from functools import partial
|
|
import math
|
|
|
|
import torch
|
|
import torch.distributed as dist
|
|
import torch.multiprocessing as mp
|
|
import numpy as np
|
|
|
|
from tests.base import EvalModelTemplate
|
|
from pytorch_lightning import Trainer
|
|
import tests.base.develop_utils as tutils
|
|
from pytorch_lightning.metrics import (
|
|
Accuracy,
|
|
ConfusionMatrix,
|
|
PrecisionRecallCurve,
|
|
Precision,
|
|
Recall,
|
|
AveragePrecision,
|
|
AUROC,
|
|
FBeta,
|
|
F1,
|
|
ROC,
|
|
MulticlassROC,
|
|
MulticlassPrecisionRecallCurve,
|
|
DiceCoefficient,
|
|
IoU,
|
|
MAE,
|
|
MSE,
|
|
RMSE,
|
|
RMSLE,
|
|
PSNR,
|
|
SSIM,
|
|
)
|
|
|
|
from sklearn.metrics import (
|
|
accuracy_score,
|
|
confusion_matrix,
|
|
precision_recall_curve,
|
|
precision_score,
|
|
recall_score,
|
|
average_precision_score,
|
|
roc_auc_score,
|
|
fbeta_score,
|
|
f1_score,
|
|
roc_curve,
|
|
jaccard_score,
|
|
mean_squared_error,
|
|
mean_absolute_error,
|
|
mean_squared_log_error
|
|
)
|
|
|
|
from skimage.metrics import (
|
|
peak_signal_noise_ratio,
|
|
structural_similarity
|
|
)
|
|
|
|
# example structure
|
|
TestCase = namedtuple('example', ['name', 'lightning_metric', 'comparing_metric', 'test_input'])
|
|
|
|
# setup some standard testcases
|
|
NB_SAMPLES = 200
|
|
multiclass_example = [(torch.randint(10, (NB_SAMPLES,)), torch.randint(10, (NB_SAMPLES,)))]
|
|
binary_example = [(torch.randint(2, (NB_SAMPLES,)), torch.randint(2, (NB_SAMPLES,)))]
|
|
multiclass_and_binary_example = [*multiclass_example, *binary_example]
|
|
binary_example_logits = (torch.randint(2, (NB_SAMPLES,)), torch.randint(5, (NB_SAMPLES,)))
|
|
multiclass_example_probs = (torch.randint(10, (NB_SAMPLES,)), torch.randn((NB_SAMPLES, 10)).softmax(-1))
|
|
regression_example = [(torch.rand((NB_SAMPLES,)), torch.rand((NB_SAMPLES,)))]
|
|
|
|
|
|
# construct additional test functions
|
|
def root_mean_squared_error(x, y):
|
|
return math.sqrt(mean_squared_error(x, y))
|
|
|
|
|
|
def root_mean_squared_log_error(x, y):
|
|
return math.sqrt(mean_squared_log_error(x, y))
|
|
|
|
|
|
# Define testcases
|
|
# TODO: update remaining metrics and uncomment the corresponding test cases
|
|
TESTS = [
|
|
TestCase('accuracy',
|
|
Accuracy,
|
|
accuracy_score,
|
|
multiclass_and_binary_example),
|
|
TestCase('confusion matrix without normalize',
|
|
ConfusionMatrix,
|
|
confusion_matrix,
|
|
multiclass_and_binary_example),
|
|
TestCase('confusion matrix with normalize',
|
|
partial(ConfusionMatrix, normalize=True),
|
|
partial(confusion_matrix, normalize='true'),
|
|
multiclass_and_binary_example),
|
|
# TestCase('precision recall curve',
|
|
# PrecisionRecallCurve,
|
|
# precision_recall_curve,
|
|
# binary_example),
|
|
TestCase('precision',
|
|
Precision,
|
|
partial(precision_score, average='micro'),
|
|
multiclass_and_binary_example),
|
|
TestCase('recall',
|
|
Recall,
|
|
partial(recall_score, average='micro'),
|
|
multiclass_and_binary_example),
|
|
# TestCase('average_precision',
|
|
# AveragePrecision,
|
|
# average_precision_score,
|
|
# binary_example),
|
|
# TestCase('auroc',
|
|
# AUROC,
|
|
# roc_auc_score,
|
|
# binary_example),
|
|
TestCase('f beta',
|
|
partial(FBeta, beta=2),
|
|
partial(fbeta_score, average='micro', beta=2),
|
|
multiclass_and_binary_example),
|
|
TestCase('f1',
|
|
F1,
|
|
partial(f1_score, average='micro'),
|
|
multiclass_and_binary_example),
|
|
# TestCase('roc',
|
|
# ROC,
|
|
# roc_curve,
|
|
# binary_example),
|
|
# TestCase('multiclass roc',
|
|
# MulticlassROC,
|
|
# multiclass_roc,
|
|
# binary_example),
|
|
# TestCase('multiclass precision recall curve',
|
|
# MulticlassPrecisionRecallCurve,
|
|
# multiclass_precision_recall_curve,
|
|
# binary_example),
|
|
# TestCase('dice coefficient',
|
|
# DiceCoefficient,
|
|
# partial(f1_score, average='micro'),
|
|
# multiclass_and_binary_example),
|
|
# TestCase('intersection over union',
|
|
# IoU,
|
|
# partial(jaccard_score, average='macro'),
|
|
# binary_example),
|
|
TestCase('mean squared error',
|
|
MSE,
|
|
mean_squared_error,
|
|
regression_example),
|
|
TestCase('root mean squared error',
|
|
RMSE,
|
|
root_mean_squared_error,
|
|
regression_example),
|
|
TestCase('mean absolute error',
|
|
MAE,
|
|
mean_absolute_error,
|
|
regression_example),
|
|
TestCase('root mean squared log error',
|
|
RMSLE,
|
|
root_mean_squared_log_error,
|
|
regression_example),
|
|
TestCase('peak signal-to-noise ratio',
|
|
partial(PSNR, data_range=10),
|
|
partial(peak_signal_noise_ratio, data_range=10),
|
|
regression_example),
|
|
# TestCase('structual similarity index measure',
|
|
# SSIM,
|
|
# structural_similarity,
|
|
# regression_example)
|
|
]
|
|
|
|
|
|
# Utility test functions
|
|
def _idsfn(test):
|
|
""" Return id for current example being tested """
|
|
return test.name
|
|
|
|
|
|
def _setup_ddp(rank, worldsize):
|
|
""" setup ddp enviroment for testing """
|
|
import os
|
|
os.environ['MASTER_ADDR'] = 'localhost'
|
|
# initialize the process group
|
|
dist.init_process_group("gloo", rank=rank, world_size=worldsize)
|
|
|
|
|
|
def comparing_fn(lightning_val, comparing_val, rtol=1e-03, atol=1e-08):
|
|
""" function for comparing output, both multi and single output"""
|
|
# multi output
|
|
if isinstance(comparing_val, tuple):
|
|
for l_score, c_score in zip(lightning_val, comparing_val):
|
|
assert np.allclose(l_score.numpy(), c_score, rtol, atol)
|
|
else: # single output
|
|
assert np.allclose(lightning_val.numpy(), comparing_val, rtol, atol)
|
|
|
|
|
|
# ===== Tests start here =====
|
|
def _test_ddp_single_batch(rank, worldsize, lightning_metric, comparing_metric, test_inputs):
|
|
""" ddp testing function, divide test_inputs equally between all processes """
|
|
_setup_ddp(rank, worldsize)
|
|
|
|
# Setup metric for ddp
|
|
lightning_metric = lightning_metric()
|
|
for test_input in test_inputs:
|
|
# rank 0 receives sample 0,2,4,...
|
|
# rank 1 receives sample 1,3,5,...
|
|
lightning_val = lightning_metric(*[ti[rank::2] for ti in test_input])
|
|
|
|
comparing_val = comparing_metric(*[ti.numpy() for ti in reversed(test_input)])
|
|
|
|
comparing_fn(lightning_val, comparing_val)
|
|
|
|
|
|
@pytest.mark.skipif(sys.platform == "win32", reason="DDP not available on windows")
|
|
@pytest.mark.parametrize("test", TESTS, ids=_idsfn)
|
|
def test_ddp(test):
|
|
"""Make sure that metrics are correctly sync and reduced in DDP mode"""
|
|
tutils.reset_seed()
|
|
tutils.set_random_master_port()
|
|
|
|
worldsize = 2
|
|
mp.spawn(_test_ddp_single_batch,
|
|
args=(worldsize,
|
|
test.lightning_metric,
|
|
test.comparing_metric,
|
|
test.test_input),
|
|
nprocs=worldsize)
|
|
|
|
|
|
@pytest.mark.parametrize("test", TESTS, ids=_idsfn)
|
|
def test_multi_batch(test):
|
|
""" test that aggregation works for multiple batches """
|
|
lightning_metric = test.lightning_metric()
|
|
comparing_metric = test.comparing_metric
|
|
|
|
for test_input in test.test_input:
|
|
for i in range(2): # for lightning device in 2 artificially batches
|
|
# first batch consist of samples 0,2,4,...
|
|
# second batch consist of samples 1,3,5,...
|
|
_ = lightning_metric(*[ti[i::2] for ti in test_input])
|
|
lightning_val = lightning_metric.aggregated
|
|
comparing_val = comparing_metric(*[ti.numpy() for ti in reversed(test_input)])
|
|
|
|
comparing_fn(lightning_val, comparing_val)
|
|
|
|
|
|
@pytest.mark.parametrize("test", TESTS, ids=_idsfn)
|
|
def test_multi_batch_unequal_sizes(test):
|
|
""" test that aggregation works for multiple batches with uneven sizes """
|
|
lightning_metric = test.lightning_metric()
|
|
comparing_metric = test.comparing_metric
|
|
|
|
for test_input in test.test_input:
|
|
for i in range(2): # for lightning device in 2 artificially batches
|
|
if i == 0: # allocate 3/4 of data to the first batch
|
|
_ = lightning_metric(*[ti[:int(3 / 4 * len(ti))] for ti in test_input])
|
|
else:
|
|
_ = lightning_metric(*[ti[int(3 / 4 * len(ti)):] for ti in test_input])
|
|
lightning_val = lightning_metric.aggregated
|
|
comparing_val = comparing_metric(*[ti.numpy() for ti in reversed(test_input)])
|
|
|
|
comparing_fn(lightning_val, comparing_val)
|
|
|
|
|
|
def _test_ddp_multi_batch(rank, worldsize, lightning_metric, comparing_metric, test_inputs):
|
|
""" ddp testing function, test that metric works with aggregation over multiple
|
|
devices and multiple batches """
|
|
_setup_ddp(rank, worldsize)
|
|
|
|
# Setup metric for ddp
|
|
lightning_metric = lightning_metric()
|
|
for test_input in test_inputs:
|
|
for i in range(2): # artificially divide samples between batches and processes
|
|
# rank 0, batch 0 consist of samples 0,4,8,...
|
|
# rank 0, batch 1 consist of samples 1,5,9,...
|
|
# rank 1, batch 0 consist of samples 2,6,10,...
|
|
# rank 1, batch 1 consist of samples 3,7,11,...
|
|
_ = lightning_metric(*[ti[i + worldsize * rank::4] for ti in test_input])
|
|
lightning_val = lightning_metric.aggregated
|
|
comparing_val = comparing_metric(*[ti.numpy() for ti in reversed(test_input)])
|
|
|
|
comparing_fn(lightning_val, comparing_val)
|
|
|
|
|
|
@pytest.mark.skipif(sys.platform == "win32", reason="DDP not available on windows")
|
|
@pytest.mark.parametrize("test", TESTS, ids=_idsfn)
|
|
def test_ddp_multi_batch(test):
|
|
""" test that aggregation works fine with in DDP mode and multiple batches """
|
|
tutils.reset_seed()
|
|
tutils.set_random_master_port()
|
|
|
|
worldsize = 2
|
|
mp.spawn(_test_ddp_multi_batch,
|
|
args=(worldsize,
|
|
test.lightning_metric,
|
|
test.comparing_metric,
|
|
test.test_input),
|
|
nprocs=worldsize)
|