312 lines
13 KiB
Python
312 lines
13 KiB
Python
import contextlib
|
|
import datetime
|
|
import os
|
|
from functools import partial
|
|
from unittest import mock
|
|
|
|
import pytest
|
|
import torch
|
|
from lightning.fabric.accelerators import CPUAccelerator, CUDAAccelerator
|
|
from lightning.fabric.plugins.collectives import TorchCollective
|
|
from lightning.fabric.plugins.environments import LightningEnvironment
|
|
from lightning.fabric.strategies.ddp import DDPStrategy
|
|
from lightning.fabric.strategies.launchers.multiprocessing import _MultiProcessingLauncher
|
|
from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_1_13
|
|
|
|
from tests_fabric.helpers.runif import RunIf
|
|
|
|
if TorchCollective.is_available():
|
|
from torch.distributed import ReduceOp
|
|
else:
|
|
ReduceOp = mock.Mock()
|
|
|
|
skip_distributed_unavailable = pytest.mark.skipif(
|
|
not TorchCollective.is_available(), reason="torch.distributed unavailable"
|
|
)
|
|
|
|
PASSED_TENSOR = mock.Mock()
|
|
PASSED_OBJECT = mock.Mock()
|
|
|
|
|
|
@contextlib.contextmanager
|
|
def check_destroy_group():
|
|
with mock.patch(
|
|
"lightning.fabric.plugins.collectives.torch_collective.TorchCollective.new_group",
|
|
wraps=TorchCollective.new_group,
|
|
) as mock_new, mock.patch(
|
|
"lightning.fabric.plugins.collectives.torch_collective.TorchCollective.destroy_group",
|
|
wraps=TorchCollective.destroy_group,
|
|
) as mock_destroy:
|
|
yield
|
|
# 0 to account for tests that mock distributed
|
|
# -1 to account for destroying the default process group
|
|
expected = 0 if mock_new.call_count == 0 else mock_destroy.call_count - 1
|
|
assert mock_new.call_count == expected, f"new_group={mock_new.call_count}, destroy_group={mock_destroy.call_count}"
|
|
if TorchCollective.is_available():
|
|
assert not torch.distributed.distributed_c10d._pg_map
|
|
assert not TorchCollective.is_initialized()
|
|
|
|
|
|
@pytest.mark.parametrize(
|
|
("fn_name", "kwargs", "return_key"),
|
|
[
|
|
("send", {"tensor": PASSED_TENSOR, "dst": 0, "tag": 0}, None),
|
|
("recv", {"tensor": PASSED_TENSOR, "src": 0, "tag": 0}, "tensor"),
|
|
("broadcast", {"tensor": PASSED_TENSOR, "src": 0}, "tensor"),
|
|
("all_reduce", {"tensor": PASSED_TENSOR, "op": ReduceOp.SUM}, "tensor"),
|
|
("reduce", {"tensor": PASSED_TENSOR, "dst": 0, "op": ReduceOp.SUM}, "tensor"),
|
|
("all_gather", {"tensor_list": [PASSED_TENSOR], "tensor": PASSED_TENSOR}, "tensor_list"),
|
|
("gather", {"tensor": PASSED_TENSOR, "gather_list": [PASSED_TENSOR], "dst": 0}, "gather_list"),
|
|
("scatter", {"tensor": PASSED_TENSOR, "scatter_list": [PASSED_TENSOR], "src": 0}, "tensor"),
|
|
("reduce_scatter", {"output": PASSED_TENSOR, "input_list": [PASSED_TENSOR], "op": ReduceOp.SUM}, "output"),
|
|
(
|
|
"all_to_all",
|
|
{"output_tensor_list": [PASSED_TENSOR], "input_tensor_list": [PASSED_TENSOR]},
|
|
"output_tensor_list",
|
|
),
|
|
("barrier", {"device_ids": [0]}, None),
|
|
("all_gather_object", {"object_list": [PASSED_OBJECT], "obj": PASSED_OBJECT}, "object_list"),
|
|
(
|
|
"broadcast_object_list",
|
|
{"object_list": [PASSED_OBJECT], "src": 0, "device": torch.device("cpu")},
|
|
"object_list",
|
|
),
|
|
(
|
|
"gather_object",
|
|
{"obj": PASSED_OBJECT, "object_gather_list": [PASSED_OBJECT], "dst": 0},
|
|
"object_gather_list",
|
|
),
|
|
(
|
|
"scatter_object_list",
|
|
{"scatter_object_output_list": [PASSED_OBJECT], "scatter_object_input_list": [PASSED_OBJECT], "src": 0},
|
|
"scatter_object_output_list",
|
|
),
|
|
("monitored_barrier", {"timeout": datetime.timedelta(seconds=1), "wait_all_ranks": False}, None),
|
|
],
|
|
)
|
|
@skip_distributed_unavailable
|
|
def test_collective_calls_with_created_group(fn_name, kwargs, return_key):
|
|
collective = TorchCollective()
|
|
with mock.patch("torch.distributed.init_process_group"):
|
|
collective.setup()
|
|
with mock.patch("torch.distributed.new_group"):
|
|
collective.create_group()
|
|
fn = getattr(collective, fn_name)
|
|
with mock.patch(f"torch.distributed.{fn_name}", autospec=True) as mock_call:
|
|
result = fn(**kwargs)
|
|
mock_call.assert_called_once_with(**kwargs, group=collective.group)
|
|
if return_key is not None:
|
|
assert result == kwargs[return_key]
|
|
|
|
with mock.patch("torch.distributed.destroy_process_group"):
|
|
collective.teardown()
|
|
|
|
|
|
@skip_distributed_unavailable
|
|
def test_convert_ops():
|
|
# Test regular names
|
|
assert TorchCollective._convert_to_native_op("band") == ReduceOp.BAND
|
|
assert TorchCollective._convert_to_native_op("bor") == ReduceOp.BOR
|
|
assert TorchCollective._convert_to_native_op("bxor") == ReduceOp.BXOR
|
|
assert TorchCollective._convert_to_native_op("max") == ReduceOp.MAX
|
|
assert TorchCollective._convert_to_native_op("min") == ReduceOp.MIN
|
|
assert TorchCollective._convert_to_native_op("product") == ReduceOp.PRODUCT
|
|
assert TorchCollective._convert_to_native_op("sum") == ReduceOp.SUM
|
|
# Test we are passing through native ops without change
|
|
assert TorchCollective._convert_to_native_op(ReduceOp.BAND) == ReduceOp.BAND
|
|
assert TorchCollective._convert_to_native_op(ReduceOp.BOR) == ReduceOp.BOR
|
|
assert TorchCollective._convert_to_native_op(ReduceOp.BXOR) == ReduceOp.BXOR
|
|
assert TorchCollective._convert_to_native_op(ReduceOp.MAX) == ReduceOp.MAX
|
|
assert TorchCollective._convert_to_native_op(ReduceOp.MIN) == ReduceOp.MIN
|
|
assert TorchCollective._convert_to_native_op(ReduceOp.PRODUCT) == ReduceOp.PRODUCT
|
|
assert TorchCollective._convert_to_native_op(ReduceOp.SUM) == ReduceOp.SUM
|
|
# Test we are handling different casing properly
|
|
assert TorchCollective._convert_to_native_op("BOR") == ReduceOp.BOR
|
|
assert TorchCollective._convert_to_native_op("BoR") == ReduceOp.BOR
|
|
assert TorchCollective._convert_to_native_op("avg") == ReduceOp.AVG
|
|
|
|
# Test invalid type
|
|
with pytest.raises(ValueError, match="Unsupported op 1 of type int"):
|
|
TorchCollective._convert_to_native_op(1)
|
|
|
|
# Test invalid string
|
|
with pytest.raises(ValueError, match="op 'INVALID' is not a member of `Red"):
|
|
TorchCollective._convert_to_native_op("invalid")
|
|
|
|
# Test RedOpType
|
|
if _TORCH_GREATER_EQUAL_1_13:
|
|
assert TorchCollective._convert_to_native_op(ReduceOp.RedOpType.AVG) == ReduceOp.RedOpType.AVG
|
|
op = torch.distributed._make_nccl_premul_sum(2.0) # this returns a ReduceOp
|
|
assert TorchCollective._convert_to_native_op(op) == ReduceOp.PREMUL_SUM
|
|
assert TorchCollective._convert_to_native_op("premul_sum") == ReduceOp.PREMUL_SUM
|
|
|
|
|
|
@skip_distributed_unavailable
|
|
@mock.patch.dict(os.environ, {}, clear=True)
|
|
def test_repeated_create_and_destroy():
|
|
collective = TorchCollective()
|
|
with mock.patch("torch.distributed.init_process_group"):
|
|
collective.setup(main_address="foo", main_port="123")
|
|
|
|
assert not os.environ
|
|
|
|
with mock.patch("torch.distributed.new_group") as new_mock:
|
|
collective.create_group()
|
|
new_mock.assert_called_once()
|
|
|
|
with pytest.raises(RuntimeError, match="TorchCollective` already owns a group"):
|
|
collective.create_group()
|
|
|
|
with mock.patch.dict("torch.distributed.distributed_c10d._pg_map", {collective.group: ("", None)}), mock.patch(
|
|
"torch.distributed.destroy_process_group"
|
|
) as destroy_mock:
|
|
collective.teardown()
|
|
# this would be called twice if `init_process_group` wasn't patched. once for the group and once for the default
|
|
# group
|
|
destroy_mock.assert_called_once()
|
|
|
|
assert not os.environ
|
|
|
|
with pytest.raises(RuntimeError, match="TorchCollective` does not own a group"):
|
|
collective.teardown()
|
|
destroy_mock.assert_called_once_with(new_mock.return_value)
|
|
assert collective._group is None
|
|
|
|
with mock.patch("torch.distributed.new_group"), mock.patch("torch.distributed.destroy_process_group"):
|
|
# check we can create_group again. also chaining
|
|
collective.create_group().teardown()
|
|
|
|
|
|
def collective_launch(fn, parallel_devices, num_groups=1):
|
|
device_to_accelerator = {"cuda": CUDAAccelerator, "cpu": CPUAccelerator}
|
|
accelerator_cls = device_to_accelerator[parallel_devices[0].type]
|
|
strategy = DDPStrategy(
|
|
accelerator=accelerator_cls(),
|
|
parallel_devices=parallel_devices,
|
|
cluster_environment=LightningEnvironment(),
|
|
start_method="spawn",
|
|
)
|
|
launcher = _MultiProcessingLauncher(strategy=strategy)
|
|
collectives = [TorchCollective() for _ in range(num_groups)]
|
|
wrapped = partial(wrap_launch_function, fn, strategy, collectives)
|
|
return launcher.launch(wrapped, strategy, *collectives)
|
|
|
|
|
|
def wrap_launch_function(fn, strategy, collectives, *args, **kwargs):
|
|
strategy._set_world_ranks()
|
|
collectives[0].setup( # only one needs to setup
|
|
world_size=strategy.num_processes,
|
|
main_address="localhost",
|
|
backend=strategy._get_process_group_backend(),
|
|
rank=strategy.global_rank,
|
|
)
|
|
with check_destroy_group(): # manually use the fixture for the assertions
|
|
fn(*args, **kwargs)
|
|
# not necessary since they will be destroyed on process destruction, only added to fulfill the assertions
|
|
for c in collectives:
|
|
c.teardown()
|
|
|
|
|
|
def _test_distributed_collectives_fn(strategy, collective):
|
|
collective.create_group()
|
|
|
|
# all_gather
|
|
tensor_list = [torch.zeros(2, dtype=torch.long) for _ in range(strategy.num_processes)]
|
|
this = torch.arange(2, dtype=torch.long) + 2 * strategy.global_rank
|
|
out = collective.all_gather(tensor_list, this)
|
|
expected = torch.arange(2 * strategy.num_processes).split(2)
|
|
torch.testing.assert_close(tuple(out), expected)
|
|
|
|
# reduce
|
|
this = torch.tensor(strategy.global_rank + 1)
|
|
out = collective.reduce(this, dst=0, op="max")
|
|
expected = torch.tensor(strategy.num_processes) if strategy.global_rank == 0 else this
|
|
torch.testing.assert_close(out, expected)
|
|
|
|
# all_reduce
|
|
this = torch.tensor(strategy.global_rank + 1)
|
|
out = collective.all_reduce(this, op=ReduceOp.MIN)
|
|
expected = torch.tensor(1)
|
|
torch.testing.assert_close(out, expected)
|
|
|
|
|
|
@pytest.mark.skip(reason="test hangs too often")
|
|
@skip_distributed_unavailable
|
|
@pytest.mark.parametrize(
|
|
"n", [1, pytest.param(2, marks=[RunIf(skip_windows=True), pytest.mark.xfail(raises=TimeoutError, strict=False)])]
|
|
)
|
|
def test_collectives_distributed(n):
|
|
collective_launch(_test_distributed_collectives_fn, [torch.device("cpu")] * n)
|
|
|
|
|
|
def _test_distributed_collectives_cuda_fn(strategy, collective):
|
|
collective.create_group()
|
|
|
|
this = torch.tensor(1.5, device=strategy.root_device)
|
|
premul_sum = torch.distributed._make_nccl_premul_sum(2.0)
|
|
out = collective.all_reduce(this, op=premul_sum)
|
|
assert out == 3
|
|
|
|
|
|
@skip_distributed_unavailable
|
|
@RunIf(min_cuda_gpus=1, min_torch="1.13")
|
|
def test_collectives_distributed_cuda():
|
|
collective_launch(_test_distributed_collectives_cuda_fn, [torch.device("cuda")])
|
|
|
|
|
|
def _test_two_groups(strategy, left_collective, right_collective):
|
|
left_collective.create_group(ranks=[0, 1])
|
|
right_collective.create_group(ranks=[1, 2])
|
|
|
|
tensor = torch.tensor(strategy.global_rank)
|
|
if strategy.global_rank in (0, 1):
|
|
tensor = left_collective.all_reduce(tensor)
|
|
assert tensor == 1
|
|
right_collective.barrier() # avoids deadlock for global rank 1
|
|
if strategy.global_rank in (1, 2):
|
|
tensor = right_collective.all_reduce(tensor)
|
|
assert tensor == 3
|
|
|
|
|
|
@skip_distributed_unavailable
|
|
@pytest.mark.flaky(reruns=5)
|
|
@RunIf(skip_windows=True) # unhandled timeouts
|
|
@pytest.mark.xfail(raises=TimeoutError, strict=False)
|
|
def test_two_groups():
|
|
collective_launch(_test_two_groups, [torch.device("cpu")] * 3, num_groups=2)
|
|
|
|
|
|
def _test_default_process_group(strategy, *collectives):
|
|
for collective in collectives:
|
|
assert collective.group == torch.distributed.group.WORLD
|
|
world_size = strategy.world_size
|
|
for c in collectives:
|
|
tensor = torch.tensor(world_size)
|
|
r = c.all_reduce(tensor)
|
|
assert world_size**2 == r
|
|
|
|
|
|
@skip_distributed_unavailable
|
|
@pytest.mark.flaky(reruns=5)
|
|
@RunIf(skip_windows=True) # unhandled timeouts
|
|
def test_default_process_group():
|
|
collective_launch(_test_default_process_group, [torch.device("cpu")] * 3, num_groups=2)
|
|
|
|
|
|
@skip_distributed_unavailable
|
|
@mock.patch.dict(os.environ, {}, clear=True)
|
|
def test_collective_manages_default_group():
|
|
collective = TorchCollective()
|
|
with mock.patch("torch.distributed.init_process_group"):
|
|
collective.setup(main_address="foo", main_port="123")
|
|
|
|
assert TorchCollective.manages_default_group
|
|
|
|
with mock.patch.object(collective, "_group") as mock_group, mock.patch.dict(
|
|
"torch.distributed.distributed_c10d._pg_map", {mock_group: ("", None)}
|
|
), mock.patch("torch.distributed.destroy_process_group") as destroy_mock:
|
|
collective.teardown()
|
|
destroy_mock.assert_called_once_with(mock_group)
|
|
|
|
assert not TorchCollective.manages_default_group
|