lightning/tests/tests_fabric/plugins/collectives/test_torch_collective.py

312 lines
13 KiB
Python

import contextlib
import datetime
import os
from functools import partial
from unittest import mock
import pytest
import torch
from lightning.fabric.accelerators import CPUAccelerator, CUDAAccelerator
from lightning.fabric.plugins.collectives import TorchCollective
from lightning.fabric.plugins.environments import LightningEnvironment
from lightning.fabric.strategies.ddp import DDPStrategy
from lightning.fabric.strategies.launchers.multiprocessing import _MultiProcessingLauncher
from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_1_13
from tests_fabric.helpers.runif import RunIf
if TorchCollective.is_available():
from torch.distributed import ReduceOp
else:
ReduceOp = mock.Mock()
skip_distributed_unavailable = pytest.mark.skipif(
not TorchCollective.is_available(), reason="torch.distributed unavailable"
)
PASSED_TENSOR = mock.Mock()
PASSED_OBJECT = mock.Mock()
@contextlib.contextmanager
def check_destroy_group():
with mock.patch(
"lightning.fabric.plugins.collectives.torch_collective.TorchCollective.new_group",
wraps=TorchCollective.new_group,
) as mock_new, mock.patch(
"lightning.fabric.plugins.collectives.torch_collective.TorchCollective.destroy_group",
wraps=TorchCollective.destroy_group,
) as mock_destroy:
yield
# 0 to account for tests that mock distributed
# -1 to account for destroying the default process group
expected = 0 if mock_new.call_count == 0 else mock_destroy.call_count - 1
assert mock_new.call_count == expected, f"new_group={mock_new.call_count}, destroy_group={mock_destroy.call_count}"
if TorchCollective.is_available():
assert not torch.distributed.distributed_c10d._pg_map
assert not TorchCollective.is_initialized()
@pytest.mark.parametrize(
("fn_name", "kwargs", "return_key"),
[
("send", {"tensor": PASSED_TENSOR, "dst": 0, "tag": 0}, None),
("recv", {"tensor": PASSED_TENSOR, "src": 0, "tag": 0}, "tensor"),
("broadcast", {"tensor": PASSED_TENSOR, "src": 0}, "tensor"),
("all_reduce", {"tensor": PASSED_TENSOR, "op": ReduceOp.SUM}, "tensor"),
("reduce", {"tensor": PASSED_TENSOR, "dst": 0, "op": ReduceOp.SUM}, "tensor"),
("all_gather", {"tensor_list": [PASSED_TENSOR], "tensor": PASSED_TENSOR}, "tensor_list"),
("gather", {"tensor": PASSED_TENSOR, "gather_list": [PASSED_TENSOR], "dst": 0}, "gather_list"),
("scatter", {"tensor": PASSED_TENSOR, "scatter_list": [PASSED_TENSOR], "src": 0}, "tensor"),
("reduce_scatter", {"output": PASSED_TENSOR, "input_list": [PASSED_TENSOR], "op": ReduceOp.SUM}, "output"),
(
"all_to_all",
{"output_tensor_list": [PASSED_TENSOR], "input_tensor_list": [PASSED_TENSOR]},
"output_tensor_list",
),
("barrier", {"device_ids": [0]}, None),
("all_gather_object", {"object_list": [PASSED_OBJECT], "obj": PASSED_OBJECT}, "object_list"),
(
"broadcast_object_list",
{"object_list": [PASSED_OBJECT], "src": 0, "device": torch.device("cpu")},
"object_list",
),
(
"gather_object",
{"obj": PASSED_OBJECT, "object_gather_list": [PASSED_OBJECT], "dst": 0},
"object_gather_list",
),
(
"scatter_object_list",
{"scatter_object_output_list": [PASSED_OBJECT], "scatter_object_input_list": [PASSED_OBJECT], "src": 0},
"scatter_object_output_list",
),
("monitored_barrier", {"timeout": datetime.timedelta(seconds=1), "wait_all_ranks": False}, None),
],
)
@skip_distributed_unavailable
def test_collective_calls_with_created_group(fn_name, kwargs, return_key):
collective = TorchCollective()
with mock.patch("torch.distributed.init_process_group"):
collective.setup()
with mock.patch("torch.distributed.new_group"):
collective.create_group()
fn = getattr(collective, fn_name)
with mock.patch(f"torch.distributed.{fn_name}", autospec=True) as mock_call:
result = fn(**kwargs)
mock_call.assert_called_once_with(**kwargs, group=collective.group)
if return_key is not None:
assert result == kwargs[return_key]
with mock.patch("torch.distributed.destroy_process_group"):
collective.teardown()
@skip_distributed_unavailable
def test_convert_ops():
# Test regular names
assert TorchCollective._convert_to_native_op("band") == ReduceOp.BAND
assert TorchCollective._convert_to_native_op("bor") == ReduceOp.BOR
assert TorchCollective._convert_to_native_op("bxor") == ReduceOp.BXOR
assert TorchCollective._convert_to_native_op("max") == ReduceOp.MAX
assert TorchCollective._convert_to_native_op("min") == ReduceOp.MIN
assert TorchCollective._convert_to_native_op("product") == ReduceOp.PRODUCT
assert TorchCollective._convert_to_native_op("sum") == ReduceOp.SUM
# Test we are passing through native ops without change
assert TorchCollective._convert_to_native_op(ReduceOp.BAND) == ReduceOp.BAND
assert TorchCollective._convert_to_native_op(ReduceOp.BOR) == ReduceOp.BOR
assert TorchCollective._convert_to_native_op(ReduceOp.BXOR) == ReduceOp.BXOR
assert TorchCollective._convert_to_native_op(ReduceOp.MAX) == ReduceOp.MAX
assert TorchCollective._convert_to_native_op(ReduceOp.MIN) == ReduceOp.MIN
assert TorchCollective._convert_to_native_op(ReduceOp.PRODUCT) == ReduceOp.PRODUCT
assert TorchCollective._convert_to_native_op(ReduceOp.SUM) == ReduceOp.SUM
# Test we are handling different casing properly
assert TorchCollective._convert_to_native_op("BOR") == ReduceOp.BOR
assert TorchCollective._convert_to_native_op("BoR") == ReduceOp.BOR
assert TorchCollective._convert_to_native_op("avg") == ReduceOp.AVG
# Test invalid type
with pytest.raises(ValueError, match="Unsupported op 1 of type int"):
TorchCollective._convert_to_native_op(1)
# Test invalid string
with pytest.raises(ValueError, match="op 'INVALID' is not a member of `Red"):
TorchCollective._convert_to_native_op("invalid")
# Test RedOpType
if _TORCH_GREATER_EQUAL_1_13:
assert TorchCollective._convert_to_native_op(ReduceOp.RedOpType.AVG) == ReduceOp.RedOpType.AVG
op = torch.distributed._make_nccl_premul_sum(2.0) # this returns a ReduceOp
assert TorchCollective._convert_to_native_op(op) == ReduceOp.PREMUL_SUM
assert TorchCollective._convert_to_native_op("premul_sum") == ReduceOp.PREMUL_SUM
@skip_distributed_unavailable
@mock.patch.dict(os.environ, {}, clear=True)
def test_repeated_create_and_destroy():
collective = TorchCollective()
with mock.patch("torch.distributed.init_process_group"):
collective.setup(main_address="foo", main_port="123")
assert not os.environ
with mock.patch("torch.distributed.new_group") as new_mock:
collective.create_group()
new_mock.assert_called_once()
with pytest.raises(RuntimeError, match="TorchCollective` already owns a group"):
collective.create_group()
with mock.patch.dict("torch.distributed.distributed_c10d._pg_map", {collective.group: ("", None)}), mock.patch(
"torch.distributed.destroy_process_group"
) as destroy_mock:
collective.teardown()
# this would be called twice if `init_process_group` wasn't patched. once for the group and once for the default
# group
destroy_mock.assert_called_once()
assert not os.environ
with pytest.raises(RuntimeError, match="TorchCollective` does not own a group"):
collective.teardown()
destroy_mock.assert_called_once_with(new_mock.return_value)
assert collective._group is None
with mock.patch("torch.distributed.new_group"), mock.patch("torch.distributed.destroy_process_group"):
# check we can create_group again. also chaining
collective.create_group().teardown()
def collective_launch(fn, parallel_devices, num_groups=1):
device_to_accelerator = {"cuda": CUDAAccelerator, "cpu": CPUAccelerator}
accelerator_cls = device_to_accelerator[parallel_devices[0].type]
strategy = DDPStrategy(
accelerator=accelerator_cls(),
parallel_devices=parallel_devices,
cluster_environment=LightningEnvironment(),
start_method="spawn",
)
launcher = _MultiProcessingLauncher(strategy=strategy)
collectives = [TorchCollective() for _ in range(num_groups)]
wrapped = partial(wrap_launch_function, fn, strategy, collectives)
return launcher.launch(wrapped, strategy, *collectives)
def wrap_launch_function(fn, strategy, collectives, *args, **kwargs):
strategy._set_world_ranks()
collectives[0].setup( # only one needs to setup
world_size=strategy.num_processes,
main_address="localhost",
backend=strategy._get_process_group_backend(),
rank=strategy.global_rank,
)
with check_destroy_group(): # manually use the fixture for the assertions
fn(*args, **kwargs)
# not necessary since they will be destroyed on process destruction, only added to fulfill the assertions
for c in collectives:
c.teardown()
def _test_distributed_collectives_fn(strategy, collective):
collective.create_group()
# all_gather
tensor_list = [torch.zeros(2, dtype=torch.long) for _ in range(strategy.num_processes)]
this = torch.arange(2, dtype=torch.long) + 2 * strategy.global_rank
out = collective.all_gather(tensor_list, this)
expected = torch.arange(2 * strategy.num_processes).split(2)
torch.testing.assert_close(tuple(out), expected)
# reduce
this = torch.tensor(strategy.global_rank + 1)
out = collective.reduce(this, dst=0, op="max")
expected = torch.tensor(strategy.num_processes) if strategy.global_rank == 0 else this
torch.testing.assert_close(out, expected)
# all_reduce
this = torch.tensor(strategy.global_rank + 1)
out = collective.all_reduce(this, op=ReduceOp.MIN)
expected = torch.tensor(1)
torch.testing.assert_close(out, expected)
@pytest.mark.skip(reason="test hangs too often")
@skip_distributed_unavailable
@pytest.mark.parametrize(
"n", [1, pytest.param(2, marks=[RunIf(skip_windows=True), pytest.mark.xfail(raises=TimeoutError, strict=False)])]
)
def test_collectives_distributed(n):
collective_launch(_test_distributed_collectives_fn, [torch.device("cpu")] * n)
def _test_distributed_collectives_cuda_fn(strategy, collective):
collective.create_group()
this = torch.tensor(1.5, device=strategy.root_device)
premul_sum = torch.distributed._make_nccl_premul_sum(2.0)
out = collective.all_reduce(this, op=premul_sum)
assert out == 3
@skip_distributed_unavailable
@RunIf(min_cuda_gpus=1, min_torch="1.13")
def test_collectives_distributed_cuda():
collective_launch(_test_distributed_collectives_cuda_fn, [torch.device("cuda")])
def _test_two_groups(strategy, left_collective, right_collective):
left_collective.create_group(ranks=[0, 1])
right_collective.create_group(ranks=[1, 2])
tensor = torch.tensor(strategy.global_rank)
if strategy.global_rank in (0, 1):
tensor = left_collective.all_reduce(tensor)
assert tensor == 1
right_collective.barrier() # avoids deadlock for global rank 1
if strategy.global_rank in (1, 2):
tensor = right_collective.all_reduce(tensor)
assert tensor == 3
@skip_distributed_unavailable
@pytest.mark.flaky(reruns=5)
@RunIf(skip_windows=True) # unhandled timeouts
@pytest.mark.xfail(raises=TimeoutError, strict=False)
def test_two_groups():
collective_launch(_test_two_groups, [torch.device("cpu")] * 3, num_groups=2)
def _test_default_process_group(strategy, *collectives):
for collective in collectives:
assert collective.group == torch.distributed.group.WORLD
world_size = strategy.world_size
for c in collectives:
tensor = torch.tensor(world_size)
r = c.all_reduce(tensor)
assert world_size**2 == r
@skip_distributed_unavailable
@pytest.mark.flaky(reruns=5)
@RunIf(skip_windows=True) # unhandled timeouts
def test_default_process_group():
collective_launch(_test_default_process_group, [torch.device("cpu")] * 3, num_groups=2)
@skip_distributed_unavailable
@mock.patch.dict(os.environ, {}, clear=True)
def test_collective_manages_default_group():
collective = TorchCollective()
with mock.patch("torch.distributed.init_process_group"):
collective.setup(main_address="foo", main_port="123")
assert TorchCollective.manages_default_group
with mock.patch.object(collective, "_group") as mock_group, mock.patch.dict(
"torch.distributed.distributed_c10d._pg_map", {mock_group: ("", None)}
), mock.patch("torch.distributed.destroy_process_group") as destroy_mock:
collective.teardown()
destroy_mock.assert_called_once_with(mock_group)
assert not TorchCollective.manages_default_group