38 lines
1.2 KiB
Python
38 lines
1.2 KiB
Python
import pytest
|
|
import torch
|
|
import torch.distributed as dist
|
|
import torch.multiprocessing as mp
|
|
from pytorch_lightning.core.step_result import Result, TrainResult, EvalResult
|
|
import tests.base.develop_utils as tutils
|
|
import sys
|
|
|
|
|
|
def _setup_ddp(rank, worldsize):
|
|
import os
|
|
|
|
os.environ["MASTER_ADDR"] = "localhost"
|
|
|
|
# initialize the process group
|
|
dist.init_process_group("gloo", rank=rank, world_size=worldsize)
|
|
|
|
|
|
def _ddp_test_fn(rank, worldsize, result_cls: Result):
|
|
_setup_ddp(rank, worldsize)
|
|
tensor = torch.tensor([1.0])
|
|
|
|
res = result_cls()
|
|
res.log("test_tensor", tensor, sync_ddp=True, sync_ddp_op=torch.distributed.ReduceOp.SUM)
|
|
|
|
assert res["test_tensor"].item() == dist.get_world_size(), "Result-Log does not work properly with DDP and Tensors"
|
|
|
|
|
|
@pytest.mark.parametrize("result_cls", [Result, TrainResult, EvalResult])
|
|
@pytest.mark.skipif(sys.platform == "win32", reason="DDP not available on windows")
|
|
def test_result_reduce_ddp(result_cls):
|
|
"""Make sure result logging works with DDP"""
|
|
tutils.reset_seed()
|
|
tutils.set_random_master_port()
|
|
|
|
worldsize = 2
|
|
mp.spawn(_ddp_test_fn, args=(worldsize, result_cls), nprocs=worldsize)
|