38 lines
1.2 KiB
Python
38 lines
1.2 KiB
Python
|
import pytest
|
||
|
import torch
|
||
|
import torch.distributed as dist
|
||
|
import torch.multiprocessing as mp
|
||
|
from pytorch_lightning.core.step_result import Result, TrainResult, EvalResult
|
||
|
import tests.base.develop_utils as tutils
|
||
|
import sys
|
||
|
|
||
|
|
||
|
def _setup_ddp(rank, worldsize):
|
||
|
import os
|
||
|
|
||
|
os.environ["MASTER_ADDR"] = "localhost"
|
||
|
|
||
|
# initialize the process group
|
||
|
dist.init_process_group("gloo", rank=rank, world_size=worldsize)
|
||
|
|
||
|
|
||
|
def _ddp_test_fn(rank, worldsize, result_cls: Result):
|
||
|
_setup_ddp(rank, worldsize)
|
||
|
tensor = torch.tensor([1.0])
|
||
|
|
||
|
res = result_cls()
|
||
|
res.log("test_tensor", tensor, sync_ddp=True, sync_ddp_op=torch.distributed.ReduceOp.SUM)
|
||
|
|
||
|
assert res["test_tensor"].item() == dist.get_world_size(), "Result-Log does not work properly with DDP and Tensors"
|
||
|
|
||
|
|
||
|
@pytest.mark.parametrize("result_cls", [Result, TrainResult, EvalResult])
|
||
|
@pytest.mark.skipif(sys.platform == "win32", reason="DDP not available on windows")
|
||
|
def test_result_reduce_ddp(result_cls):
|
||
|
"""Make sure result logging works with DDP"""
|
||
|
tutils.reset_seed()
|
||
|
tutils.set_random_master_port()
|
||
|
|
||
|
worldsize = 2
|
||
|
mp.spawn(_ddp_test_fn, args=(worldsize, result_cls), nprocs=worldsize)
|