Let TorchCollective works on the `torch.distributed` WORLD process group by default (#16995)
Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com> Co-authored-by: Carlos Mocholí <carlossmocholi@gmail.com> Co-authored-by: Jirka Borovec <6035284+Borda@users.noreply.github.com>
This commit is contained in:
parent
c886317c0c
commit
bb861cba7e
|
@ -30,6 +30,12 @@ class TorchCollective(Collective):
|
|||
raise RuntimeError("Torch distributed is not available.")
|
||||
super().__init__()
|
||||
|
||||
@property
|
||||
def group(self) -> CollectibleGroup:
|
||||
if self._group is None:
|
||||
self._group = dist.GroupMember.WORLD
|
||||
return super().group
|
||||
|
||||
@property
|
||||
def rank(self) -> int:
|
||||
# local rank
|
||||
|
@ -138,17 +144,20 @@ class TorchCollective(Collective):
|
|||
return self
|
||||
|
||||
def teardown(self) -> Self:
|
||||
non_group_member = self.group == dist.GroupMember.NON_GROUP_MEMBER
|
||||
group_member = self.group != dist.GroupMember.NON_GROUP_MEMBER
|
||||
super().teardown() # will destroy its own group
|
||||
# try to destroy the default group. this should only be done by a group member to avoid race conditions,
|
||||
# and only if the class is managing it
|
||||
if not non_group_member and TorchCollective.manages_default_group:
|
||||
default_group = dist.GroupMember.WORLD
|
||||
if default_group is not None: # not destroyed already
|
||||
group_map = dist.distributed_c10d._pg_map
|
||||
if len(group_map) == 1 and default_group in group_map: # only the default group is left
|
||||
self.destroy_group(default_group)
|
||||
TorchCollective.manages_default_group = False
|
||||
if (
|
||||
group_member
|
||||
and TorchCollective.manages_default_group
|
||||
and (default_group := dist.GroupMember.WORLD) is not None # not destroyed already
|
||||
and len(dist.distributed_c10d._pg_map) == 1 # only the default group is left
|
||||
):
|
||||
self.destroy_group(default_group)
|
||||
TorchCollective.manages_default_group = False
|
||||
elif TorchCollective.manages_default_group and dist.GroupMember.WORLD is None:
|
||||
TorchCollective.manages_default_group = False
|
||||
return self
|
||||
|
||||
@classmethod
|
||||
|
@ -171,7 +180,8 @@ class TorchCollective(Collective):
|
|||
def destroy_group(cls, group: CollectibleGroup) -> None:
|
||||
# can be called by all processes in the default group, group will be `object()` if they are not part of the
|
||||
# current group
|
||||
dist.destroy_process_group(group) # type: ignore[arg-type]
|
||||
if group in dist.distributed_c10d._pg_map:
|
||||
dist.destroy_process_group(group) # type: ignore[arg-type]
|
||||
|
||||
@classmethod
|
||||
def _convert_to_native_op(cls, op: Union[str, ReduceOp, RedOpType]) -> Union[ReduceOp, RedOpType]:
|
||||
|
|
|
@ -146,7 +146,7 @@ def test_convert_ops():
|
|||
def test_repeated_create_and_destroy():
|
||||
collective = TorchCollective()
|
||||
with mock.patch("torch.distributed.init_process_group"):
|
||||
collective.setup(main_address="foo", main_port=123)
|
||||
collective.setup(main_address="foo", main_port="123")
|
||||
|
||||
assert not os.environ
|
||||
|
||||
|
@ -157,7 +157,9 @@ def test_repeated_create_and_destroy():
|
|||
with pytest.raises(RuntimeError, match="TorchCollective` already owns a group"):
|
||||
collective.create_group()
|
||||
|
||||
with mock.patch("torch.distributed.destroy_process_group") as destroy_mock:
|
||||
with mock.patch.dict("torch.distributed.distributed_c10d._pg_map", {collective.group: ("", None)}), mock.patch(
|
||||
"torch.distributed.destroy_process_group"
|
||||
) as destroy_mock:
|
||||
collective.teardown()
|
||||
# this would be called twice if `init_process_group` wasn't patched. once for the group and once for the default
|
||||
# group
|
||||
|
@ -269,3 +271,38 @@ def _test_two_groups(strategy, left_collective, right_collective):
|
|||
@pytest.mark.skip(reason="TODO(carmocca): causing hangs in CI")
|
||||
def test_two_groups():
|
||||
collective_launch(_test_two_groups, [torch.device("cpu")] * 3, num_groups=2)
|
||||
|
||||
|
||||
def _test_default_process_group(strategy, *collectives):
|
||||
for collective in collectives:
|
||||
assert collective.group == torch.distributed.group.WORLD
|
||||
world_size = strategy.world_size
|
||||
for c in collectives:
|
||||
tensor = torch.tensor(world_size)
|
||||
r = c.all_reduce(tensor)
|
||||
assert world_size**2 == r
|
||||
|
||||
|
||||
@skip_distributed_unavailable
|
||||
@RunIf(skip_windows=True)
|
||||
@mock.patch.dict(os.environ, os.environ.copy(), clear=True) # sets CUDA_MODULE_LOADING in torch==1.13
|
||||
def test_default_process_group():
|
||||
collective_launch(_test_default_process_group, [torch.device("cpu")] * 3, num_groups=2)
|
||||
|
||||
|
||||
@skip_distributed_unavailable
|
||||
@mock.patch.dict(os.environ, {}, clear=True)
|
||||
def test_collective_manages_default_group():
|
||||
collective = TorchCollective()
|
||||
with mock.patch("torch.distributed.init_process_group"):
|
||||
collective.setup(main_address="foo", main_port="123")
|
||||
|
||||
assert TorchCollective.manages_default_group
|
||||
|
||||
with mock.patch.object(collective, "_group") as mock_group, mock.patch.dict(
|
||||
"torch.distributed.distributed_c10d._pg_map", {mock_group: ("", None)}
|
||||
), mock.patch("torch.distributed.destroy_process_group") as destroy_mock:
|
||||
collective.teardown()
|
||||
destroy_mock.assert_called_once_with(mock_group)
|
||||
|
||||
assert not TorchCollective.manages_default_group
|
||||
|
|
Loading…
Reference in New Issue