67 lines
2.3 KiB
Python
67 lines
2.3 KiB
Python
import torch
|
|
import numpy as np
|
|
import os
|
|
import sys
|
|
import pytest
|
|
import pickle
|
|
|
|
NUM_PROCESSES = 2
|
|
NUM_BATCHES = 10
|
|
BATCH_SIZE = 16
|
|
|
|
|
|
def setup_ddp(rank, world_size):
|
|
os.environ["MASTER_ADDR"] = 'localhost'
|
|
os.environ['MASTER_PORT'] = '8088'
|
|
torch.distributed.init_process_group("gloo", rank=rank, world_size=world_size)
|
|
|
|
|
|
def _compute_batch(rank, preds, target, metric_class, sk_metric, ddp_sync_on_step, worldsize=1, metric_args={}):
|
|
|
|
metric = metric_class(compute_on_step=True, ddp_sync_on_step=ddp_sync_on_step, **metric_args)
|
|
|
|
# verify metrics work after being loaded from pickled state
|
|
pickled_metric = pickle.dumps(metric)
|
|
metric = pickle.loads(pickled_metric)
|
|
|
|
# Only use ddp if world size
|
|
if worldsize > 1:
|
|
setup_ddp(rank, worldsize)
|
|
|
|
for i in range(rank, NUM_BATCHES, worldsize):
|
|
batch_result = metric(preds[i], target[i])
|
|
|
|
if metric.ddp_sync_on_step:
|
|
if rank == 0:
|
|
ddp_preds = torch.stack([preds[i + r] for r in range(worldsize)])
|
|
ddp_target = torch.stack([target[i + r] for r in range(worldsize)])
|
|
sk_batch_result = sk_metric(ddp_preds, ddp_target)
|
|
assert np.allclose(batch_result.numpy(), sk_batch_result)
|
|
else:
|
|
sk_batch_result = sk_metric(preds[i], target[i])
|
|
assert np.allclose(batch_result.numpy(), sk_batch_result)
|
|
|
|
# check on all batches on all ranks
|
|
result = metric.compute()
|
|
assert isinstance(result, torch.Tensor)
|
|
|
|
total_preds = torch.stack([preds[i] for i in range(NUM_BATCHES)])
|
|
total_target = torch.stack([target[i] for i in range(NUM_BATCHES)])
|
|
sk_result = sk_metric(total_preds, total_target)
|
|
|
|
assert np.allclose(result.numpy(), sk_result)
|
|
|
|
|
|
def compute_batch(preds, target, metric_class, sk_metric, ddp_sync_on_step, ddp=False, metric_args={}):
|
|
if ddp:
|
|
if sys.platform == "win32":
|
|
pytest.skip("DDP not supported on windows")
|
|
|
|
torch.multiprocessing.spawn(
|
|
_compute_batch, args=(preds, target, metric_class, sk_metric, ddp_sync_on_step, NUM_PROCESSES, metric_args),
|
|
nprocs=NUM_PROCESSES
|
|
)
|
|
else:
|
|
# first args: rank, last args: world size
|
|
_compute_batch(0, preds, target, metric_class, sk_metric, ddp_sync_on_step, 1, metric_args)
|