lightning/tests/tests_lite/utilities/test_distributed.py

79 lines
2.8 KiB
Python

from functools import partial
import pytest
import torch
from tests_lite.helpers.runif import RunIf
from lightning_lite.accelerators import CPUAccelerator, CUDAAccelerator, MPSAccelerator
from lightning_lite.plugins.environments import LightningEnvironment
from lightning_lite.strategies import DDPSpawnStrategy
from lightning_lite.strategies.launchers.multiprocessing import _MultiProcessingLauncher
from lightning_lite.utilities.distributed import _gather_all_tensors
def wrap_launch_function(fn, strategy, *args, **kwargs):
# the launcher does not manage this automatically. explanation available in:
# https://github.com/Lightning-AI/lightning/pull/14926#discussion_r982976718
strategy.setup_environment()
return fn(*args, **kwargs)
def spawn_launch(fn, parallel_devices):
"""Copied from ``tests_pytorch.core.test_results.spawn_launch``"""
# TODO: the accelerator and cluster_environment should be optional to just launch processes, but this requires lazy
# initialization to be implemented
device_to_accelerator = {"cuda": CUDAAccelerator, "mps": MPSAccelerator, "cpu": CPUAccelerator}
accelerator_cls = device_to_accelerator[parallel_devices[0].type]
strategy = DDPSpawnStrategy(
accelerator=accelerator_cls(), parallel_devices=parallel_devices, cluster_environment=LightningEnvironment()
)
launcher = _MultiProcessingLauncher(strategy=strategy)
wrapped = partial(wrap_launch_function, fn, strategy)
return launcher.launch(wrapped, strategy)
def _test_all_gather_uneven_tensors(strategy):
rank = strategy.local_rank
device = strategy.root_device
world_size = strategy.num_processes
tensor = torch.ones(rank, device=device)
result = _gather_all_tensors(tensor)
assert len(result) == world_size
for idx in range(world_size):
assert len(result[idx]) == idx
assert (result[idx] == torch.ones_like(result[idx])).all()
def _test_all_gather_uneven_tensors_multidim(strategy):
rank = strategy.local_rank
device = strategy.root_device
world_size = strategy.num_processes
tensor = torch.ones(rank + 1, 2 - rank, device=device)
result = _gather_all_tensors(tensor)
assert len(result) == world_size
for idx in range(world_size):
val = result[idx]
assert val.shape == (idx + 1, 2 - idx)
assert (val == torch.ones_like(val)).all()
@RunIf(min_torch="1.10", skip_windows=True)
@pytest.mark.parametrize(
"process",
[
_test_all_gather_uneven_tensors_multidim,
_test_all_gather_uneven_tensors,
],
)
@pytest.mark.parametrize(
"devices",
[
pytest.param([torch.device("cuda:0"), torch.device("cuda:1")], marks=RunIf(min_cuda_gpus=2)),
[torch.device("cpu")] * 2,
],
)
def test_gather_all_tensors(devices, process):
spawn_launch(process, devices)