Let TorchCollective works on the `torch.distributed` WORLD process group by default (#16995)

Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com>
Co-authored-by: Carlos Mocholí <carlossmocholi@gmail.com>
Co-authored-by: Jirka Borovec <6035284+Borda@users.noreply.github.com>
This commit is contained in:
belerico 2023-03-21 00:30:27 +01:00 committed by GitHub
parent c886317c0c
commit bb861cba7e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 58 additions and 11 deletions

View File

@ -30,6 +30,12 @@ class TorchCollective(Collective):
raise RuntimeError("Torch distributed is not available.")
super().__init__()
@property
def group(self) -> CollectibleGroup:
if self._group is None:
self._group = dist.GroupMember.WORLD
return super().group
@property
def rank(self) -> int:
# local rank
@ -138,17 +144,20 @@ class TorchCollective(Collective):
return self
def teardown(self) -> Self:
non_group_member = self.group == dist.GroupMember.NON_GROUP_MEMBER
group_member = self.group != dist.GroupMember.NON_GROUP_MEMBER
super().teardown() # will destroy its own group
# try to destroy the default group. this should only be done by a group member to avoid race conditions,
# and only if the class is managing it
if not non_group_member and TorchCollective.manages_default_group:
default_group = dist.GroupMember.WORLD
if default_group is not None: # not destroyed already
group_map = dist.distributed_c10d._pg_map
if len(group_map) == 1 and default_group in group_map: # only the default group is left
self.destroy_group(default_group)
TorchCollective.manages_default_group = False
if (
group_member
and TorchCollective.manages_default_group
and (default_group := dist.GroupMember.WORLD) is not None # not destroyed already
and len(dist.distributed_c10d._pg_map) == 1 # only the default group is left
):
self.destroy_group(default_group)
TorchCollective.manages_default_group = False
elif TorchCollective.manages_default_group and dist.GroupMember.WORLD is None:
TorchCollective.manages_default_group = False
return self
@classmethod
@ -171,7 +180,8 @@ class TorchCollective(Collective):
def destroy_group(cls, group: CollectibleGroup) -> None:
# can be called by all processes in the default group, group will be `object()` if they are not part of the
# current group
dist.destroy_process_group(group) # type: ignore[arg-type]
if group in dist.distributed_c10d._pg_map:
dist.destroy_process_group(group) # type: ignore[arg-type]
@classmethod
def _convert_to_native_op(cls, op: Union[str, ReduceOp, RedOpType]) -> Union[ReduceOp, RedOpType]:

View File

@ -146,7 +146,7 @@ def test_convert_ops():
def test_repeated_create_and_destroy():
collective = TorchCollective()
with mock.patch("torch.distributed.init_process_group"):
collective.setup(main_address="foo", main_port=123)
collective.setup(main_address="foo", main_port="123")
assert not os.environ
@ -157,7 +157,9 @@ def test_repeated_create_and_destroy():
with pytest.raises(RuntimeError, match="TorchCollective` already owns a group"):
collective.create_group()
with mock.patch("torch.distributed.destroy_process_group") as destroy_mock:
with mock.patch.dict("torch.distributed.distributed_c10d._pg_map", {collective.group: ("", None)}), mock.patch(
"torch.distributed.destroy_process_group"
) as destroy_mock:
collective.teardown()
# this would be called twice if `init_process_group` wasn't patched. once for the group and once for the default
# group
@ -269,3 +271,38 @@ def _test_two_groups(strategy, left_collective, right_collective):
@pytest.mark.skip(reason="TODO(carmocca): causing hangs in CI")
def test_two_groups():
collective_launch(_test_two_groups, [torch.device("cpu")] * 3, num_groups=2)
def _test_default_process_group(strategy, *collectives):
for collective in collectives:
assert collective.group == torch.distributed.group.WORLD
world_size = strategy.world_size
for c in collectives:
tensor = torch.tensor(world_size)
r = c.all_reduce(tensor)
assert world_size**2 == r
@skip_distributed_unavailable
@RunIf(skip_windows=True)
@mock.patch.dict(os.environ, os.environ.copy(), clear=True) # sets CUDA_MODULE_LOADING in torch==1.13
def test_default_process_group():
collective_launch(_test_default_process_group, [torch.device("cpu")] * 3, num_groups=2)
@skip_distributed_unavailable
@mock.patch.dict(os.environ, {}, clear=True)
def test_collective_manages_default_group():
collective = TorchCollective()
with mock.patch("torch.distributed.init_process_group"):
collective.setup(main_address="foo", main_port="123")
assert TorchCollective.manages_default_group
with mock.patch.object(collective, "_group") as mock_group, mock.patch.dict(
"torch.distributed.distributed_c10d._pg_map", {mock_group: ("", None)}
), mock.patch("torch.distributed.destroy_process_group") as destroy_mock:
collective.teardown()
destroy_mock.assert_called_once_with(mock_group)
assert not TorchCollective.manages_default_group