Auto convert to contiguous format for all_gather (#4907)

* convert memory format

* changelog

* formatting

* suggestions

* retrigger tests

Co-authored-by: Nicki Skafte <nugginea@gmail.com>
Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>
Co-authored-by: chaton <thomas@grid.ai>
This commit is contained in:
Nicki Skafte 2020-12-05 15:49:45 +01:00 committed by GitHub
parent 72349706c1
commit 1b40a4053d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 32 additions and 0 deletions

View File

@ -120,6 +120,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Fixed `LoggerConnector` to have logged metrics on root device in DP ([#4138](https://github.com/PyTorchLightning/pytorch-lightning/pull/4138))
- Auto convert tensors to contiguous format when `gather_all` ([#4907](https://github.com/PyTorchLightning/pytorch-lightning/pull/4907))
## [1.0.8] - 2020-11-24
### Added

View File

@ -89,6 +89,9 @@ def gather_all_tensors(result: Union[torch.Tensor], group: Optional[Any] = None)
if group is None:
group = torch.distributed.group.WORLD
# convert tensors to contiguous format
result = result.contiguous()
world_size = torch.distributed.get_world_size(group)
gathered_result = [torch.zeros_like(result) for _ in range(world_size)]

View File

@ -3,6 +3,7 @@ import sys
import pytest
import torch
from pytorch_lightning.metrics import Metric
from tests.metrics.test_metric import Dummy
from tests.metrics.utils import setup_ddp
@ -43,3 +44,28 @@ def _test_ddp_sum_cat(rank, worldsize):
@pytest.mark.parametrize("process", [_test_ddp_cat, _test_ddp_sum, _test_ddp_sum_cat])
def test_ddp(process):
torch.multiprocessing.spawn(process, args=(2,), nprocs=2)
def _test_non_contiguous_tensors(rank, worldsize):
setup_ddp(rank, worldsize)
class DummyMetric(Metric):
def __init__(self):
super().__init__()
self.add_state("x", default=[], dist_reduce_fx=None)
def update(self, x):
self.x.append(x)
def compute(self):
x = torch.cat(self.x, dim=0)
return x.sum()
metric = DummyMetric()
metric.update(torch.randn(10, 5)[:, 0])
@pytest.mark.skipif(sys.platform == "win32", reason="DDP not available on windows")
def test_non_contiguous_tensors():
""" Test that gather_all operation works for non contiguous tensors """
torch.multiprocessing.spawn(_test_non_contiguous_tensors, args=(2,), nprocs=2)