lightning/tests/tests_fabric/utilities/test_distributed.py

251 lines
9.3 KiB
Python

import functools
import os
from functools import partial
from pathlib import Path
from unittest import mock
from unittest.mock import Mock
import lightning.fabric
import pytest
import torch
from lightning.fabric.accelerators import CPUAccelerator, CUDAAccelerator, MPSAccelerator
from lightning.fabric.plugins.environments import LightningEnvironment
from lightning.fabric.strategies import DDPStrategy, SingleDeviceStrategy
from lightning.fabric.strategies.launchers.multiprocessing import _MultiProcessingLauncher
from lightning.fabric.utilities.distributed import (
_destroy_dist_connection,
_gather_all_tensors,
_InfiniteBarrier,
_init_dist_connection,
_is_dtensor,
_set_num_threads_if_needed,
_suggested_max_num_threads,
_sync_ddp,
is_shared_filesystem,
)
from lightning_utilities.core.imports import RequirementCache
from tests_fabric.helpers.runif import RunIf
def wrap_launch_function(fn, strategy, *args, **kwargs):
# the launcher does not manage this automatically. explanation available in:
# https://github.com/Lightning-AI/lightning/pull/14926#discussion_r982976718
strategy.setup_environment()
return fn(*args, **kwargs)
def spawn_launch(fn, parallel_devices):
"""Copied from ``tests_pytorch.core.test_results.spawn_launch``"""
# TODO: the accelerator and cluster_environment should be optional to just launch processes, but this requires lazy
# initialization to be implemented
device_to_accelerator = {"cuda": CUDAAccelerator, "mps": MPSAccelerator, "cpu": CPUAccelerator}
accelerator_cls = device_to_accelerator[parallel_devices[0].type]
strategy = DDPStrategy(
accelerator=accelerator_cls(),
parallel_devices=parallel_devices,
cluster_environment=LightningEnvironment(),
start_method="spawn",
)
launcher = _MultiProcessingLauncher(strategy=strategy)
wrapped = partial(wrap_launch_function, fn, strategy)
return launcher.launch(wrapped, strategy)
def _test_all_gather_uneven_tensors(strategy):
rank = strategy.local_rank
device = strategy.root_device
world_size = strategy.num_processes
tensor = torch.ones(rank, device=device)
result = _gather_all_tensors(tensor)
assert len(result) == world_size
for idx in range(world_size):
assert len(result[idx]) == idx
assert (result[idx] == torch.ones_like(result[idx])).all()
def _test_all_gather_uneven_tensors_multidim(strategy):
rank = strategy.local_rank
device = strategy.root_device
world_size = strategy.num_processes
tensor = torch.ones(rank + 1, 2 - rank, device=device)
result = _gather_all_tensors(tensor)
assert len(result) == world_size
for idx in range(world_size):
val = result[idx]
assert val.shape == (idx + 1, 2 - idx)
assert (val == torch.ones_like(val)).all()
def _test_all_reduce(strategy):
rank = strategy.local_rank
device = strategy.root_device
world_size = strategy.num_processes
for dtype in (torch.long, torch.int, torch.float, torch.half):
# max
tensor = torch.tensor(rank + 1, device=device, dtype=dtype)
expected = torch.tensor(2, device=device, dtype=dtype)
result = _sync_ddp(tensor, reduce_op="max")
assert torch.equal(result, expected)
assert result is tensor # inplace
# sum
tensor = torch.tensor(rank + 1, device=device, dtype=dtype)
expected = torch.tensor(sum(range(1, world_size + 1)), device=device, dtype=dtype)
result = _sync_ddp(tensor, reduce_op="sum")
assert torch.equal(result, expected)
assert result is tensor # inplace
# average
tensor = torch.tensor(rank + 1, device=device, dtype=dtype)
expected = torch.tensor(sum(range(1, world_size + 1)) / 2, device=device, dtype=dtype)
result = _sync_ddp(tensor, reduce_op="avg")
assert torch.equal(result, expected)
assert result is tensor # inplace
@RunIf(skip_windows=True)
@pytest.mark.parametrize(
"process",
[
_test_all_gather_uneven_tensors_multidim,
_test_all_gather_uneven_tensors,
_test_all_reduce,
],
)
@pytest.mark.parametrize(
"devices",
[
pytest.param([torch.device("cuda:0"), torch.device("cuda:1")], marks=RunIf(min_cuda_gpus=2)),
[torch.device("cpu"), torch.device("cpu")],
],
)
def test_collective_operations(devices, process):
spawn_launch(process, devices)
@pytest.mark.skipif(
RequirementCache("torch<2.4") and RequirementCache("numpy>=2.0"),
reason="torch.distributed not compatible with numpy>=2.0",
)
@pytest.mark.flaky(reruns=3) # flaky with "process 0 terminated with signal SIGABRT" (GLOO)
def test_is_shared_filesystem(tmp_path, monkeypatch):
# In the non-distributed case, every location is interpreted as 'shared'
assert is_shared_filesystem(SingleDeviceStrategy(torch.device("cpu")))
test_fn = functools.partial(_test_is_shared_filesystem, tmp_path=tmp_path, monkeypatch=monkeypatch)
spawn_launch(test_fn, [torch.device("cpu"), torch.device("cpu")])
def _test_is_shared_filesystem(strategy, tmp_path, monkeypatch):
# Path doesn't exist
with pytest.raises(FileNotFoundError, match="Unable to determine if the path belongs to a shared filesystem"):
is_shared_filesystem(strategy, path="not/exist")
# Path exists but not the same on all ranks
file = tmp_path / f"file-rank-{strategy.global_rank}"
file.touch()
folder = tmp_path / f"folder-rank-{strategy.global_rank}"
folder.mkdir()
assert not is_shared_filesystem(strategy, path=file)
assert not is_shared_filesystem(strategy, path=folder)
# Path exists
folder = tmp_path / "folder"
file = folder / "file"
if strategy.global_rank == 0:
folder.mkdir()
file.touch()
strategy.barrier()
assert folder.exists()
assert is_shared_filesystem(strategy, path=folder)
assert is_shared_filesystem(strategy, path=file)
assert os.listdir(folder) == ["file"] # rank test files got cleaned up
# Path defaults to CWD
monkeypatch.chdir(tmp_path)
assert Path.cwd() == tmp_path
assert is_shared_filesystem(strategy)
monkeypatch.undo()
# Path is a symlink
linked = Path(tmp_path / "linked")
if strategy.global_rank == 0:
linked.symlink_to(tmp_path / "folder", target_is_directory=True)
assert is_shared_filesystem(strategy, path=folder)
# Remote path is considered shared
assert is_shared_filesystem(strategy, path="s3://my-bucket/data")
@pytest.mark.parametrize("invalid", [-1, 0])
def test_suggested_max_num_threads(invalid):
with pytest.raises(ValueError, match="should be >= 1"):
_suggested_max_num_threads(invalid)
@mock.patch.dict(os.environ, {}, clear=True)
@mock.patch("lightning.fabric.utilities.distributed.torch.set_num_threads")
@mock.patch("lightning.fabric.utilities.distributed._num_cpus_available", return_value=4)
@pytest.mark.parametrize(("num_processes", "expected"), [(1, 4), (2, 2), (3, 1), (4, 1), (8, 1)])
def test_set_num_threads_if_needed(_, set_num_threads_mock, num_processes, expected):
assert "OMP_NUM_THREADS" not in os.environ
_set_num_threads_if_needed(num_processes)
set_num_threads_mock.assert_called_with(expected)
assert os.environ["OMP_NUM_THREADS"] == str(expected)
# if env variable is already set, no change
set_num_threads_mock.reset_mock()
_set_num_threads_if_needed(1)
set_num_threads_mock.assert_not_called()
assert os.environ["OMP_NUM_THREADS"] == str(expected)
def test_infinite_barrier():
# distributed not available
barrier = _InfiniteBarrier()
assert barrier.group is None
with mock.patch("lightning.fabric.utilities.distributed._distributed_is_initialized", return_value=False):
barrier.__enter__()
assert barrier.group is None
barrier()
barrier.__exit__(None, None, None)
assert barrier.group is None
# distributed available
barrier = _InfiniteBarrier()
with mock.patch(
"lightning.fabric.utilities.distributed._distributed_is_initialized", return_value=True
), mock.patch("lightning.fabric.utilities.distributed.torch.distributed") as dist_mock:
barrier.__enter__()
dist_mock.new_group.assert_called_once()
assert barrier.barrier == barrier.group.monitored_barrier
assert barrier.barrier.call_count == 0
barrier()
assert barrier.barrier.call_count == 1
barrier.__exit__(None, None, None)
assert barrier.barrier.call_count == 2
dist_mock.destroy_process_group.assert_called_once()
@mock.patch("lightning.fabric.utilities.distributed.atexit")
@mock.patch("lightning.fabric.utilities.distributed.torch.distributed.init_process_group")
def test_init_dist_connection_registers_destruction_handler(_, atexit_mock):
_init_dist_connection(LightningEnvironment(), "nccl")
atexit_mock.register.assert_called_once_with(_destroy_dist_connection)
atexit_mock.reset_mock()
_init_dist_connection(LightningEnvironment(), "gloo")
atexit_mock.register.assert_not_called()
@RunIf(min_torch="2.4")
def test_is_dtensor(monkeypatch):
from torch.distributed._tensor import DTensor
assert _is_dtensor(Mock(spec=DTensor))
assert not _is_dtensor(torch.zeros(2, 2))
monkeypatch.setattr(lightning.fabric.utilities.distributed, "_TORCH_GREATER_EQUAL_2_4", False)
assert not _is_dtensor(Mock(spec=DTensor))