diff --git a/docs/source-fabric/advanced/model_init.rst b/docs/source-fabric/advanced/model_init.rst index f1e5cf846b..4b31df036f 100644 --- a/docs/source-fabric/advanced/model_init.rst +++ b/docs/source-fabric/advanced/model_init.rst @@ -61,15 +61,15 @@ When loading a model from a checkpoint, for example when fine-tuning, set ``empt ---- -******************************************** -Model-parallel training (FSDP and DeepSpeed) -******************************************** +*************************************************** +Model-parallel training (FSDP, TP, DeepSpeed, etc.) +*************************************************** -When training sharded models with :doc:`FSDP ` or DeepSpeed, using :meth:`~lightning.fabric.fabric.Fabric.init_module` is necessary in most cases because otherwise model initialization gets very slow (minutes) or (and that's more likely) you run out of CPU memory due to the size of the model. +When training distributed models with :doc:`FSDP/TP ` or DeepSpeed, using :meth:`~lightning.fabric.fabric.Fabric.init_module` is necessary in most cases because otherwise model initialization gets very slow (minutes) or (and that's more likely) you run out of CPU memory due to the size of the model. .. code-block:: python - # Recommended for FSDP and DeepSpeed + # Recommended for FSDP, TP and DeepSpeed with fabric.init_module(empty_init=True): model = GPT3() # parameters are placed on the meta-device @@ -81,4 +81,4 @@ When training sharded models with :doc:`FSDP ` or DeepSpeed .. note:: Empty-init is experimental and the behavior may change in the future. - For FSDP on PyTorch 2.1+, it is required that all user-defined modules that manage parameters implement a ``reset_parameters()`` method (all PyTorch built-in modules have this too). + For distributed models on PyTorch 2.1+, it is required that all user-defined modules that manage parameters implement a ``reset_parameters()`` method (all PyTorch built-in modules have this too). diff --git a/docs/source-fabric/advanced/model_parallel/index.rst b/docs/source-fabric/advanced/model_parallel/index.rst index 59b8acc72d..78649279a2 100644 --- a/docs/source-fabric/advanced/model_parallel/index.rst +++ b/docs/source-fabric/advanced/model_parallel/index.rst @@ -82,7 +82,6 @@ Get started .. displayitem:: :header: Pipeline Parallelism :description: Coming soon - :button_link: :col_css: col-md-4 :height: 180 :tag: advanced diff --git a/docs/source-fabric/advanced/model_parallel/tp.rst b/docs/source-fabric/advanced/model_parallel/tp.rst index 4d5fb26181..fdb05da121 100644 --- a/docs/source-fabric/advanced/model_parallel/tp.rst +++ b/docs/source-fabric/advanced/model_parallel/tp.rst @@ -110,6 +110,7 @@ In Fabric, define a function that applies the tensor parallelism to the model: parallelize_module(model, tp_mesh, plan) return model +By writing the parallelization code in a separate function rather than hardcoding it into the model, we keep the original source code clean and maintainable. Next, configure the :class:`~lightning.fabric.strategies.model_parallel.ModelParallelStrategy` in Fabric: .. code-block:: python @@ -222,6 +223,8 @@ Beyond this toy example, we recommend you study our `LLM Tensor Parallel Example ---- +.. _tp-data-loading: + *************************** Data-loading considerations *************************** diff --git a/docs/source-fabric/advanced/model_parallel/tp_fsdp.rst b/docs/source-fabric/advanced/model_parallel/tp_fsdp.rst index 606d9619f6..7bab5e88a3 100644 --- a/docs/source-fabric/advanced/model_parallel/tp_fsdp.rst +++ b/docs/source-fabric/advanced/model_parallel/tp_fsdp.rst @@ -2,4 +2,286 @@ 2D Parallelism (Tensor Parallelism + FSDP) ########################################## -Content will be available soon. +2D Parallelism combines Tensor Parallelism (TP) and Fully Sharded Data Parallelism (FSDP) to leverage the memory efficiency of FSDP and the computational scalability of TP. +This hybrid approach balances the trade-offs of each method, optimizing memory usage and minimizing communication overhead, enabling the training of extremely large models on large GPU clusters. + +The :doc:`Tensor Parallelism documentation ` and a general understanding of `FSDP `_ are a prerequisite for this tutorial. + +.. note:: This is an experimental feature. + + +---- + + +********************* +Enable 2D parallelism +********************* + +We will start off with the same feed forward example model as in the :doc:`Tensor Parallelism tutorial `. + +.. code-block:: python + + import torch + import torch.nn as nn + import torch.nn.functional as F + + + class FeedForward(nn.Module): + def __init__(self, dim, hidden_dim): + super().__init__() + self.w1 = nn.Linear(dim, hidden_dim, bias=False) + self.w2 = nn.Linear(hidden_dim, dim, bias=False) + self.w3 = nn.Linear(dim, hidden_dim, bias=False) + + def forward(self, x): + return self.w2(F.silu(self.w1(x)) * self.w3(x)) + +Next, we define a function that applies the desired parallelism to our model. +The function must take as first argument the model and as second argument the a :class:`~torch.distributed.device_mesh.DeviceMesh`. +More on how the device mesh works later. + +.. code-block:: python + + from torch.distributed.tensor.parallel import ColwiseParallel, RowwiseParallel + from torch.distributed.tensor.parallel import parallelize_module + from torch.distributed._composable.fsdp.fully_shard import fully_shard + + def parallelize_feedforward(model, device_mesh): + # Lightning will set up a device mesh for you + # Here, it is 2-dimensional + tp_mesh = device_mesh["tensor_parallel"] + dp_mesh = device_mesh["data_parallel"] + + if tp_mesh.size() > 1: + # Use PyTorch's distributed tensor APIs to parallelize the model + plan = { + "w1": ColwiseParallel(), + "w2": RowwiseParallel(), + "w3": ColwiseParallel(), + } + parallelize_module(model, tp_mesh, plan) + + if dp_mesh.size() > 1: + # Use PyTorch's FSDP2 APIs to parallelize the model + fully_shard(model.w1, mesh=dp_mesh) + fully_shard(model.w2, mesh=dp_mesh) + fully_shard(model.w3, mesh=dp_mesh) + fully_shard(model, mesh=dp_mesh) + + return model + +By writing the parallelization code in a separate function rather than hardcoding it into the model, we keep the original source code clean and maintainable. +In addition to the tensor-parallel code from the :doc:`Tensor Parallelism tutorial `, this function also shards the model's parameters using FSDP along the data-parallel dimension. + +Finally, pass the parallelization function to the :class:`~lightning.fabric.strategies.model_parallel.ModelParallelStrategy` and configure the data-parallel and tensor-parallel sizes: + +.. code-block:: python + + import lightning as L + from lightning.fabric.strategies import ModelParallelStrategy + + strategy = ModelParallelStrategy( + parallelize_fn=parallelize_feedforward, + # Define the size of the 2D parallelism + # Set these to "auto" (default) to apply TP intra-node and FSDP inter-node + data_parallel_size=2, + tensor_parallel_size=2, + ) + + fabric = L.Fabric(accelerator="cuda", devices=4, strategy=strategy) + fabric.launch() + + +In this example with 4 GPUs, Fabric will create a device mesh that groups GPU 0-1 and GPU 2-3 (2 groups because ``data_parallel_size=2``, and 2 GPUs per group because ``tensor_parallel_size=2``). +Later on when ``fabric.setup(model)`` is called, each layer wrapped with FSDP (``fully_shard``) will be split into two shards, one for the GPU 0-1 group, and one for the GPU 2-3 group. +Finally, the tensor parallelism will apply to each group, splitting the sharded tensor across the GPUs within each group. + + +.. collapse:: Full training example (requires at least 4 GPUs). + + .. code-block:: python + + import torch + import torch.nn as nn + import torch.nn.functional as F + + from torch.distributed.tensor.parallel import ColwiseParallel, RowwiseParallel + from torch.distributed.tensor.parallel import parallelize_module + from torch.distributed._composable.fsdp.fully_shard import fully_shard + + import lightning as L + from lightning.pytorch.demos.boring_classes import RandomDataset + from lightning.fabric.strategies import ModelParallelStrategy + + + class FeedForward(nn.Module): + def __init__(self, dim, hidden_dim): + super().__init__() + self.w1 = nn.Linear(dim, hidden_dim, bias=False) + self.w2 = nn.Linear(hidden_dim, dim, bias=False) + self.w3 = nn.Linear(dim, hidden_dim, bias=False) + + def forward(self, x): + return self.w2(F.silu(self.w1(x)) * self.w3(x)) + + + def parallelize_feedforward(model, device_mesh): + # Lightning will set up a device mesh for you + # Here, it is 2-dimensional + tp_mesh = device_mesh["tensor_parallel"] + dp_mesh = device_mesh["data_parallel"] + + if tp_mesh.size() > 1: + # Use PyTorch's distributed tensor APIs to parallelize the model + plan = { + "w1": ColwiseParallel(), + "w2": RowwiseParallel(), + "w3": ColwiseParallel(), + } + parallelize_module(model, tp_mesh, plan) + + if dp_mesh.size() > 1: + # Use PyTorch's FSDP2 APIs to parallelize the model + fully_shard(model.w1, mesh=dp_mesh) + fully_shard(model.w2, mesh=dp_mesh) + fully_shard(model.w3, mesh=dp_mesh) + fully_shard(model, mesh=dp_mesh) + + return model + + + strategy = ModelParallelStrategy( + parallelize_fn=parallelize_feedforward, + data_parallel_size=2, + tensor_parallel_size=2, + ) + + fabric = L.Fabric(accelerator="cuda", devices=4, strategy=strategy) + fabric.launch() + + # Initialize the model + model = FeedForward(8192, 8192) + model = fabric.setup(model) + + # Define the optimizer + optimizer = torch.optim.AdamW(model.parameters(), lr=3e-3, foreach=True) + optimizer = fabric.setup_optimizers(optimizer) + + # Define dataset/dataloader + dataset = RandomDataset(8192, 128) + dataloader = torch.utils.data.DataLoader(dataset, batch_size=8) + dataloader = fabric.setup_dataloaders(dataloader) + + # Simplified training loop + for i, batch in enumerate(dataloader): + output = model(batch) + loss = output.sum() + fabric.backward(loss) + optimizer.step() + optimizer.zero_grad() + fabric.print(f"Iteration {i} complete") + + fabric.print(f"Peak memory usage: {torch.cuda.max_memory_allocated() / 1e9:.02f} GB") + +| + +Beyond this toy example, we recommend you study our `LLM 2D Parallel Example (Llama 3) `_. + + +---- + + +******************* +Effective use cases +******************* + +In the toy example above, the parallelization is configured to work within a single machine across multiple GPUs. +However, in practice the main use case for 2D parallelism is in multi-node training, where one can effectively combine both methods to maximize throughput and model scale. +Since tensor-parallelism requires blocking collective calls, fast GPU data transfers are essential to keep throughput high and therefore TP is typically applied across GPUs within a machine. +On the other hand, FSDP by design has the advantage that it can overlap GPU transfers with the computation (it can prefetch layers). +Hence, combining FSDP for inter-node parallelism and TP for intra-node parallelism is generally a good strategy to minimize both the latency and network bandwidth usage, making it possible to scale to much larger models than is possible with FSDP alone. + + +.. code-block:: python + + from lightning.fabric.strategies import ModelParallelStrategy + + strategy = ModelParallelStrategy( + # Default is "auto" + # Applies TP intra-node and DP inter-node + data_parallel_size="auto", + tensor_parallel_size="auto", + ) + + +---- + + +*************************** +Data-loading considerations +*************************** + +In a tensor-parallelized model, it is important that the model receives an identical input on each GPU that participates in the same tensor-parallel group. +However, across the data-parallel dimension, the inputs should be different. +In other words, if TP is applied within a node, and FSDP across nodes, each node must receive a different batch, but every GPU within the node gets the same batch of data. + +If you use a PyTorch data loader and set it up using :meth:`~lightning.fabric.fabric.Fabric.setup_dataloaders`, Fabric will automatically handle this for you by configuring the distributed sampler. +However, when you shuffle data in your dataset or data loader, or when applying randomized transformations/augmentations in your data, you must still ensure that the seed is set appropriately. + + +.. code-block:: python + + import lightning as L + + fabric = L.Fabric(...) + + # Define dataset/dataloader + # If there is randomness/augmentation in the dataset, fix the seed + dataset = MyDataset(seed=42) + dataloader = DataLoader(dataset, batch_size=8, shuffle=True) + + # Fabric configures the sampler automatically for you such that + # all batches in a tensor-parallel group are identical, + # while still sharding the dataset across the data-parallel group + dataloader = fabric.setup_dataloaders(dataloader) + + for i, batch in enumerate(dataloader): + ... + + + + +---- + + +********** +Next steps +********** + +.. raw:: html + +
+
+ +.. displayitem:: + :header: LLM 2D Parallel Example + :description: Full example how to combine TP + FSDP in a large language model (Llama 3) + :col_css: col-md-4 + :button_link: https://github.com/Lightning-AI/pytorch-lightning/tree/master/examples/fabric/tensor_parallel + :height: 160 + :tag: advanced + +.. displayitem:: + :header: Pipeline Parallelism + :description: Coming sooon + :col_css: col-md-4 + :height: 160 + :tag: advanced + + +.. raw:: html + +
+
+ +|