# Copyright The Lightning AI team. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import os from datetime import timedelta from unittest import mock from unittest.mock import MagicMock, Mock import pytest import torch from torch.nn.parallel import DistributedDataParallel from lightning.fabric.plugins import DoublePrecision, HalfPrecision, Precision from lightning.fabric.plugins.environments import LightningEnvironment from lightning.fabric.strategies import DDPStrategy from lightning.fabric.strategies.ddp import _DDPBackwardSyncControl from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_0 from tests_fabric.helpers.runif import RunIf from tests_fabric.strategies.test_single_device import _MyFabricGradNorm, _MyFabricGradVal @pytest.mark.parametrize( ["process_group_backend", "device_str", "expected_process_group_backend"], [ pytest.param("foo", "cpu", "foo"), pytest.param("foo", "cuda:0", "foo"), pytest.param(None, "cuda:0", "nccl"), pytest.param(None, "cpu", "gloo"), ], ) def test_ddp_process_group_backend(process_group_backend, device_str, expected_process_group_backend): """Test settings for process group backend.""" class MockDDPStrategy(DDPStrategy): def __init__(self, root_device, process_group_backend): self._root_device = root_device super().__init__(process_group_backend=process_group_backend) @property def root_device(self): return self._root_device strategy = MockDDPStrategy(process_group_backend=process_group_backend, root_device=torch.device(device_str)) assert strategy._get_process_group_backend() == expected_process_group_backend def test_ddp_no_backward_sync(): """Test that the backward sync control calls `.no_sync()`, and only on a DDP-wrapped module.""" strategy = DDPStrategy() assert isinstance(strategy._backward_sync_control, _DDPBackwardSyncControl) with pytest.raises( TypeError, match="is only possible if the module passed to .* is wrapped in `DistributedDataParallel`" ), strategy._backward_sync_control.no_backward_sync(Mock()): pass module = MagicMock(spec=DistributedDataParallel) with strategy._backward_sync_control.no_backward_sync(module): pass module.no_sync.assert_called_once() @mock.patch("lightning.fabric.strategies.ddp.DistributedDataParallel") def test_ddp_extra_kwargs(ddp_mock): """Test that additional kwargs passed to the DDPStrategy get passed down to the DistributedDataParallel wrapper.""" module = torch.nn.Linear(1, 1) strategy = DDPStrategy(parallel_devices=[torch.device("cpu"), torch.device("cpu")]) strategy.setup_module(module) ddp_mock.assert_called_with(module=module, device_ids=None) ddp_mock.reset_mock() strategy = DDPStrategy(parallel_devices=[torch.device("cpu"), torch.device("cpu")], find_unused_parameters=True) strategy.setup_module(module) ddp_mock.assert_called_with(module=module, device_ids=None, find_unused_parameters=True) def test_ddp_module_state_dict(): """Test that the module state dict gets retrieved without the prefixed wrapper keys from DDP.""" class DistributedDataParallelMock(MagicMock): def __instancecheck__(self, instance): # to make the strategy's `isinstance(model, DistributedDataParallel)` pass with a mock as class return True strategy = DDPStrategy(parallel_devices=[torch.device("cpu"), torch.device("cpu")]) # Without DDP applied (no setup call) original_module = torch.nn.Linear(2, 3) assert strategy.get_module_state_dict(original_module).keys() == original_module.state_dict().keys() # With DDP applied (setup called) with mock.patch("lightning.fabric.strategies.ddp.DistributedDataParallel", DistributedDataParallelMock): wrapped_module = strategy.setup_module(original_module) assert strategy.get_module_state_dict(wrapped_module).keys() == original_module.state_dict().keys() @pytest.mark.parametrize( "clip_type,accelerator,precision", [ ("norm", "cpu", "32-true"), ("val", "cpu", "32-true"), ("norm", "cpu", "bf16-mixed"), ("val", "cpu", "bf16-mixed"), pytest.param("norm", "cuda", "32-true", marks=RunIf(min_cuda_gpus=2)), pytest.param("val", "cuda", "32-true", marks=RunIf(min_cuda_gpus=2)), pytest.param("norm", "cuda", "16-mixed", marks=RunIf(min_cuda_gpus=2)), pytest.param("val", "cuda", "16-mixed", marks=RunIf(min_cuda_gpus=2)), pytest.param("norm", "cuda", "bf16-mixed", marks=RunIf(min_cuda_gpus=2, bf16_cuda=True)), pytest.param("val", "cuda", "bf16-mixed", marks=RunIf(min_cuda_gpus=2, bf16_cuda=True)), ], ) @RunIf(standalone=True) def test_ddp_grad_clipping(clip_type, accelerator, precision): clipping_test_cls = _MyFabricGradNorm if clip_type == "norm" else _MyFabricGradVal fabric = clipping_test_cls(accelerator=accelerator, devices=2, precision=precision, strategy="ddp") fabric.run() @RunIf(min_cuda_gpus=2) @pytest.mark.parametrize( "precision,expected_dtype", [ (Precision(), torch.float32), (HalfPrecision("16-true"), torch.float16), pytest.param(HalfPrecision("bf16-true"), torch.bfloat16, marks=RunIf(bf16_cuda=True)), (DoublePrecision(), torch.float64), ], ) @mock.patch.dict(os.environ, {"LOCAL_RANK": "1"}) def test_module_init_context(precision, expected_dtype): """Test that the module under the init-context gets moved to the right device and dtype.""" parallel_devices = [torch.device("cuda", 0), torch.device("cuda", 1)] expected_device = parallel_devices[1] if _TORCH_GREATER_EQUAL_2_0 else torch.device("cpu") strategy = DDPStrategy( parallel_devices=parallel_devices, precision=precision, cluster_environment=LightningEnvironment() ) assert strategy.local_rank == 1 with strategy.module_init_context(): module = torch.nn.Linear(2, 2) assert module.weight.device == module.bias.device == expected_device assert module.weight.dtype == module.bias.dtype == expected_dtype @mock.patch("torch.distributed.init_process_group") def test_set_timeout(init_process_group_mock): """Test that the timeout gets passed to the ``torch.distributed.init_process_group`` function.""" test_timedelta = timedelta(seconds=30) strategy = DDPStrategy(timeout=test_timedelta, parallel_devices=[torch.device("cpu")]) strategy.cluster_environment = LightningEnvironment() strategy.accelerator = Mock() strategy.setup_environment() process_group_backend = strategy._get_process_group_backend() global_rank = strategy.cluster_environment.global_rank() world_size = strategy.cluster_environment.world_size() init_process_group_mock.assert_called_with( process_group_backend, rank=global_rank, world_size=world_size, timeout=test_timedelta )