Fix: gather_all_tensors cross GPUs in DDP (#3319)
* Fix: gather_all_tensors cross GPUs in metrics * add a test case for gather_all_tensors_ddp in #3253
This commit is contained in:
parent
ee72271d20
commit
d521c1b178
|
@ -301,7 +301,7 @@ def gather_all_tensors_if_available(result: Union[torch.Tensor],
|
||||||
|
|
||||||
world_size = torch.distributed.get_world_size(group)
|
world_size = torch.distributed.get_world_size(group)
|
||||||
|
|
||||||
gathered_result = world_size * [torch.zeros_like(result)]
|
gathered_result = [torch.zeros_like(result) for _ in range(world_size)]
|
||||||
|
|
||||||
# sync and broadcast all
|
# sync and broadcast all
|
||||||
torch.distributed.barrier(group=group)
|
torch.distributed.barrier(group=group)
|
||||||
|
|
|
@ -15,6 +15,7 @@ from pytorch_lightning.metrics.converters import (
|
||||||
_numpy_metric_conversion,
|
_numpy_metric_conversion,
|
||||||
_tensor_metric_conversion,
|
_tensor_metric_conversion,
|
||||||
sync_ddp_if_available,
|
sync_ddp_if_available,
|
||||||
|
gather_all_tensors_if_available,
|
||||||
tensor_metric,
|
tensor_metric,
|
||||||
numpy_metric
|
numpy_metric
|
||||||
)
|
)
|
||||||
|
@ -134,6 +135,17 @@ def _ddp_test_fn(rank, worldsize, add_offset: bool, reduction_mean=False):
|
||||||
'Sync-Reduce does not work properly with DDP and Tensors'
|
'Sync-Reduce does not work properly with DDP and Tensors'
|
||||||
|
|
||||||
|
|
||||||
|
def _ddp_test_gather_all_tensors(rank, worldsize):
|
||||||
|
_setup_ddp(rank, worldsize)
|
||||||
|
|
||||||
|
tensor = torch.tensor([rank])
|
||||||
|
gather_tensors = gather_all_tensors_if_available(tensor)
|
||||||
|
mannual_tensors = [torch.tensor([i]) for i in range(worldsize)]
|
||||||
|
|
||||||
|
for t1, t2 in zip(gather_tensors, mannual_tensors):
|
||||||
|
assert(t1.equal(t2))
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skipif(sys.platform == "win32" , reason="DDP not available on windows")
|
@pytest.mark.skipif(sys.platform == "win32" , reason="DDP not available on windows")
|
||||||
def test_sync_reduce_ddp():
|
def test_sync_reduce_ddp():
|
||||||
"""Make sure sync-reduce works with DDP"""
|
"""Make sure sync-reduce works with DDP"""
|
||||||
|
@ -164,6 +176,16 @@ def test_sync_reduce_simple():
|
||||||
'Sync-Reduce does not work properly without DDP and Tensors'
|
'Sync-Reduce does not work properly without DDP and Tensors'
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skipif(sys.platform == "win32" , reason="DDP not available on windows")
|
||||||
|
def test_gather_all_tensors_ddp():
|
||||||
|
"""Make sure gather_all_tensors works with DDP"""
|
||||||
|
tutils.reset_seed()
|
||||||
|
tutils.set_random_master_port()
|
||||||
|
|
||||||
|
worldsize = 2
|
||||||
|
mp.spawn(_ddp_test_gather_all_tensors, args=(worldsize, ), nprocs=worldsize)
|
||||||
|
|
||||||
|
|
||||||
def _test_tensor_metric(is_ddp: bool):
|
def _test_tensor_metric(is_ddp: bool):
|
||||||
@tensor_metric()
|
@tensor_metric()
|
||||||
def tensor_test_metric(*args, **kwargs):
|
def tensor_test_metric(*args, **kwargs):
|
||||||
|
|
Loading…
Reference in New Issue