Fix: gather_all_tensors cross GPUs in DDP (#3319)

* Fix: gather_all_tensors cross GPUs in metrics

* add a test case for gather_all_tensors_ddp in #3253
This commit is contained in:
HT Liu 2020-09-03 18:27:32 +08:00 committed by GitHub
parent ee72271d20
commit d521c1b178
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 23 additions and 1 deletions

View File

@ -301,7 +301,7 @@ def gather_all_tensors_if_available(result: Union[torch.Tensor],
world_size = torch.distributed.get_world_size(group)
gathered_result = world_size * [torch.zeros_like(result)]
gathered_result = [torch.zeros_like(result) for _ in range(world_size)]
# sync and broadcast all
torch.distributed.barrier(group=group)

View File

@ -15,6 +15,7 @@ from pytorch_lightning.metrics.converters import (
_numpy_metric_conversion,
_tensor_metric_conversion,
sync_ddp_if_available,
gather_all_tensors_if_available,
tensor_metric,
numpy_metric
)
@ -134,6 +135,17 @@ def _ddp_test_fn(rank, worldsize, add_offset: bool, reduction_mean=False):
'Sync-Reduce does not work properly with DDP and Tensors'
def _ddp_test_gather_all_tensors(rank, worldsize):
_setup_ddp(rank, worldsize)
tensor = torch.tensor([rank])
gather_tensors = gather_all_tensors_if_available(tensor)
mannual_tensors = [torch.tensor([i]) for i in range(worldsize)]
for t1, t2 in zip(gather_tensors, mannual_tensors):
assert(t1.equal(t2))
@pytest.mark.skipif(sys.platform == "win32" , reason="DDP not available on windows")
def test_sync_reduce_ddp():
"""Make sure sync-reduce works with DDP"""
@ -164,6 +176,16 @@ def test_sync_reduce_simple():
'Sync-Reduce does not work properly without DDP and Tensors'
@pytest.mark.skipif(sys.platform == "win32" , reason="DDP not available on windows")
def test_gather_all_tensors_ddp():
"""Make sure gather_all_tensors works with DDP"""
tutils.reset_seed()
tutils.set_random_master_port()
worldsize = 2
mp.spawn(_ddp_test_gather_all_tensors, args=(worldsize, ), nprocs=worldsize)
def _test_tensor_metric(is_ddp: bool):
@tensor_metric()
def tensor_test_metric(*args, **kwargs):