import torch from model import Transformer from torch.distributed._composable.fsdp import MixedPrecisionPolicy from torch.distributed._composable.fsdp.fully_shard import fully_shard from torch.distributed._tensor import Replicate, Shard from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import checkpoint_wrapper from torch.distributed.device_mesh import DeviceMesh from torch.distributed.tensor.parallel import ( ColwiseParallel, PrepareModuleInput, RowwiseParallel, SequenceParallel, parallelize_module, ) # Taken and modified from torchtitan # https://github.com/pytorch/torchtitan/blob/main/torchtitan/parallelisms/parallelize_llama.py def parallelize(model: Transformer, device_mesh: DeviceMesh) -> Transformer: """Apply parallelisms and activation checkpointing to the model. NOTE: The passed-in model preferably should be on meta device. Otherwise, the model must fit on GPU or CPU memory. """ dp_mesh = device_mesh["data_parallel"] tp_mesh = device_mesh["tensor_parallel"] if tp_mesh.size() > 1: # 1. Parallelize the first embedding and the last linear proj layer # 2. Parallelize the root norm layer over the sequence dim # 3. Shard the first transformer block's inputs # Parallelize the first embedding and the last linear out projection plan = { "tok_embeddings": RowwiseParallel(input_layouts=Replicate()), "output": ColwiseParallel( input_layouts=Shard(1), # Optional: Shard the output along the class dimension to compute the loss in parallel. # See `loss_parallel` in `train.py` output_layouts=Shard(-1), use_local_output=False, ), "norm": SequenceParallel(), "layers.0": PrepareModuleInput( input_layouts=(Replicate(), None), desired_input_layouts=(Shard(1), None), use_local_output=True, ), } model = parallelize_module(model, tp_mesh, plan) # Parallelize each transformer block for transformer_block in model.layers.values(): plan = { "attention": PrepareModuleInput( input_layouts=(Shard(1), None), desired_input_layouts=(Replicate(), None), ), "attention.wq": ColwiseParallel(), "attention.wk": ColwiseParallel(), "attention.wv": ColwiseParallel(), "attention.wo": RowwiseParallel(output_layouts=Shard(1)), "attention_norm": SequenceParallel(), "feed_forward": PrepareModuleInput( input_layouts=(Shard(1),), desired_input_layouts=(Replicate(),), ), "feed_forward.w1": ColwiseParallel(), "feed_forward.w2": RowwiseParallel(output_layouts=Shard(1)), "feed_forward.w3": ColwiseParallel(), "ffn_norm": SequenceParallel(), } # Adjust attention module to use the local number of heads attn_layer = transformer_block.attention attn_layer.n_heads = attn_layer.n_heads // tp_mesh.size() attn_layer.n_kv_heads = attn_layer.n_kv_heads // tp_mesh.size() # Apply the plan for the current transformer block parallelize_module(transformer_block, tp_mesh, plan) if dp_mesh.size() > 1: assert dp_mesh.ndim == 1 # Hybrid-sharding not supported # NOTE: Currently, the user is required to manually handle precision settings such as the `mp_policy` here # because the model parallel strategy does not respect all settings of `Fabric(precision=...)` at the moment. mp_policy = MixedPrecisionPolicy(param_dtype=torch.bfloat16, reduce_dtype=torch.float32) fsdp_config = {"mesh": dp_mesh, "mp_policy": mp_policy} for layer_id, transformer_block in model.layers.items(): # Apply activation checkpointing transformer_block = checkpoint_wrapper(transformer_block) # As an optimization, do not reshard after forward for the last # transformer block since FSDP would prefetch it immediately reshard_after_forward = int(layer_id) < len(model.layers) - 1 fully_shard( transformer_block, **fsdp_config, reshard_after_forward=reshard_after_forward, ) model.layers[layer_id] = transformer_block model = fully_shard(model, **fsdp_config) return model