import contextlib import datetime import os from functools import partial from unittest import mock import pytest import torch from tests_fabric.helpers.runif import RunIf from lightning_fabric.accelerators import CPUAccelerator, CUDAAccelerator from lightning_fabric.plugins.collectives import TorchCollective from lightning_fabric.plugins.environments import LightningEnvironment from lightning_fabric.strategies.ddp import DDPStrategy from lightning_fabric.strategies.launchers.multiprocessing import _MultiProcessingLauncher from lightning_fabric.utilities.imports import _TORCH_GREATER_EQUAL_1_11, _TORCH_GREATER_EQUAL_1_13 if TorchCollective.is_available(): from torch.distributed import ReduceOp else: ReduceOp = mock.Mock() skip_distributed_unavailable = pytest.mark.skipif( not TorchCollective.is_available(), reason="torch.distributed unavailable" ) PASSED_TENSOR = mock.Mock() PASSED_OBJECT = mock.Mock() @contextlib.contextmanager def check_destroy_group(): with mock.patch( "lightning_fabric.plugins.collectives.torch_collective.TorchCollective.new_group", wraps=TorchCollective.new_group, ) as mock_new, mock.patch( "lightning_fabric.plugins.collectives.torch_collective.TorchCollective.destroy_group", wraps=TorchCollective.destroy_group, ) as mock_destroy: yield # 0 to account for tests that mock distributed # -1 to account for destroying the default process group expected = 0 if mock_new.call_count == 0 else mock_destroy.call_count - 1 assert mock_new.call_count == expected, f"new_group={mock_new.call_count}, destroy_group={mock_destroy.call_count}" if TorchCollective.is_available(): assert not torch.distributed.distributed_c10d._pg_map assert not TorchCollective.is_initialized() @pytest.mark.parametrize( ["fn_name", "kwargs", "return_key"], [ ("send", {"tensor": PASSED_TENSOR, "dst": 0, "tag": 0}, None), ("recv", {"tensor": PASSED_TENSOR, "src": 0, "tag": 0}, "tensor"), ("broadcast", {"tensor": PASSED_TENSOR, "src": 0}, "tensor"), ("all_reduce", {"tensor": PASSED_TENSOR, "op": ReduceOp.SUM}, "tensor"), ("reduce", {"tensor": PASSED_TENSOR, "dst": 0, "op": ReduceOp.SUM}, "tensor"), ("all_gather", {"tensor_list": [PASSED_TENSOR], "tensor": PASSED_TENSOR}, "tensor_list"), ("gather", {"tensor": PASSED_TENSOR, "gather_list": [PASSED_TENSOR], "dst": 0}, "gather_list"), ("scatter", {"tensor": PASSED_TENSOR, "scatter_list": [PASSED_TENSOR], "src": 0}, "tensor"), ("reduce_scatter", {"output": PASSED_TENSOR, "input_list": [PASSED_TENSOR], "op": ReduceOp.SUM}, "output"), ( "all_to_all", {"output_tensor_list": [PASSED_TENSOR], "input_tensor_list": [PASSED_TENSOR]}, "output_tensor_list", ), ("barrier", {"device_ids": [0]}, None), ("all_gather_object", {"object_list": [PASSED_OBJECT], "obj": PASSED_OBJECT}, "object_list"), ( "broadcast_object_list", {"object_list": [PASSED_OBJECT], "src": 0, "device": torch.device("cpu")}, "object_list", ), ( "gather_object", {"obj": PASSED_OBJECT, "object_gather_list": [PASSED_OBJECT], "dst": 0}, "object_gather_list", ), ( "scatter_object_list", {"scatter_object_output_list": [PASSED_OBJECT], "scatter_object_input_list": [PASSED_OBJECT], "src": 0}, "scatter_object_output_list", ), ("monitored_barrier", {"timeout": datetime.timedelta(seconds=1), "wait_all_ranks": False}, None), ], ) @skip_distributed_unavailable def test_collective_calls_with_created_group(fn_name, kwargs, return_key): collective = TorchCollective() with mock.patch("torch.distributed.init_process_group"): collective.setup() with mock.patch("torch.distributed.new_group"): collective.create_group() fn = getattr(collective, fn_name) with mock.patch(f"torch.distributed.{fn_name}", autospec=True) as mock_call: result = fn(**kwargs) mock_call.assert_called_once_with(**kwargs, group=collective.group) if return_key is not None: assert result == kwargs[return_key] with mock.patch("torch.distributed.destroy_process_group"): collective.teardown() @skip_distributed_unavailable def test_convert_ops(): # Test regular names assert TorchCollective._convert_to_native_op("band") == ReduceOp.BAND assert TorchCollective._convert_to_native_op("bor") == ReduceOp.BOR assert TorchCollective._convert_to_native_op("bxor") == ReduceOp.BXOR assert TorchCollective._convert_to_native_op("max") == ReduceOp.MAX assert TorchCollective._convert_to_native_op("min") == ReduceOp.MIN assert TorchCollective._convert_to_native_op("product") == ReduceOp.PRODUCT assert TorchCollective._convert_to_native_op("sum") == ReduceOp.SUM # Test we are passing through native ops without change assert TorchCollective._convert_to_native_op(ReduceOp.BAND) == ReduceOp.BAND assert TorchCollective._convert_to_native_op(ReduceOp.BOR) == ReduceOp.BOR assert TorchCollective._convert_to_native_op(ReduceOp.BXOR) == ReduceOp.BXOR assert TorchCollective._convert_to_native_op(ReduceOp.MAX) == ReduceOp.MAX assert TorchCollective._convert_to_native_op(ReduceOp.MIN) == ReduceOp.MIN assert TorchCollective._convert_to_native_op(ReduceOp.PRODUCT) == ReduceOp.PRODUCT assert TorchCollective._convert_to_native_op(ReduceOp.SUM) == ReduceOp.SUM # Test we are handling different casing properly assert TorchCollective._convert_to_native_op("BOR") == ReduceOp.BOR assert TorchCollective._convert_to_native_op("BoR") == ReduceOp.BOR # AVG is very recent! if _TORCH_GREATER_EQUAL_1_11: assert TorchCollective._convert_to_native_op("avg") == ReduceOp.AVG # Test invalid type with pytest.raises(ValueError, match="Unsupported op 1 of type int"): TorchCollective._convert_to_native_op(1) # Test invalid string with pytest.raises(ValueError, match="op 'INVALID' is not a member of `Red"): TorchCollective._convert_to_native_op("invalid") # Test RedOpType if _TORCH_GREATER_EQUAL_1_13: assert TorchCollective._convert_to_native_op(ReduceOp.RedOpType.AVG) == ReduceOp.RedOpType.AVG op = torch.distributed._make_nccl_premul_sum(2.0) # this returns a ReduceOp assert TorchCollective._convert_to_native_op(op) == ReduceOp.PREMUL_SUM assert TorchCollective._convert_to_native_op("premul_sum") == ReduceOp.PREMUL_SUM @skip_distributed_unavailable @mock.patch.dict(os.environ, {}, clear=True) def test_repeated_create_and_destroy(): collective = TorchCollective() with mock.patch("torch.distributed.init_process_group"): collective.setup(main_address="foo", main_port=123) assert not os.environ with mock.patch("torch.distributed.new_group") as new_mock: collective.create_group() new_mock.assert_called_once() with pytest.raises(RuntimeError, match="TorchCollective` already owns a group"): collective.create_group() with mock.patch("torch.distributed.destroy_process_group") as destroy_mock: collective.teardown() # this would be called twice if `init_process_group` wasn't patched. once for the group and once for the default # group destroy_mock.assert_called_once() assert not os.environ with pytest.raises(RuntimeError, match="TorchCollective` does not own a group"): collective.teardown() destroy_mock.assert_called_once_with(new_mock.return_value) assert collective._group is None with mock.patch("torch.distributed.new_group"), mock.patch("torch.distributed.destroy_process_group"): # check we can create_group again. also chaining collective.create_group().teardown() def collective_launch(fn, parallel_devices, num_groups=1): device_to_accelerator = {"cuda": CUDAAccelerator, "cpu": CPUAccelerator} accelerator_cls = device_to_accelerator[parallel_devices[0].type] strategy = DDPStrategy( accelerator=accelerator_cls(), parallel_devices=parallel_devices, cluster_environment=LightningEnvironment(), start_method="spawn", ) launcher = _MultiProcessingLauncher(strategy=strategy) collectives = [TorchCollective() for _ in range(num_groups)] wrapped = partial(wrap_launch_function, fn, strategy, collectives) return launcher.launch(wrapped, strategy, *collectives) def wrap_launch_function(fn, strategy, collectives, *args, **kwargs): strategy._set_world_ranks() collectives[0].setup( # only one needs to setup world_size=strategy.num_processes, main_address="localhost", backend=strategy._get_process_group_backend(), rank=strategy.global_rank, ) with check_destroy_group(): # manually use the fixture for the assertions fn(*args, **kwargs) # not necessary since they will be destroyed on process destruction, only added to fulfill the assertions for c in collectives: c.teardown() def _test_distributed_collectives_fn(strategy, collective): collective.create_group() # all_gather tensor_list = [torch.zeros(2, dtype=torch.long) for _ in range(strategy.num_processes)] this = torch.arange(2, dtype=torch.long) + 2 * strategy.global_rank out = collective.all_gather(tensor_list, this) expected = torch.arange(2 * strategy.num_processes).split(2) torch.testing.assert_close(tuple(out), expected) # reduce this = torch.tensor(strategy.global_rank + 1) out = collective.reduce(this, dst=0, op="max") expected = torch.tensor(strategy.num_processes) if strategy.global_rank == 0 else this torch.testing.assert_close(out, expected) # all_reduce this = torch.tensor(strategy.global_rank + 1) out = collective.all_reduce(this, op=ReduceOp.MIN) expected = torch.tensor(1) torch.testing.assert_close(out, expected) @skip_distributed_unavailable @pytest.mark.parametrize("n", (1, 2)) @mock.patch.dict(os.environ, os.environ.copy(), clear=True) # sets CUDA_MODULE_LOADING in torch==1.13 def test_collectives_distributed(n): collective_launch(_test_distributed_collectives_fn, [torch.device("cpu")] * n) def _test_distributed_collectives_cuda_fn(strategy, collective): collective.create_group() this = torch.tensor(1.5, device=strategy.root_device) premul_sum = torch.distributed._make_nccl_premul_sum(2.0) out = collective.all_reduce(this, op=premul_sum) assert out == 3 @skip_distributed_unavailable @RunIf(min_cuda_gpus=1, min_torch="1.13") def test_collectives_distributed_cuda(): collective_launch(_test_distributed_collectives_cuda_fn, [torch.device("cuda")]) def _test_two_groups(strategy, left_collective, right_collective): left_collective.create_group(ranks=[0, 1]) right_collective.create_group(ranks=[1, 2]) tensor = torch.tensor(strategy.global_rank) if strategy.global_rank in (0, 1): tensor = left_collective.all_reduce(tensor) assert tensor == 1 right_collective.barrier() # avoids deadlock for global rank 1 if strategy.global_rank in (1, 2): tensor = right_collective.all_reduce(tensor) assert tensor == 3 @skip_distributed_unavailable @pytest.mark.skip(reason="TODO(carmocca): causing hangs in CI") def test_two_groups(): collective_launch(_test_two_groups, [torch.device("cpu")] * 3, num_groups=2)