700 lines
28 KiB
Python
700 lines
28 KiB
Python
# Copyright The Lightning AI team.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
import os
|
|
from copy import deepcopy
|
|
from pathlib import Path
|
|
from unittest import mock
|
|
|
|
import pytest
|
|
import torch
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
from lightning.fabric import Fabric
|
|
from lightning.fabric.strategies.model_parallel import ModelParallelStrategy, _load_raw_module_state
|
|
from lightning.fabric.utilities.load import _load_distributed_checkpoint
|
|
from torch.utils.data import DataLoader, DistributedSampler
|
|
|
|
from tests_fabric.helpers.datasets import RandomDataset
|
|
from tests_fabric.helpers.runif import RunIf
|
|
|
|
|
|
class FeedForward(nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.w1 = nn.Linear(32, 64)
|
|
self.w2 = nn.Linear(32, 64)
|
|
self.w3 = nn.Linear(64, 32)
|
|
|
|
def forward(self, x):
|
|
return self.w3(F.silu(self.w1(x)) * self.w2(x))
|
|
|
|
|
|
def _parallelize_feed_forward_tp(model, device_mesh):
|
|
from torch.distributed.tensor.parallel import ColwiseParallel, RowwiseParallel, parallelize_module
|
|
|
|
tp_mesh = device_mesh["tensor_parallel"]
|
|
tp_plan = {
|
|
"w1": ColwiseParallel(),
|
|
"w2": ColwiseParallel(),
|
|
"w3": RowwiseParallel(),
|
|
}
|
|
parallelize_module(model, tp_mesh, tp_plan)
|
|
return model
|
|
|
|
|
|
def _parallelize_feed_forward_fsdp2(model, device_mesh):
|
|
from torch.distributed._composable.fsdp.fully_shard import fully_shard
|
|
|
|
dp_mesh = device_mesh["data_parallel"]
|
|
assert dp_mesh.ndim == 1 # Hybrid-sharding not supported
|
|
|
|
# Fully-shard each layer
|
|
fully_shard(model.w1, mesh=dp_mesh)
|
|
fully_shard(model.w2, mesh=dp_mesh)
|
|
fully_shard(model.w3, mesh=dp_mesh)
|
|
|
|
# TODO: Re-enable activation checkpointing
|
|
# Currently, state dict keys get prefixed with '_checkpoint_wrapper' in the keys
|
|
# which leads to mismatches when loading weights into a checkpoint-wrapped module.
|
|
# PyTorch should handle this automatically.
|
|
|
|
# model = checkpoint_wrapper(model)
|
|
|
|
return model
|
|
|
|
|
|
def _parallelize_feed_forward_fsdp2_tp(model, device_mesh):
|
|
model = _parallelize_feed_forward_tp(model, device_mesh)
|
|
model = _parallelize_feed_forward_fsdp2(model, device_mesh)
|
|
return model
|
|
|
|
|
|
@RunIf(min_torch="2.4", standalone=True, min_cuda_gpus=4)
|
|
def test_setup_device_mesh():
|
|
from torch.distributed.device_mesh import DeviceMesh
|
|
|
|
for dp_size, tp_size in ((1, 4), (4, 1), (2, 2)):
|
|
strategy = ModelParallelStrategy(
|
|
parallelize_fn=(lambda m, _: m),
|
|
data_parallel_size=dp_size,
|
|
tensor_parallel_size=tp_size,
|
|
)
|
|
fabric = Fabric(accelerator="auto", devices=4, strategy=strategy)
|
|
fabric.launch()
|
|
|
|
device_mesh = fabric.strategy.device_mesh
|
|
assert isinstance(device_mesh, DeviceMesh)
|
|
assert device_mesh.device_type == fabric.device.type
|
|
assert device_mesh.mesh_dim_names == ("data_parallel", "tensor_parallel")
|
|
assert device_mesh.size(0) == dp_size
|
|
assert device_mesh.size(1) == tp_size
|
|
assert device_mesh.ndim == 2
|
|
|
|
fabric.barrier()
|
|
|
|
# Passing "auto" will select internode and intranode dimensions automatically
|
|
strategy = ModelParallelStrategy(
|
|
parallelize_fn=(lambda m, _: m),
|
|
data_parallel_size="auto",
|
|
tensor_parallel_size="auto",
|
|
)
|
|
fabric = Fabric(accelerator="auto", devices=4, num_nodes=1, strategy=strategy)
|
|
fabric.launch()
|
|
assert fabric.strategy.device_mesh.mesh_dim_names == ("data_parallel", "tensor_parallel")
|
|
assert fabric.strategy.device_mesh.size(0) == 1
|
|
assert fabric.strategy.device_mesh.size(1) == 4
|
|
|
|
|
|
@RunIf(min_torch="2.4", standalone=True, min_cuda_gpus=2)
|
|
def test_tensor_parallel():
|
|
from torch.distributed._tensor import DTensor
|
|
|
|
strategy = ModelParallelStrategy(parallelize_fn=_parallelize_feed_forward_tp)
|
|
fabric = Fabric(accelerator="auto", devices=2, strategy=strategy)
|
|
fabric.launch()
|
|
|
|
fabric.seed_everything(0)
|
|
|
|
with fabric.init_module(empty_init=True):
|
|
model = FeedForward()
|
|
|
|
model = fabric.setup(model)
|
|
optimizer = torch.optim.AdamW(model.parameters())
|
|
optimizer = fabric.setup_optimizers(optimizer)
|
|
|
|
device_mesh = fabric.strategy.device_mesh
|
|
assert all(tensor.device_mesh == device_mesh["tensor_parallel"] for tensor in optimizer.param_groups[0]["params"])
|
|
assert all(isinstance(weight, DTensor) for weight in model.parameters())
|
|
assert model.w1.weight.device_mesh == device_mesh["tensor_parallel"]
|
|
|
|
dataset_size = 6
|
|
dataset = RandomDataset(32, dataset_size)
|
|
dataloader = DataLoader(dataset, batch_size=2)
|
|
dataloader = fabric.setup_dataloaders(dataloader)
|
|
|
|
# No data sharding, all GPUs get the same input inside a TP group
|
|
assert len(dataloader) == dataset_size // dataloader.batch_size
|
|
assert isinstance(dataloader.sampler, DistributedSampler)
|
|
|
|
for _, batch in enumerate(dataloader):
|
|
# All batches must be identical across TP group
|
|
batches = fabric.all_gather(batch)
|
|
assert all(torch.equal(batches[0], batches[i]) for i in range(1, len(batches)))
|
|
|
|
output = model(batch)
|
|
fabric.backward(output.sum())
|
|
assert isinstance(model.w1.weight.grad, DTensor)
|
|
assert model.w1.weight.grad.device_mesh == device_mesh["tensor_parallel"]
|
|
optimizer.step()
|
|
optimizer.zero_grad()
|
|
|
|
|
|
@RunIf(min_torch="2.4", standalone=True, min_cuda_gpus=4)
|
|
def test_fsdp2_tensor_parallel():
|
|
from torch.distributed._tensor import DTensor
|
|
|
|
strategy = ModelParallelStrategy(
|
|
parallelize_fn=_parallelize_feed_forward_fsdp2_tp,
|
|
data_parallel_size=2,
|
|
tensor_parallel_size=2,
|
|
)
|
|
fabric = Fabric(accelerator="auto", devices=4, strategy=strategy)
|
|
fabric.launch()
|
|
|
|
fabric.seed_everything(0)
|
|
|
|
with fabric.init_module(empty_init=True):
|
|
model = FeedForward()
|
|
|
|
model = fabric.setup(model)
|
|
optimizer = torch.optim.AdamW(model.parameters())
|
|
optimizer = fabric.setup_optimizers(optimizer)
|
|
|
|
assert all(isinstance(weight, DTensor) for weight in model.parameters())
|
|
assert all(isinstance(tensor, DTensor) for tensor in optimizer.param_groups[0]["params"])
|
|
assert model.w1.weight.device_mesh.ndim == 2
|
|
assert model.w1.weight.device_mesh.size(0) == 2
|
|
assert model.w1.weight.device_mesh.size(1) == 2
|
|
assert all(weight.device.type != "meta" for weight in model.parameters())
|
|
assert all(tensor.device_mesh.ndim == 2 for tensor in optimizer.param_groups[0]["params"])
|
|
assert all(tensor.device.type != "meta" for tensor in optimizer.param_groups[0]["params"])
|
|
|
|
dataset_size = 8
|
|
dataset = RandomDataset(32, dataset_size)
|
|
dataloader = DataLoader(dataset, batch_size=2)
|
|
dataloader = fabric.setup_dataloaders(dataloader)
|
|
|
|
# No data sharding across TP dimension, sharding across data-parallel dimension only
|
|
device_mesh = fabric.strategy.device_mesh
|
|
dp_mesh = device_mesh["data_parallel"]
|
|
tp_mesh = device_mesh["tensor_parallel"]
|
|
assert len(dataloader) == dataset_size // dataloader.batch_size // dp_mesh.size()
|
|
assert isinstance(dataloader.sampler, DistributedSampler)
|
|
|
|
for _, batch in enumerate(dataloader):
|
|
batches = fabric.all_gather(batch)
|
|
# Batches across the TP dimension must be identical
|
|
batches_tp = batches[tp_mesh.mesh]
|
|
assert all(torch.equal(batches_tp[0], batches_tp[i]) for i in range(1, len(batches_tp)))
|
|
# Batches across the DP dimension must be different
|
|
batches_dp = batches[dp_mesh.mesh]
|
|
assert all(not torch.equal(batches_dp[0], batches_dp[i]) for i in range(1, len(batches_dp)))
|
|
|
|
output = model(batch)
|
|
fabric.backward(output.sum())
|
|
assert isinstance(model.w1.weight.grad, DTensor)
|
|
assert model.w1.weight.grad.device_mesh == device_mesh
|
|
optimizer.step()
|
|
optimizer.zero_grad()
|
|
|
|
|
|
def _train(fabric, model=None, optimizer=None):
|
|
fabric.seed_everything(0)
|
|
|
|
if model is None:
|
|
with fabric.init_module(empty_init=True):
|
|
model = FeedForward()
|
|
model = fabric.setup(model)
|
|
if optimizer is None:
|
|
optimizer = torch.optim.AdamW(model.parameters())
|
|
optimizer = fabric.setup_optimizers(optimizer)
|
|
|
|
output = model(torch.rand(2, 32, device=fabric.device))
|
|
fabric.backward(output.sum())
|
|
optimizer.step()
|
|
optimizer.zero_grad()
|
|
return model, optimizer
|
|
|
|
|
|
@RunIf(min_torch="2.4", min_cuda_gpus=4, standalone=True)
|
|
@pytest.mark.parametrize(
|
|
"precision",
|
|
[
|
|
pytest.param("32-true"),
|
|
pytest.param("bf16-mixed", marks=RunIf(bf16_cuda=True)),
|
|
],
|
|
)
|
|
def test_train_save_load(precision, tmp_path):
|
|
"""Test 2D-parallel training, saving and loading precision settings."""
|
|
strategy = ModelParallelStrategy(
|
|
_parallelize_feed_forward_fsdp2_tp,
|
|
data_parallel_size=2,
|
|
tensor_parallel_size=2,
|
|
)
|
|
fabric = Fabric(accelerator="cuda", devices=4, strategy=strategy, precision=precision)
|
|
fabric.launch()
|
|
model, optimizer = _train(fabric)
|
|
|
|
checkpoint_path = fabric.broadcast(str(tmp_path / "dist-checkpoint"))
|
|
|
|
params_before = [p.full_tensor().clone() for p in model.parameters()]
|
|
state = {"model": model, "optimizer": optimizer, "steps": 1}
|
|
fabric.save(checkpoint_path, state)
|
|
assert set(os.listdir(checkpoint_path)) == {
|
|
".metadata",
|
|
"__0_0.distcp",
|
|
"__1_0.distcp",
|
|
"__2_0.distcp",
|
|
"__3_0.distcp",
|
|
"meta.pt",
|
|
}
|
|
|
|
# re-init all objects and resume
|
|
strategy = ModelParallelStrategy(
|
|
_parallelize_feed_forward_fsdp2_tp,
|
|
data_parallel_size=2,
|
|
tensor_parallel_size=2,
|
|
)
|
|
fabric = Fabric(accelerator="cuda", devices=4, strategy=strategy, precision=precision)
|
|
fabric.launch()
|
|
model, optimizer = _train(fabric)
|
|
|
|
# check correctness with loaded state
|
|
state = {"model": model, "optimizer": optimizer, "steps": 0}
|
|
metadata = fabric.load(checkpoint_path, state)
|
|
for p0, p1 in zip(params_before, (p.full_tensor() for p in model.parameters())):
|
|
torch.testing.assert_close(p0, p1, atol=0, rtol=0, equal_nan=True)
|
|
|
|
# check user data in state reloaded
|
|
assert state["steps"] == 1
|
|
assert not metadata
|
|
|
|
# attempt to load a key not in the metadata checkpoint
|
|
state = {"model": model, "coconut": 11}
|
|
with pytest.raises(KeyError, match="The requested state contains a key 'coconut' that does not exist"):
|
|
fabric.load(checkpoint_path, state)
|
|
|
|
# `strict=False` ignores the missing key
|
|
state = {"model": model, "coconut": 11}
|
|
fabric.load(checkpoint_path, state, strict=False)
|
|
assert state["coconut"] == 11
|
|
|
|
|
|
@pytest.mark.filterwarnings("ignore::FutureWarning")
|
|
@RunIf(min_torch="2.4", min_cuda_gpus=2, standalone=True)
|
|
def test_save_full_state_dict(tmp_path):
|
|
"""Test that ModelParallelStrategy saves the full state into a single file with
|
|
`save_distributed_checkpoint=False`."""
|
|
from torch.distributed.checkpoint.state_dict import get_optimizer_state_dict
|
|
|
|
strategy = ModelParallelStrategy(
|
|
_parallelize_feed_forward_fsdp2,
|
|
data_parallel_size=2,
|
|
tensor_parallel_size=1,
|
|
save_distributed_checkpoint=False,
|
|
)
|
|
fabric = Fabric(accelerator="cuda", strategy=strategy, devices=2)
|
|
fabric.launch()
|
|
model, optimizer = _train(fabric)
|
|
|
|
checkpoint_path = Path(fabric.broadcast(str(tmp_path / "fsdp-checkpoint.pt")))
|
|
state = {"model": model, "optimizer": optimizer, "steps": 1}
|
|
fabric.save(checkpoint_path, state)
|
|
|
|
checkpoint = torch.load(checkpoint_path, weights_only=True)
|
|
assert checkpoint["steps"] == 1
|
|
loaded_state_dict = checkpoint["model"]
|
|
|
|
# assert the correct state model was saved
|
|
state_dict = model.state_dict()
|
|
assert set(loaded_state_dict.keys()) == set(state_dict.keys())
|
|
for param_name in state_dict:
|
|
assert torch.equal(loaded_state_dict[param_name], state_dict[param_name].full_tensor().cpu())
|
|
params_before = [p.full_tensor().cpu() for p in model.parameters()]
|
|
|
|
# assert the correct optimizer state was saved
|
|
optimizer_state_before = get_optimizer_state_dict(model, optimizer)
|
|
assert set(checkpoint["optimizer"].keys()) == set(optimizer_state_before.keys()) == {"state", "param_groups"}
|
|
|
|
# 1. verify the FSDP state can be loaded back into a FSDP model/strategy directly
|
|
strategy = ModelParallelStrategy(_parallelize_feed_forward_fsdp2, data_parallel_size=2, tensor_parallel_size=1)
|
|
fabric = Fabric(accelerator="cuda", strategy=strategy, devices=2)
|
|
fabric.launch()
|
|
model, optimizer = _train(fabric)
|
|
|
|
metadata = fabric.load(checkpoint_path, {"model": model, "optimizer": optimizer})
|
|
assert metadata == {"steps": 1}
|
|
|
|
params_after = [p.full_tensor() for p in model.parameters()]
|
|
assert all(torch.equal(p0.cpu(), p1.cpu()) for p0, p1 in zip(params_before, params_after))
|
|
|
|
optimizer_state_after = get_optimizer_state_dict(model, optimizer)
|
|
optimizer_state_after["param_groups"][0]["betas"] = tuple(optimizer_state_after["param_groups"][0]["betas"])
|
|
assert set(optimizer_state_after.keys()) == set(optimizer_state_before.keys()) == {"state", "param_groups"}
|
|
torch.testing.assert_close(optimizer_state_after["state"], optimizer_state_before["state"], atol=0, rtol=0)
|
|
assert optimizer_state_after["param_groups"] == optimizer_state_before["param_groups"]
|
|
|
|
# run a step to verify the optimizer state is correct
|
|
_train(fabric, model, optimizer)
|
|
|
|
# 2. verify the FSDP state can be loaded back into a single-device model/strategy
|
|
fabric = Fabric(accelerator="cpu", devices=1)
|
|
model, optimizer = _train(fabric)
|
|
metadata = fabric.load(checkpoint_path, {"model": model, "optimizer": optimizer})
|
|
assert metadata == {"steps": 1}
|
|
params_after = list(model.parameters())
|
|
assert all(torch.equal(p0, p1) for p0, p1 in zip(params_before, params_after))
|
|
|
|
# get optimizer state after loading
|
|
normal_checkpoint_path = Path(fabric.broadcast(str(tmp_path / "normal-checkpoint.pt")))
|
|
fabric.save(normal_checkpoint_path, {"model": model, "optimizer": optimizer, "steps": 2})
|
|
|
|
optimizer_state_after = torch.load(normal_checkpoint_path, weights_only=True)["optimizer"]
|
|
assert set(optimizer_state_after.keys()) == set(optimizer_state_before.keys()) == {"state", "param_groups"}
|
|
assert torch.equal(
|
|
optimizer_state_after["state"][0]["exp_avg"],
|
|
optimizer_state_before["state"]["_forward_module.w1.weight"]["exp_avg"].full_tensor().cpu(),
|
|
)
|
|
|
|
# run a step to verify the optimizer state is correct
|
|
_train(fabric, model, optimizer)
|
|
|
|
# 3. verify that a single-device model/strategy states can be loaded into a FSDP model/strategy
|
|
strategy = ModelParallelStrategy(_parallelize_feed_forward_fsdp2, data_parallel_size=2, tensor_parallel_size=1)
|
|
fabric = Fabric(accelerator="cuda", strategy=strategy, devices=2)
|
|
fabric.launch()
|
|
model, optimizer = _train(fabric)
|
|
|
|
metadata = fabric.load(normal_checkpoint_path, {"model": model, "optimizer": optimizer})
|
|
assert metadata == {"steps": 2}
|
|
|
|
params_after = [p.full_tensor() for p in model.parameters()]
|
|
assert all(torch.equal(p0.cpu(), p1.cpu()) for p0, p1 in zip(params_before, params_after))
|
|
|
|
optimizer_state_after = get_optimizer_state_dict(model, optimizer)
|
|
optimizer_state_after["param_groups"][0]["betas"] = tuple(optimizer_state_after["param_groups"][0]["betas"])
|
|
assert set(optimizer_state_after.keys()) == set(optimizer_state_before.keys()) == {"state", "param_groups"}
|
|
torch.testing.assert_close(optimizer_state_after["state"], optimizer_state_before["state"], atol=0, rtol=0)
|
|
assert optimizer_state_after["param_groups"] == optimizer_state_before["param_groups"]
|
|
|
|
# run a step to verify the optimizer state is correct
|
|
_train(fabric, model, optimizer)
|
|
|
|
|
|
@pytest.mark.filterwarnings("ignore::FutureWarning")
|
|
@RunIf(min_torch="2.4", min_cuda_gpus=2, standalone=True)
|
|
def test_load_full_state_dict_into_sharded_model(tmp_path):
|
|
"""Test that the strategy can load a full-state checkpoint into a distributed model."""
|
|
fabric = Fabric(accelerator="cuda", devices=1)
|
|
fabric.seed_everything(0)
|
|
model, optimizer = _train(fabric)
|
|
|
|
# Save a full-state-dict checkpoint
|
|
checkpoint_path = Path(fabric.broadcast(str(tmp_path / "full-checkpoint.pt")))
|
|
state = {"model": model, "optimizer": optimizer, "steps": 1}
|
|
fabric.save(checkpoint_path, state)
|
|
|
|
# Gather all weights and store a copy manually
|
|
params_before = torch.cat([p.cpu().view(-1) for p in model.parameters()])
|
|
|
|
# Create a FSDP sharded model
|
|
strategy = ModelParallelStrategy(_parallelize_feed_forward_fsdp2, data_parallel_size=2, tensor_parallel_size=1)
|
|
fabric = Fabric(accelerator="cuda", strategy=strategy, devices=2)
|
|
fabric.launch()
|
|
model, optimizer = _train(fabric)
|
|
|
|
state = {"model": model, "optimizer": optimizer, "steps": 44}
|
|
fabric.load(checkpoint_path, state)
|
|
assert state["steps"] == 1
|
|
|
|
# Gather all weights and compare
|
|
params_after = torch.cat([p.full_tensor().cpu().view(-1) for p in model.parameters()])
|
|
assert torch.equal(params_before, params_after)
|
|
|
|
# Create a raw state-dict checkpoint to test `Fabric.load_raw` too
|
|
raw_checkpoint_path = checkpoint_path.with_name("model-state-dict")
|
|
if fabric.global_rank == 0:
|
|
checkpoint = torch.load(checkpoint_path, weights_only=True)
|
|
torch.save(checkpoint["model"], raw_checkpoint_path)
|
|
fabric.barrier()
|
|
|
|
_train(fabric, model, optimizer)
|
|
fabric.load_raw(raw_checkpoint_path, model)
|
|
|
|
# Gather all weights and compare
|
|
params_after = torch.cat([p.full_tensor().cpu().view(-1) for p in model.parameters()])
|
|
assert torch.equal(params_before, params_after)
|
|
|
|
|
|
@RunIf(min_torch="2.4", min_cuda_gpus=2, skip_windows=True, standalone=True)
|
|
@pytest.mark.parametrize("move_to_device", [True, False])
|
|
@mock.patch("lightning.fabric.wrappers._FabricModule")
|
|
def test_setup_module_move_to_device(fabric_module_mock, move_to_device):
|
|
"""Test that `move_to_device` does nothing, ModelParallel decides which device parameters get moved to which device
|
|
(sharding)."""
|
|
from torch.distributed._tensor import DTensor
|
|
|
|
strategy = ModelParallelStrategy(parallelize_fn=_parallelize_feed_forward_fsdp2)
|
|
fabric = Fabric(accelerator="cuda", devices=2, strategy=strategy)
|
|
fabric.launch()
|
|
|
|
model = FeedForward()
|
|
fabric_model = fabric.setup_module(model, move_to_device=move_to_device)
|
|
fabric_module_mock.assert_not_called()
|
|
|
|
# the linear layer got sharded and each part is on the expected device
|
|
assert fabric_model.w1.weight.device == torch.device("cuda", fabric.local_rank)
|
|
assert isinstance(fabric_model.w1.weight, DTensor)
|
|
|
|
# The _DeviceDtypeModuleMixin currently can't represent the device in a meaningful way for models with pieces on
|
|
# different devices
|
|
assert fabric_model.device == torch.device("cuda", fabric.local_rank)
|
|
assert fabric.device == torch.device("cuda", fabric.local_rank)
|
|
|
|
|
|
@RunIf(min_torch="2.4", min_cuda_gpus=2, skip_windows=True, standalone=True)
|
|
@pytest.mark.parametrize(
|
|
("precision", "expected_dtype"),
|
|
[
|
|
("32-true", torch.float32),
|
|
("16-true", torch.float16),
|
|
pytest.param("bf16-true", torch.bfloat16, marks=RunIf(bf16_cuda=True)),
|
|
],
|
|
)
|
|
def test_module_init_context(precision, expected_dtype):
|
|
"""Test that the module under the init-context gets moved to the right device and dtype."""
|
|
strategy = ModelParallelStrategy(parallelize_fn=_parallelize_feed_forward_fsdp2)
|
|
fabric = Fabric(accelerator="cuda", devices=2, strategy=strategy, precision=precision)
|
|
fabric.launch()
|
|
|
|
def _run_setup_assertions(empty_init, expected_device):
|
|
with fabric.init_module(empty_init=empty_init):
|
|
model = FeedForward()
|
|
|
|
# The model is on the CPU/meta-device until after `.setup()``
|
|
assert all(weight.device == expected_device for weight in model.parameters())
|
|
assert all(weight.dtype == expected_dtype for weight in model.parameters())
|
|
model = fabric.setup(model)
|
|
# Parameters get sharded in `.setup()` and moved to the target device
|
|
assert all(weight.device == torch.device("cuda", fabric.local_rank) for weight in model.parameters())
|
|
assert all(weight.dtype == expected_dtype for weight in model.parameters())
|
|
|
|
_run_setup_assertions(empty_init=False, expected_device=torch.device("cpu"))
|
|
_run_setup_assertions(empty_init=True, expected_device=torch.device("meta"))
|
|
|
|
|
|
@RunIf(min_torch="2.4", min_cuda_gpus=2, standalone=True)
|
|
def test_save_filter(tmp_path):
|
|
strategy = ModelParallelStrategy(
|
|
parallelize_fn=_parallelize_feed_forward_fsdp2,
|
|
save_distributed_checkpoint=False,
|
|
)
|
|
fabric = Fabric(accelerator="cuda", strategy=strategy, devices=2)
|
|
fabric.launch()
|
|
model = FeedForward()
|
|
model = fabric.setup_module(model)
|
|
|
|
tmp_path = Path(fabric.broadcast(str(tmp_path)))
|
|
state = {"model": model}
|
|
filter = {"model": lambda k, v: "bias" in k}
|
|
|
|
checkpoint_path = tmp_path / "full.pth"
|
|
fabric.save(checkpoint_path, state, filter=filter)
|
|
checkpoint = torch.load(checkpoint_path, weights_only=True)["model"]
|
|
assert set(checkpoint) == {"w1.bias", "w2.bias", "w3.bias"}
|
|
assert type(checkpoint["w1.bias"]) is torch.Tensor
|
|
|
|
fabric.strategy._save_distributed_checkpoint = True
|
|
checkpoint_path = tmp_path / "distributed"
|
|
with pytest.raises(NotImplementedError, match="doesn't support loading distributed filtered"):
|
|
fabric.save(checkpoint_path, state, filter=filter)
|
|
|
|
|
|
def _parallelize_single_linear_tp_fsdp2(model, device_mesh):
|
|
from torch.distributed._composable.fsdp.fully_shard import fully_shard
|
|
from torch.distributed.tensor.parallel import ColwiseParallel, parallelize_module
|
|
|
|
dp_mesh = device_mesh["data_parallel"]
|
|
tp_mesh = device_mesh["tensor_parallel"]
|
|
|
|
parallelize_module(model, tp_mesh, ColwiseParallel())
|
|
fully_shard(model, mesh=dp_mesh)
|
|
return model
|
|
|
|
|
|
@RunIf(min_torch="2.4", min_cuda_gpus=2, standalone=True)
|
|
@pytest.mark.parametrize(
|
|
"precision",
|
|
[
|
|
"32-true",
|
|
pytest.param("bf16-mixed", marks=RunIf(bf16_cuda=True)),
|
|
],
|
|
)
|
|
@pytest.mark.parametrize(
|
|
"clip_type",
|
|
[
|
|
pytest.param("norm", marks=pytest.mark.skip("Gradient clipping by norm is not correct.")),
|
|
"val",
|
|
],
|
|
)
|
|
def test_clip_gradients(clip_type, precision):
|
|
strategy = ModelParallelStrategy(_parallelize_single_linear_tp_fsdp2)
|
|
fabric = Fabric(accelerator="auto", devices=2, precision=precision, strategy=strategy)
|
|
fabric.launch()
|
|
|
|
in_features, out_features = 32, 2
|
|
model = torch.nn.Linear(in_features, out_features, bias=False)
|
|
model.weight.data.fill_(0.01)
|
|
|
|
model = fabric.setup(model)
|
|
optimizer = torch.optim.Adam(model.parameters(), lr=0.1)
|
|
optimizer = fabric.setup_optimizers(optimizer)
|
|
|
|
batch = torch.full((1, in_features), 0.1, device=fabric.device)
|
|
loss = model(batch).sum()
|
|
|
|
# The example is constructed such that the gradients are all the same
|
|
fabric.backward(loss)
|
|
|
|
if clip_type == "norm":
|
|
norm = torch.linalg.vector_norm(model.weight.grad.full_tensor().detach().cpu(), 2, dtype=torch.float32).item()
|
|
new_norm = norm / 10
|
|
fabric.clip_gradients(model, optimizer, max_norm=new_norm * 10)
|
|
assert torch.allclose(
|
|
torch.linalg.vector_norm(model.weight.grad.full_tensor().detach().cpu(), 2, dtype=torch.float32),
|
|
torch.tensor(new_norm),
|
|
)
|
|
elif clip_type == "val":
|
|
val = model.weight.grad.full_tensor()[0, 0].item()
|
|
new_val = val / 2.0
|
|
fabric.clip_gradients(model, optimizer, clip_val=new_val)
|
|
assert torch.allclose(
|
|
model.weight.grad.full_tensor(), torch.full_like(model.weight.grad.full_tensor(), new_val)
|
|
)
|
|
else:
|
|
raise AssertionError(f"Unknown clip type: {clip_type}")
|
|
|
|
optimizer.step()
|
|
optimizer.zero_grad()
|
|
|
|
|
|
@RunIf(min_torch="2.4", min_cuda_gpus=4, standalone=True)
|
|
def test_save_sharded_and_consolidate_and_load(tmp_path):
|
|
"""Test the consolidation of a distributed (DTensor) checkpoint into a single file."""
|
|
strategy = ModelParallelStrategy(
|
|
_parallelize_feed_forward_fsdp2_tp,
|
|
data_parallel_size=2,
|
|
tensor_parallel_size=2,
|
|
)
|
|
fabric = Fabric(accelerator="cuda", devices=4, strategy=strategy)
|
|
fabric.launch()
|
|
|
|
model = FeedForward()
|
|
model = fabric.setup(model)
|
|
optimizer = torch.optim.Adam(model.parameters())
|
|
optimizer = fabric.setup_optimizers(optimizer)
|
|
state = {"model": model, "optimizer": optimizer, "steps": 1}
|
|
|
|
# run one iteration to init the state of the optimizer
|
|
loss = model(torch.rand(1, 32, device=fabric.device)).sum()
|
|
fabric.backward(loss)
|
|
optimizer.step()
|
|
|
|
checkpoint_path_sharded = fabric.broadcast(str(tmp_path / "checkpoint_sharded"))
|
|
fabric.save(checkpoint_path_sharded, state)
|
|
assert set(os.listdir(checkpoint_path_sharded)) == {
|
|
".metadata",
|
|
"__0_0.distcp",
|
|
"__1_0.distcp",
|
|
"__2_0.distcp",
|
|
"__3_0.distcp",
|
|
"meta.pt",
|
|
}
|
|
|
|
# consolidate the checkpoint to a single file
|
|
checkpoint_path_full = fabric.broadcast(str(tmp_path / "checkpoint_full.pt"))
|
|
if fabric.global_rank == 0:
|
|
checkpoint = _load_distributed_checkpoint(Path(checkpoint_path_sharded))
|
|
torch.save(checkpoint, checkpoint_path_full)
|
|
fabric.barrier()
|
|
|
|
# re-init and load from full checkpoint
|
|
strategy = ModelParallelStrategy(
|
|
_parallelize_feed_forward_fsdp2_tp,
|
|
data_parallel_size=2,
|
|
tensor_parallel_size=2,
|
|
)
|
|
fabric = Fabric(accelerator="cuda", devices=4, strategy=strategy)
|
|
fabric.launch()
|
|
|
|
model = FeedForward()
|
|
model = fabric.setup(model)
|
|
optimizer = torch.optim.Adam(model.parameters())
|
|
optimizer = fabric.setup_optimizers(optimizer)
|
|
state = {"model": model, "optimizer": optimizer, "steps": 1}
|
|
fabric.load(checkpoint_path_full, state)
|
|
|
|
|
|
@RunIf(min_torch="2.4", min_cuda_gpus=2, standalone=True)
|
|
def test_load_raw_module_state():
|
|
from torch.distributed.device_mesh import init_device_mesh
|
|
from torch.distributed.tensor.parallel import ColwiseParallel, parallelize_module
|
|
|
|
class CustomModel(nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.parameter = nn.Parameter(torch.rand(2, 2))
|
|
self.layer1 = nn.Linear(4, 4)
|
|
self.layer2 = nn.Linear(4, 4)
|
|
self.register_buffer("persistent_buffer", torch.rand(2), persistent=True)
|
|
self.register_buffer("non_persistent_buffer", torch.rand(2), persistent=False)
|
|
|
|
fabric = Fabric(accelerator="cuda", devices=2)
|
|
fabric.launch()
|
|
fabric.seed_everything(0)
|
|
|
|
with fabric.init_module():
|
|
model = CustomModel()
|
|
|
|
state_dict = deepcopy(model.state_dict())
|
|
|
|
with fabric.init_module():
|
|
model = CustomModel()
|
|
|
|
device_mesh = init_device_mesh("cuda", mesh_shape=(2,), mesh_dim_names=("tp",))
|
|
plan = {"layer1": ColwiseParallel()}
|
|
parallelize_module(model, device_mesh, plan)
|
|
_load_raw_module_state(state_dict, model, strict=True)
|
|
|
|
assert torch.equal(model.parameter, state_dict["parameter"])
|
|
assert torch.equal(model.layer1.weight.full_tensor(), state_dict["layer1.weight"])
|
|
assert torch.equal(model.layer2.weight, state_dict["layer2.weight"])
|
|
assert torch.equal(model.persistent_buffer, state_dict["persistent_buffer"])
|
|
|
|
state_dict.pop("parameter")
|
|
with pytest.raises(KeyError, match="The model contains a key 'parameter' that does not exist"):
|
|
_load_raw_module_state(state_dict, model, strict=True)
|
|
|
|
_load_raw_module_state(state_dict, model, strict=False)
|