diff --git a/src/lightning/fabric/plugins/collectives/torch_collective.py b/src/lightning/fabric/plugins/collectives/torch_collective.py index 10323379ed..05c2a5b4b2 100644 --- a/src/lightning/fabric/plugins/collectives/torch_collective.py +++ b/src/lightning/fabric/plugins/collectives/torch_collective.py @@ -30,6 +30,12 @@ class TorchCollective(Collective): raise RuntimeError("Torch distributed is not available.") super().__init__() + @property + def group(self) -> CollectibleGroup: + if self._group is None: + self._group = dist.GroupMember.WORLD + return super().group + @property def rank(self) -> int: # local rank @@ -138,17 +144,20 @@ class TorchCollective(Collective): return self def teardown(self) -> Self: - non_group_member = self.group == dist.GroupMember.NON_GROUP_MEMBER + group_member = self.group != dist.GroupMember.NON_GROUP_MEMBER super().teardown() # will destroy its own group # try to destroy the default group. this should only be done by a group member to avoid race conditions, # and only if the class is managing it - if not non_group_member and TorchCollective.manages_default_group: - default_group = dist.GroupMember.WORLD - if default_group is not None: # not destroyed already - group_map = dist.distributed_c10d._pg_map - if len(group_map) == 1 and default_group in group_map: # only the default group is left - self.destroy_group(default_group) - TorchCollective.manages_default_group = False + if ( + group_member + and TorchCollective.manages_default_group + and (default_group := dist.GroupMember.WORLD) is not None # not destroyed already + and len(dist.distributed_c10d._pg_map) == 1 # only the default group is left + ): + self.destroy_group(default_group) + TorchCollective.manages_default_group = False + elif TorchCollective.manages_default_group and dist.GroupMember.WORLD is None: + TorchCollective.manages_default_group = False return self @classmethod @@ -171,7 +180,8 @@ class TorchCollective(Collective): def destroy_group(cls, group: CollectibleGroup) -> None: # can be called by all processes in the default group, group will be `object()` if they are not part of the # current group - dist.destroy_process_group(group) # type: ignore[arg-type] + if group in dist.distributed_c10d._pg_map: + dist.destroy_process_group(group) # type: ignore[arg-type] @classmethod def _convert_to_native_op(cls, op: Union[str, ReduceOp, RedOpType]) -> Union[ReduceOp, RedOpType]: diff --git a/tests/tests_fabric/plugins/collectives/test_torch_collective.py b/tests/tests_fabric/plugins/collectives/test_torch_collective.py index ca9f5858d2..cedf6fb3df 100644 --- a/tests/tests_fabric/plugins/collectives/test_torch_collective.py +++ b/tests/tests_fabric/plugins/collectives/test_torch_collective.py @@ -146,7 +146,7 @@ def test_convert_ops(): def test_repeated_create_and_destroy(): collective = TorchCollective() with mock.patch("torch.distributed.init_process_group"): - collective.setup(main_address="foo", main_port=123) + collective.setup(main_address="foo", main_port="123") assert not os.environ @@ -157,7 +157,9 @@ def test_repeated_create_and_destroy(): with pytest.raises(RuntimeError, match="TorchCollective` already owns a group"): collective.create_group() - with mock.patch("torch.distributed.destroy_process_group") as destroy_mock: + with mock.patch.dict("torch.distributed.distributed_c10d._pg_map", {collective.group: ("", None)}), mock.patch( + "torch.distributed.destroy_process_group" + ) as destroy_mock: collective.teardown() # this would be called twice if `init_process_group` wasn't patched. once for the group and once for the default # group @@ -269,3 +271,38 @@ def _test_two_groups(strategy, left_collective, right_collective): @pytest.mark.skip(reason="TODO(carmocca): causing hangs in CI") def test_two_groups(): collective_launch(_test_two_groups, [torch.device("cpu")] * 3, num_groups=2) + + +def _test_default_process_group(strategy, *collectives): + for collective in collectives: + assert collective.group == torch.distributed.group.WORLD + world_size = strategy.world_size + for c in collectives: + tensor = torch.tensor(world_size) + r = c.all_reduce(tensor) + assert world_size**2 == r + + +@skip_distributed_unavailable +@RunIf(skip_windows=True) +@mock.patch.dict(os.environ, os.environ.copy(), clear=True) # sets CUDA_MODULE_LOADING in torch==1.13 +def test_default_process_group(): + collective_launch(_test_default_process_group, [torch.device("cpu")] * 3, num_groups=2) + + +@skip_distributed_unavailable +@mock.patch.dict(os.environ, {}, clear=True) +def test_collective_manages_default_group(): + collective = TorchCollective() + with mock.patch("torch.distributed.init_process_group"): + collective.setup(main_address="foo", main_port="123") + + assert TorchCollective.manages_default_group + + with mock.patch.object(collective, "_group") as mock_group, mock.patch.dict( + "torch.distributed.distributed_c10d._pg_map", {mock_group: ("", None)} + ), mock.patch("torch.distributed.destroy_process_group") as destroy_mock: + collective.teardown() + destroy_mock.assert_called_once_with(mock_group) + + assert not TorchCollective.manages_default_group