lightning/tests/tests_fabric/strategies/test_ddp_integration.py

130 lines
5.1 KiB
Python

# Copyright The Lightning AI team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
from copy import deepcopy
from unittest import mock
from unittest.mock import Mock
import pytest
import torch
from lightning.fabric import Fabric
from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_0
from torch.nn.parallel.distributed import DistributedDataParallel
from tests_fabric.helpers.runif import RunIf
from tests_fabric.strategies.test_single_device import _run_test_clip_gradients
from tests_fabric.test_fabric import BoringModel
@pytest.mark.parametrize(
"accelerator",
[
"cpu",
pytest.param("cuda", marks=RunIf(min_cuda_gpus=2)),
],
)
def test_ddp_save_load(accelerator, tmp_path):
"""Test that DDP model checkpoints can be saved and loaded successfully."""
fabric = Fabric(devices=2, accelerator=accelerator, strategy="ddp_spawn")
fabric.launch(_run_ddp_save_load, tmp_path)
def _run_ddp_save_load(fabric, tmp_path):
fabric.seed_everything(0)
tmp_path = fabric.broadcast(tmp_path)
model = torch.nn.Linear(2, 2)
params_before = deepcopy(list(model.parameters()))
# Save
fabric.save(tmp_path / "saved_before_setup.ckpt", {"model": model})
wrapped_model = fabric.setup(model)
fabric.save(tmp_path / "saved_after_setup.ckpt", {"model": wrapped_model})
def assert_params_equal(params0, params1):
assert all(torch.equal(p0, p1.to(p0.device)) for p0, p1 in zip(params0, params1))
# Load
model = torch.nn.Linear(2, 2)
fabric.load(tmp_path / "saved_before_setup.ckpt", {"model": model})
assert_params_equal(params_before, model.parameters())
fabric.load(tmp_path / "saved_after_setup.ckpt", {"model": model})
assert_params_equal(params_before, model.parameters())
wrapped_model = fabric.setup(model)
fabric.load(tmp_path / "saved_before_setup.ckpt", {"model": wrapped_model})
assert_params_equal(params_before, wrapped_model.parameters())
fabric.load(tmp_path / "saved_after_setup.ckpt", {"model": wrapped_model})
assert_params_equal(params_before, wrapped_model.parameters())
@RunIf(min_cuda_gpus=2, standalone=True, min_torch="2.1.0", dynamo=True)
@mock.patch(
"lightning.fabric.wrappers.torch.compile",
Mock(wraps=(torch.compile if _TORCH_GREATER_EQUAL_2_0 else None)),
)
@mock.patch.dict(os.environ, {})
def test_reapply_compile():
"""Test that Fabric can rewrap a compiled module such that compilation happens over the DDP-wrapper."""
from torch._dynamo import OptimizedModule
fabric = Fabric(accelerator="cuda", devices=2, strategy="ddp")
fabric.launch()
model = BoringModel()
compile_kwargs = {"mode": "reduce-overhead"}
compiled_model = torch.compile(model, **compile_kwargs)
torch.compile.reset_mock()
fabric_model = fabric.setup(compiled_model, _reapply_compile=True)
assert isinstance(fabric_model._forward_module, OptimizedModule)
assert isinstance(fabric_model._forward_module._orig_mod, DistributedDataParallel)
# Assert we called compile again with the same arguments, but on the DDP-wrapped module
torch.compile.assert_called_with(fabric_model._forward_module._orig_mod, **compile_kwargs)
assert fabric_model._original_module == model
assert fabric_model._forward_module._orig_mod.module == model
assert fabric_model.device == fabric.device
# Smoke-testing forward to ensure we don't get compilation errors
for _ in range(3):
fabric_model(torch.randn(2, 32, device=fabric.device)).sum().backward()
@pytest.mark.parametrize(
("clip_type", "accelerator", "precision"),
[
("norm", "cpu", "32-true"),
("val", "cpu", "32-true"),
("norm", "cpu", "bf16-mixed"),
("val", "cpu", "bf16-mixed"),
pytest.param("norm", "cuda", "32-true", marks=RunIf(min_cuda_gpus=2)),
pytest.param("val", "cuda", "32-true", marks=RunIf(min_cuda_gpus=2)),
pytest.param("norm", "cuda", "16-mixed", marks=RunIf(min_cuda_gpus=2)),
pytest.param("val", "cuda", "16-mixed", marks=RunIf(min_cuda_gpus=2)),
pytest.param("norm", "cuda", "bf16-mixed", marks=RunIf(min_cuda_gpus=2, bf16_cuda=True)),
pytest.param("val", "cuda", "bf16-mixed", marks=RunIf(min_cuda_gpus=2, bf16_cuda=True)),
],
)
@RunIf(standalone=True)
def test_clip_gradients(clip_type, accelerator, precision):
if clip_type == "norm" and precision == "16-mixed":
pytest.skip(reason="Clipping by norm with 16-mixed is numerically unstable.")
fabric = Fabric(accelerator=accelerator, devices=2, precision=precision, strategy="ddp")
fabric.launch()
_run_test_clip_gradients(fabric=fabric, clip_type=clip_type)