Auto convert to contiguous format for all_gather (#4907)
* convert memory format * changelog * formatting * suggestions * retrigger tests Co-authored-by: Nicki Skafte <nugginea@gmail.com> Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> Co-authored-by: chaton <thomas@grid.ai>
This commit is contained in:
parent
72349706c1
commit
1b40a4053d
|
@ -120,6 +120,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
|
|||
- Fixed `LoggerConnector` to have logged metrics on root device in DP ([#4138](https://github.com/PyTorchLightning/pytorch-lightning/pull/4138))
|
||||
|
||||
|
||||
- Auto convert tensors to contiguous format when `gather_all` ([#4907](https://github.com/PyTorchLightning/pytorch-lightning/pull/4907))
|
||||
|
||||
|
||||
## [1.0.8] - 2020-11-24
|
||||
|
||||
### Added
|
||||
|
|
|
@ -89,6 +89,9 @@ def gather_all_tensors(result: Union[torch.Tensor], group: Optional[Any] = None)
|
|||
if group is None:
|
||||
group = torch.distributed.group.WORLD
|
||||
|
||||
# convert tensors to contiguous format
|
||||
result = result.contiguous()
|
||||
|
||||
world_size = torch.distributed.get_world_size(group)
|
||||
|
||||
gathered_result = [torch.zeros_like(result) for _ in range(world_size)]
|
||||
|
|
|
@ -3,6 +3,7 @@ import sys
|
|||
import pytest
|
||||
import torch
|
||||
|
||||
from pytorch_lightning.metrics import Metric
|
||||
from tests.metrics.test_metric import Dummy
|
||||
from tests.metrics.utils import setup_ddp
|
||||
|
||||
|
@ -43,3 +44,28 @@ def _test_ddp_sum_cat(rank, worldsize):
|
|||
@pytest.mark.parametrize("process", [_test_ddp_cat, _test_ddp_sum, _test_ddp_sum_cat])
|
||||
def test_ddp(process):
|
||||
torch.multiprocessing.spawn(process, args=(2,), nprocs=2)
|
||||
|
||||
|
||||
def _test_non_contiguous_tensors(rank, worldsize):
|
||||
setup_ddp(rank, worldsize)
|
||||
|
||||
class DummyMetric(Metric):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.add_state("x", default=[], dist_reduce_fx=None)
|
||||
|
||||
def update(self, x):
|
||||
self.x.append(x)
|
||||
|
||||
def compute(self):
|
||||
x = torch.cat(self.x, dim=0)
|
||||
return x.sum()
|
||||
|
||||
metric = DummyMetric()
|
||||
metric.update(torch.randn(10, 5)[:, 0])
|
||||
|
||||
|
||||
@pytest.mark.skipif(sys.platform == "win32", reason="DDP not available on windows")
|
||||
def test_non_contiguous_tensors():
|
||||
""" Test that gather_all operation works for non contiguous tensors """
|
||||
torch.multiprocessing.spawn(_test_non_contiguous_tensors, args=(2,), nprocs=2)
|
||||
|
|
Loading…
Reference in New Issue