lightning/tests/metrics/test_aggregation.py

298 lines
9.9 KiB
Python

import pytest
import sys
from collections import namedtuple
from functools import partial
import math
import torch
import torch.distributed as dist
import torch.multiprocessing as mp
import numpy as np
from tests.base import EvalModelTemplate
from pytorch_lightning import Trainer
import tests.base.develop_utils as tutils
from pytorch_lightning.metrics import (
Accuracy,
ConfusionMatrix,
PrecisionRecallCurve,
Precision,
Recall,
AveragePrecision,
AUROC,
FBeta,
F1,
ROC,
MulticlassROC,
MulticlassPrecisionRecallCurve,
DiceCoefficient,
IoU,
MAE,
MSE,
RMSE,
RMSLE,
PSNR,
SSIM,
)
from sklearn.metrics import (
accuracy_score,
confusion_matrix,
precision_recall_curve,
precision_score,
recall_score,
average_precision_score,
roc_auc_score,
fbeta_score,
f1_score,
roc_curve,
jaccard_score,
mean_squared_error,
mean_absolute_error,
mean_squared_log_error
)
from skimage.metrics import (
peak_signal_noise_ratio,
structural_similarity
)
# example structure
TestCase = namedtuple('example', ['name', 'lightning_metric', 'comparing_metric', 'test_input'])
# setup some standard testcases
NB_SAMPLES = 200
multiclass_example = [(torch.randint(10, (NB_SAMPLES,)), torch.randint(10, (NB_SAMPLES,)))]
binary_example = [(torch.randint(2, (NB_SAMPLES,)), torch.randint(2, (NB_SAMPLES,)))]
multiclass_and_binary_example = [*multiclass_example, *binary_example]
binary_example_logits = (torch.randint(2, (NB_SAMPLES,)), torch.randint(5, (NB_SAMPLES,)))
multiclass_example_probs = (torch.randint(10, (NB_SAMPLES,)), torch.randn((NB_SAMPLES, 10)).softmax(-1))
regression_example = [(torch.rand((NB_SAMPLES,)), torch.rand((NB_SAMPLES,)))]
# construct additional test functions
def root_mean_squared_error(x, y):
return math.sqrt(mean_squared_error(x, y))
def root_mean_squared_log_error(x, y):
return math.sqrt(mean_squared_log_error(x, y))
# Define testcases
# TODO: update remaining metrics and uncomment the corresponding test cases
TESTS = [
TestCase('accuracy',
Accuracy,
accuracy_score,
multiclass_and_binary_example),
TestCase('confusion matrix without normalize',
ConfusionMatrix,
confusion_matrix,
multiclass_and_binary_example),
TestCase('confusion matrix with normalize',
partial(ConfusionMatrix, normalize=True),
partial(confusion_matrix, normalize='true'),
multiclass_and_binary_example),
# TestCase('precision recall curve',
# PrecisionRecallCurve,
# precision_recall_curve,
# binary_example),
TestCase('precision',
Precision,
partial(precision_score, average='micro'),
multiclass_and_binary_example),
TestCase('recall',
Recall,
partial(recall_score, average='micro'),
multiclass_and_binary_example),
# TestCase('average_precision',
# AveragePrecision,
# average_precision_score,
# binary_example),
# TestCase('auroc',
# AUROC,
# roc_auc_score,
# binary_example),
TestCase('f beta',
partial(FBeta, beta=2),
partial(fbeta_score, average='micro', beta=2),
multiclass_and_binary_example),
TestCase('f1',
F1,
partial(f1_score, average='micro'),
multiclass_and_binary_example),
# TestCase('roc',
# ROC,
# roc_curve,
# binary_example),
# TestCase('multiclass roc',
# MulticlassROC,
# multiclass_roc,
# binary_example),
# TestCase('multiclass precision recall curve',
# MulticlassPrecisionRecallCurve,
# multiclass_precision_recall_curve,
# binary_example),
# TestCase('dice coefficient',
# DiceCoefficient,
# partial(f1_score, average='micro'),
# multiclass_and_binary_example),
# TestCase('intersection over union',
# IoU,
# partial(jaccard_score, average='macro'),
# binary_example),
TestCase('mean squared error',
MSE,
mean_squared_error,
regression_example),
TestCase('root mean squared error',
RMSE,
root_mean_squared_error,
regression_example),
TestCase('mean absolute error',
MAE,
mean_absolute_error,
regression_example),
TestCase('root mean squared log error',
RMSLE,
root_mean_squared_log_error,
regression_example),
TestCase('peak signal-to-noise ratio',
partial(PSNR, data_range=10),
partial(peak_signal_noise_ratio, data_range=10),
regression_example),
# TestCase('structual similarity index measure',
# SSIM,
# structural_similarity,
# regression_example)
]
# Utility test functions
def _idsfn(test):
""" Return id for current example being tested """
return test.name
def _setup_ddp(rank, worldsize):
""" setup ddp enviroment for testing """
import os
os.environ['MASTER_ADDR'] = 'localhost'
# initialize the process group
dist.init_process_group("gloo", rank=rank, world_size=worldsize)
def comparing_fn(lightning_val, comparing_val, rtol=1e-03, atol=1e-08):
""" function for comparing output, both multi and single output"""
# multi output
if isinstance(comparing_val, tuple):
for l_score, c_score in zip(lightning_val, comparing_val):
assert np.allclose(l_score.numpy(), c_score, rtol, atol)
else: # single output
assert np.allclose(lightning_val.numpy(), comparing_val, rtol, atol)
# ===== Tests start here =====
def _test_ddp_single_batch(rank, worldsize, lightning_metric, comparing_metric, test_inputs):
""" ddp testing function, divide test_inputs equally between all processes """
_setup_ddp(rank, worldsize)
# Setup metric for ddp
lightning_metric = lightning_metric()
for test_input in test_inputs:
# rank 0 receives sample 0,2,4,...
# rank 1 receives sample 1,3,5,...
lightning_val = lightning_metric(*[ti[rank::2] for ti in test_input])
comparing_val = comparing_metric(*[ti.numpy() for ti in reversed(test_input)])
comparing_fn(lightning_val, comparing_val)
@pytest.mark.skipif(sys.platform == "win32", reason="DDP not available on windows")
@pytest.mark.parametrize("test", TESTS, ids=_idsfn)
def test_ddp(test):
"""Make sure that metrics are correctly sync and reduced in DDP mode"""
tutils.reset_seed()
tutils.set_random_master_port()
worldsize = 2
mp.spawn(_test_ddp_single_batch,
args=(worldsize,
test.lightning_metric,
test.comparing_metric,
test.test_input),
nprocs=worldsize)
@pytest.mark.parametrize("test", TESTS, ids=_idsfn)
def test_multi_batch(test):
""" test that aggregation works for multiple batches """
lightning_metric = test.lightning_metric()
comparing_metric = test.comparing_metric
for test_input in test.test_input:
for i in range(2): # for lightning device in 2 artificially batches
# first batch consist of samples 0,2,4,...
# second batch consist of samples 1,3,5,...
_ = lightning_metric(*[ti[i::2] for ti in test_input])
lightning_val = lightning_metric.aggregated
comparing_val = comparing_metric(*[ti.numpy() for ti in reversed(test_input)])
comparing_fn(lightning_val, comparing_val)
@pytest.mark.parametrize("test", TESTS, ids=_idsfn)
def test_multi_batch_unequal_sizes(test):
""" test that aggregation works for multiple batches with uneven sizes """
lightning_metric = test.lightning_metric()
comparing_metric = test.comparing_metric
for test_input in test.test_input:
for i in range(2): # for lightning device in 2 artificially batches
if i == 0: # allocate 3/4 of data to the first batch
_ = lightning_metric(*[ti[:int(3 / 4 * len(ti))] for ti in test_input])
else:
_ = lightning_metric(*[ti[int(3 / 4 * len(ti)):] for ti in test_input])
lightning_val = lightning_metric.aggregated
comparing_val = comparing_metric(*[ti.numpy() for ti in reversed(test_input)])
comparing_fn(lightning_val, comparing_val)
def _test_ddp_multi_batch(rank, worldsize, lightning_metric, comparing_metric, test_inputs):
""" ddp testing function, test that metric works with aggregation over multiple
devices and multiple batches """
_setup_ddp(rank, worldsize)
# Setup metric for ddp
lightning_metric = lightning_metric()
for test_input in test_inputs:
for i in range(2): # artificially divide samples between batches and processes
# rank 0, batch 0 consist of samples 0,4,8,...
# rank 0, batch 1 consist of samples 1,5,9,...
# rank 1, batch 0 consist of samples 2,6,10,...
# rank 1, batch 1 consist of samples 3,7,11,...
_ = lightning_metric(*[ti[i + worldsize * rank::4] for ti in test_input])
lightning_val = lightning_metric.aggregated
comparing_val = comparing_metric(*[ti.numpy() for ti in reversed(test_input)])
comparing_fn(lightning_val, comparing_val)
@pytest.mark.skipif(sys.platform == "win32", reason="DDP not available on windows")
@pytest.mark.parametrize("test", TESTS, ids=_idsfn)
def test_ddp_multi_batch(test):
""" test that aggregation works fine with in DDP mode and multiple batches """
tutils.reset_seed()
tutils.set_random_master_port()
worldsize = 2
mp.spawn(_test_ddp_multi_batch,
args=(worldsize,
test.lightning_metric,
test.comparing_metric,
test.test_input),
nprocs=worldsize)