lightning/tests/metrics/utils.py

67 lines
2.3 KiB
Python

import torch
import numpy as np
import os
import sys
import pytest
import pickle
NUM_PROCESSES = 2
NUM_BATCHES = 10
BATCH_SIZE = 16
def setup_ddp(rank, world_size):
os.environ["MASTER_ADDR"] = 'localhost'
os.environ['MASTER_PORT'] = '8088'
torch.distributed.init_process_group("gloo", rank=rank, world_size=world_size)
def _compute_batch(rank, preds, target, metric_class, sk_metric, ddp_sync_on_step, worldsize=1, metric_args={}):
metric = metric_class(compute_on_step=True, ddp_sync_on_step=ddp_sync_on_step, **metric_args)
# verify metrics work after being loaded from pickled state
pickled_metric = pickle.dumps(metric)
metric = pickle.loads(pickled_metric)
# Only use ddp if world size
if worldsize > 1:
setup_ddp(rank, worldsize)
for i in range(rank, NUM_BATCHES, worldsize):
batch_result = metric(preds[i], target[i])
if metric.ddp_sync_on_step:
if rank == 0:
ddp_preds = torch.stack([preds[i + r] for r in range(worldsize)])
ddp_target = torch.stack([target[i + r] for r in range(worldsize)])
sk_batch_result = sk_metric(ddp_preds, ddp_target)
assert np.allclose(batch_result.numpy(), sk_batch_result)
else:
sk_batch_result = sk_metric(preds[i], target[i])
assert np.allclose(batch_result.numpy(), sk_batch_result)
# check on all batches on all ranks
result = metric.compute()
assert isinstance(result, torch.Tensor)
total_preds = torch.stack([preds[i] for i in range(NUM_BATCHES)])
total_target = torch.stack([target[i] for i in range(NUM_BATCHES)])
sk_result = sk_metric(total_preds, total_target)
assert np.allclose(result.numpy(), sk_result)
def compute_batch(preds, target, metric_class, sk_metric, ddp_sync_on_step, ddp=False, metric_args={}):
if ddp:
if sys.platform == "win32":
pytest.skip("DDP not supported on windows")
torch.multiprocessing.spawn(
_compute_batch, args=(preds, target, metric_class, sk_metric, ddp_sync_on_step, NUM_PROCESSES, metric_args),
nprocs=NUM_PROCESSES
)
else:
# first args: rank, last args: world size
_compute_batch(0, preds, target, metric_class, sk_metric, ddp_sync_on_step, 1, metric_args)