107 lines
4.6 KiB
Python
107 lines
4.6 KiB
Python
import torch
|
|
from model import Transformer
|
|
from torch.distributed._composable.fsdp import MixedPrecisionPolicy
|
|
from torch.distributed._composable.fsdp.fully_shard import fully_shard
|
|
from torch.distributed._tensor import Replicate, Shard
|
|
from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import checkpoint_wrapper
|
|
from torch.distributed.device_mesh import DeviceMesh
|
|
from torch.distributed.tensor.parallel import (
|
|
ColwiseParallel,
|
|
PrepareModuleInput,
|
|
RowwiseParallel,
|
|
SequenceParallel,
|
|
parallelize_module,
|
|
)
|
|
|
|
|
|
# Taken and modified from torchtitan
|
|
# https://github.com/pytorch/torchtitan/blob/main/torchtitan/parallelisms/parallelize_llama.py
|
|
def parallelize(model: Transformer, device_mesh: DeviceMesh) -> Transformer:
|
|
"""Apply parallelisms and activation checkpointing to the model.
|
|
|
|
NOTE: The passed-in model preferably should be on meta device. Otherwise,
|
|
the model must fit on GPU or CPU memory.
|
|
|
|
"""
|
|
|
|
dp_mesh = device_mesh["data_parallel"]
|
|
tp_mesh = device_mesh["tensor_parallel"]
|
|
|
|
if tp_mesh.size() > 1:
|
|
# 1. Parallelize the first embedding and the last linear proj layer
|
|
# 2. Parallelize the root norm layer over the sequence dim
|
|
# 3. Shard the first transformer block's inputs
|
|
|
|
# Parallelize the first embedding and the last linear out projection
|
|
plan = {
|
|
"tok_embeddings": RowwiseParallel(input_layouts=Replicate()),
|
|
"output": ColwiseParallel(
|
|
input_layouts=Shard(1),
|
|
# Optional: Shard the output along the class dimension to compute the loss in parallel.
|
|
# See `loss_parallel` in `train.py`
|
|
output_layouts=Shard(-1),
|
|
use_local_output=False,
|
|
),
|
|
"norm": SequenceParallel(),
|
|
"layers.0": PrepareModuleInput(
|
|
input_layouts=(Replicate(), None),
|
|
desired_input_layouts=(Shard(1), None),
|
|
use_local_output=True,
|
|
),
|
|
}
|
|
model = parallelize_module(model, tp_mesh, plan)
|
|
|
|
# Parallelize each transformer block
|
|
for transformer_block in model.layers.values():
|
|
plan = {
|
|
"attention": PrepareModuleInput(
|
|
input_layouts=(Shard(1), None),
|
|
desired_input_layouts=(Replicate(), None),
|
|
),
|
|
"attention.wq": ColwiseParallel(),
|
|
"attention.wk": ColwiseParallel(),
|
|
"attention.wv": ColwiseParallel(),
|
|
"attention.wo": RowwiseParallel(output_layouts=Shard(1)),
|
|
"attention_norm": SequenceParallel(),
|
|
"feed_forward": PrepareModuleInput(
|
|
input_layouts=(Shard(1),),
|
|
desired_input_layouts=(Replicate(),),
|
|
),
|
|
"feed_forward.w1": ColwiseParallel(),
|
|
"feed_forward.w2": RowwiseParallel(output_layouts=Shard(1)),
|
|
"feed_forward.w3": ColwiseParallel(),
|
|
"ffn_norm": SequenceParallel(),
|
|
}
|
|
|
|
# Adjust attention module to use the local number of heads
|
|
attn_layer = transformer_block.attention
|
|
attn_layer.n_heads = attn_layer.n_heads // tp_mesh.size()
|
|
attn_layer.n_kv_heads = attn_layer.n_kv_heads // tp_mesh.size()
|
|
|
|
# Apply the plan for the current transformer block
|
|
parallelize_module(transformer_block, tp_mesh, plan)
|
|
|
|
if dp_mesh.size() > 1:
|
|
assert dp_mesh.ndim == 1 # Hybrid-sharding not supported
|
|
|
|
# NOTE: Currently, the user is required to manually handle precision settings such as the `mp_policy` here
|
|
# because the model parallel strategy does not respect all settings of `Fabric(precision=...)` at the moment.
|
|
mp_policy = MixedPrecisionPolicy(param_dtype=torch.bfloat16, reduce_dtype=torch.float32)
|
|
|
|
fsdp_config = {"mesh": dp_mesh, "mp_policy": mp_policy}
|
|
for layer_id, transformer_block in model.layers.items():
|
|
# Apply activation checkpointing
|
|
transformer_block = checkpoint_wrapper(transformer_block)
|
|
# As an optimization, do not reshard after forward for the last
|
|
# transformer block since FSDP would prefetch it immediately
|
|
reshard_after_forward = int(layer_id) < len(model.layers) - 1
|
|
fully_shard(
|
|
transformer_block,
|
|
**fsdp_config,
|
|
reshard_after_forward=reshard_after_forward,
|
|
)
|
|
model.layers[layer_id] = transformer_block
|
|
model = fully_shard(model, **fsdp_config)
|
|
|
|
return model
|