(8/n) Support 2D Parallelism - 2D Parallel Fabric Docs (#19887)
This commit is contained in:
parent
8fc7b4ae94
commit
341474aaac
|
@ -61,15 +61,15 @@ When loading a model from a checkpoint, for example when fine-tuning, set ``empt
|
|||
----
|
||||
|
||||
|
||||
********************************************
|
||||
Model-parallel training (FSDP and DeepSpeed)
|
||||
********************************************
|
||||
***************************************************
|
||||
Model-parallel training (FSDP, TP, DeepSpeed, etc.)
|
||||
***************************************************
|
||||
|
||||
When training sharded models with :doc:`FSDP <model_parallel/fsdp>` or DeepSpeed, using :meth:`~lightning.fabric.fabric.Fabric.init_module` is necessary in most cases because otherwise model initialization gets very slow (minutes) or (and that's more likely) you run out of CPU memory due to the size of the model.
|
||||
When training distributed models with :doc:`FSDP/TP <model_parallel/index>` or DeepSpeed, using :meth:`~lightning.fabric.fabric.Fabric.init_module` is necessary in most cases because otherwise model initialization gets very slow (minutes) or (and that's more likely) you run out of CPU memory due to the size of the model.
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
# Recommended for FSDP and DeepSpeed
|
||||
# Recommended for FSDP, TP and DeepSpeed
|
||||
with fabric.init_module(empty_init=True):
|
||||
model = GPT3() # parameters are placed on the meta-device
|
||||
|
||||
|
@ -81,4 +81,4 @@ When training sharded models with :doc:`FSDP <model_parallel/fsdp>` or DeepSpeed
|
|||
|
||||
.. note::
|
||||
Empty-init is experimental and the behavior may change in the future.
|
||||
For FSDP on PyTorch 2.1+, it is required that all user-defined modules that manage parameters implement a ``reset_parameters()`` method (all PyTorch built-in modules have this too).
|
||||
For distributed models on PyTorch 2.1+, it is required that all user-defined modules that manage parameters implement a ``reset_parameters()`` method (all PyTorch built-in modules have this too).
|
||||
|
|
|
@ -82,7 +82,6 @@ Get started
|
|||
.. displayitem::
|
||||
:header: Pipeline Parallelism
|
||||
:description: Coming soon
|
||||
:button_link:
|
||||
:col_css: col-md-4
|
||||
:height: 180
|
||||
:tag: advanced
|
||||
|
|
|
@ -110,6 +110,7 @@ In Fabric, define a function that applies the tensor parallelism to the model:
|
|||
parallelize_module(model, tp_mesh, plan)
|
||||
return model
|
||||
|
||||
By writing the parallelization code in a separate function rather than hardcoding it into the model, we keep the original source code clean and maintainable.
|
||||
Next, configure the :class:`~lightning.fabric.strategies.model_parallel.ModelParallelStrategy` in Fabric:
|
||||
|
||||
.. code-block:: python
|
||||
|
@ -222,6 +223,8 @@ Beyond this toy example, we recommend you study our `LLM Tensor Parallel Example
|
|||
----
|
||||
|
||||
|
||||
.. _tp-data-loading:
|
||||
|
||||
***************************
|
||||
Data-loading considerations
|
||||
***************************
|
||||
|
|
|
@ -2,4 +2,286 @@
|
|||
2D Parallelism (Tensor Parallelism + FSDP)
|
||||
##########################################
|
||||
|
||||
Content will be available soon.
|
||||
2D Parallelism combines Tensor Parallelism (TP) and Fully Sharded Data Parallelism (FSDP) to leverage the memory efficiency of FSDP and the computational scalability of TP.
|
||||
This hybrid approach balances the trade-offs of each method, optimizing memory usage and minimizing communication overhead, enabling the training of extremely large models on large GPU clusters.
|
||||
|
||||
The :doc:`Tensor Parallelism documentation <tp>` and a general understanding of `FSDP <https://pytorch.org/tutorials/intermediate/FSDP_tutorial.html>`_ are a prerequisite for this tutorial.
|
||||
|
||||
.. note:: This is an experimental feature.
|
||||
|
||||
|
||||
----
|
||||
|
||||
|
||||
*********************
|
||||
Enable 2D parallelism
|
||||
*********************
|
||||
|
||||
We will start off with the same feed forward example model as in the :doc:`Tensor Parallelism tutorial <tp>`.
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
class FeedForward(nn.Module):
|
||||
def __init__(self, dim, hidden_dim):
|
||||
super().__init__()
|
||||
self.w1 = nn.Linear(dim, hidden_dim, bias=False)
|
||||
self.w2 = nn.Linear(hidden_dim, dim, bias=False)
|
||||
self.w3 = nn.Linear(dim, hidden_dim, bias=False)
|
||||
|
||||
def forward(self, x):
|
||||
return self.w2(F.silu(self.w1(x)) * self.w3(x))
|
||||
|
||||
Next, we define a function that applies the desired parallelism to our model.
|
||||
The function must take as first argument the model and as second argument the a :class:`~torch.distributed.device_mesh.DeviceMesh`.
|
||||
More on how the device mesh works later.
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
from torch.distributed.tensor.parallel import ColwiseParallel, RowwiseParallel
|
||||
from torch.distributed.tensor.parallel import parallelize_module
|
||||
from torch.distributed._composable.fsdp.fully_shard import fully_shard
|
||||
|
||||
def parallelize_feedforward(model, device_mesh):
|
||||
# Lightning will set up a device mesh for you
|
||||
# Here, it is 2-dimensional
|
||||
tp_mesh = device_mesh["tensor_parallel"]
|
||||
dp_mesh = device_mesh["data_parallel"]
|
||||
|
||||
if tp_mesh.size() > 1:
|
||||
# Use PyTorch's distributed tensor APIs to parallelize the model
|
||||
plan = {
|
||||
"w1": ColwiseParallel(),
|
||||
"w2": RowwiseParallel(),
|
||||
"w3": ColwiseParallel(),
|
||||
}
|
||||
parallelize_module(model, tp_mesh, plan)
|
||||
|
||||
if dp_mesh.size() > 1:
|
||||
# Use PyTorch's FSDP2 APIs to parallelize the model
|
||||
fully_shard(model.w1, mesh=dp_mesh)
|
||||
fully_shard(model.w2, mesh=dp_mesh)
|
||||
fully_shard(model.w3, mesh=dp_mesh)
|
||||
fully_shard(model, mesh=dp_mesh)
|
||||
|
||||
return model
|
||||
|
||||
By writing the parallelization code in a separate function rather than hardcoding it into the model, we keep the original source code clean and maintainable.
|
||||
In addition to the tensor-parallel code from the :doc:`Tensor Parallelism tutorial <tp>`, this function also shards the model's parameters using FSDP along the data-parallel dimension.
|
||||
|
||||
Finally, pass the parallelization function to the :class:`~lightning.fabric.strategies.model_parallel.ModelParallelStrategy` and configure the data-parallel and tensor-parallel sizes:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
import lightning as L
|
||||
from lightning.fabric.strategies import ModelParallelStrategy
|
||||
|
||||
strategy = ModelParallelStrategy(
|
||||
parallelize_fn=parallelize_feedforward,
|
||||
# Define the size of the 2D parallelism
|
||||
# Set these to "auto" (default) to apply TP intra-node and FSDP inter-node
|
||||
data_parallel_size=2,
|
||||
tensor_parallel_size=2,
|
||||
)
|
||||
|
||||
fabric = L.Fabric(accelerator="cuda", devices=4, strategy=strategy)
|
||||
fabric.launch()
|
||||
|
||||
|
||||
In this example with 4 GPUs, Fabric will create a device mesh that groups GPU 0-1 and GPU 2-3 (2 groups because ``data_parallel_size=2``, and 2 GPUs per group because ``tensor_parallel_size=2``).
|
||||
Later on when ``fabric.setup(model)`` is called, each layer wrapped with FSDP (``fully_shard``) will be split into two shards, one for the GPU 0-1 group, and one for the GPU 2-3 group.
|
||||
Finally, the tensor parallelism will apply to each group, splitting the sharded tensor across the GPUs within each group.
|
||||
|
||||
|
||||
.. collapse:: Full training example (requires at least 4 GPUs).
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from torch.distributed.tensor.parallel import ColwiseParallel, RowwiseParallel
|
||||
from torch.distributed.tensor.parallel import parallelize_module
|
||||
from torch.distributed._composable.fsdp.fully_shard import fully_shard
|
||||
|
||||
import lightning as L
|
||||
from lightning.pytorch.demos.boring_classes import RandomDataset
|
||||
from lightning.fabric.strategies import ModelParallelStrategy
|
||||
|
||||
|
||||
class FeedForward(nn.Module):
|
||||
def __init__(self, dim, hidden_dim):
|
||||
super().__init__()
|
||||
self.w1 = nn.Linear(dim, hidden_dim, bias=False)
|
||||
self.w2 = nn.Linear(hidden_dim, dim, bias=False)
|
||||
self.w3 = nn.Linear(dim, hidden_dim, bias=False)
|
||||
|
||||
def forward(self, x):
|
||||
return self.w2(F.silu(self.w1(x)) * self.w3(x))
|
||||
|
||||
|
||||
def parallelize_feedforward(model, device_mesh):
|
||||
# Lightning will set up a device mesh for you
|
||||
# Here, it is 2-dimensional
|
||||
tp_mesh = device_mesh["tensor_parallel"]
|
||||
dp_mesh = device_mesh["data_parallel"]
|
||||
|
||||
if tp_mesh.size() > 1:
|
||||
# Use PyTorch's distributed tensor APIs to parallelize the model
|
||||
plan = {
|
||||
"w1": ColwiseParallel(),
|
||||
"w2": RowwiseParallel(),
|
||||
"w3": ColwiseParallel(),
|
||||
}
|
||||
parallelize_module(model, tp_mesh, plan)
|
||||
|
||||
if dp_mesh.size() > 1:
|
||||
# Use PyTorch's FSDP2 APIs to parallelize the model
|
||||
fully_shard(model.w1, mesh=dp_mesh)
|
||||
fully_shard(model.w2, mesh=dp_mesh)
|
||||
fully_shard(model.w3, mesh=dp_mesh)
|
||||
fully_shard(model, mesh=dp_mesh)
|
||||
|
||||
return model
|
||||
|
||||
|
||||
strategy = ModelParallelStrategy(
|
||||
parallelize_fn=parallelize_feedforward,
|
||||
data_parallel_size=2,
|
||||
tensor_parallel_size=2,
|
||||
)
|
||||
|
||||
fabric = L.Fabric(accelerator="cuda", devices=4, strategy=strategy)
|
||||
fabric.launch()
|
||||
|
||||
# Initialize the model
|
||||
model = FeedForward(8192, 8192)
|
||||
model = fabric.setup(model)
|
||||
|
||||
# Define the optimizer
|
||||
optimizer = torch.optim.AdamW(model.parameters(), lr=3e-3, foreach=True)
|
||||
optimizer = fabric.setup_optimizers(optimizer)
|
||||
|
||||
# Define dataset/dataloader
|
||||
dataset = RandomDataset(8192, 128)
|
||||
dataloader = torch.utils.data.DataLoader(dataset, batch_size=8)
|
||||
dataloader = fabric.setup_dataloaders(dataloader)
|
||||
|
||||
# Simplified training loop
|
||||
for i, batch in enumerate(dataloader):
|
||||
output = model(batch)
|
||||
loss = output.sum()
|
||||
fabric.backward(loss)
|
||||
optimizer.step()
|
||||
optimizer.zero_grad()
|
||||
fabric.print(f"Iteration {i} complete")
|
||||
|
||||
fabric.print(f"Peak memory usage: {torch.cuda.max_memory_allocated() / 1e9:.02f} GB")
|
||||
|
||||
|
|
||||
|
||||
Beyond this toy example, we recommend you study our `LLM 2D Parallel Example (Llama 3) <https://github.com/Lightning-AI/pytorch-lightning/tree/master/examples/fabric/tensor_parallel>`_.
|
||||
|
||||
|
||||
----
|
||||
|
||||
|
||||
*******************
|
||||
Effective use cases
|
||||
*******************
|
||||
|
||||
In the toy example above, the parallelization is configured to work within a single machine across multiple GPUs.
|
||||
However, in practice the main use case for 2D parallelism is in multi-node training, where one can effectively combine both methods to maximize throughput and model scale.
|
||||
Since tensor-parallelism requires blocking collective calls, fast GPU data transfers are essential to keep throughput high and therefore TP is typically applied across GPUs within a machine.
|
||||
On the other hand, FSDP by design has the advantage that it can overlap GPU transfers with the computation (it can prefetch layers).
|
||||
Hence, combining FSDP for inter-node parallelism and TP for intra-node parallelism is generally a good strategy to minimize both the latency and network bandwidth usage, making it possible to scale to much larger models than is possible with FSDP alone.
|
||||
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
from lightning.fabric.strategies import ModelParallelStrategy
|
||||
|
||||
strategy = ModelParallelStrategy(
|
||||
# Default is "auto"
|
||||
# Applies TP intra-node and DP inter-node
|
||||
data_parallel_size="auto",
|
||||
tensor_parallel_size="auto",
|
||||
)
|
||||
|
||||
|
||||
----
|
||||
|
||||
|
||||
***************************
|
||||
Data-loading considerations
|
||||
***************************
|
||||
|
||||
In a tensor-parallelized model, it is important that the model receives an identical input on each GPU that participates in the same tensor-parallel group.
|
||||
However, across the data-parallel dimension, the inputs should be different.
|
||||
In other words, if TP is applied within a node, and FSDP across nodes, each node must receive a different batch, but every GPU within the node gets the same batch of data.
|
||||
|
||||
If you use a PyTorch data loader and set it up using :meth:`~lightning.fabric.fabric.Fabric.setup_dataloaders`, Fabric will automatically handle this for you by configuring the distributed sampler.
|
||||
However, when you shuffle data in your dataset or data loader, or when applying randomized transformations/augmentations in your data, you must still ensure that the seed is set appropriately.
|
||||
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
import lightning as L
|
||||
|
||||
fabric = L.Fabric(...)
|
||||
|
||||
# Define dataset/dataloader
|
||||
# If there is randomness/augmentation in the dataset, fix the seed
|
||||
dataset = MyDataset(seed=42)
|
||||
dataloader = DataLoader(dataset, batch_size=8, shuffle=True)
|
||||
|
||||
# Fabric configures the sampler automatically for you such that
|
||||
# all batches in a tensor-parallel group are identical,
|
||||
# while still sharding the dataset across the data-parallel group
|
||||
dataloader = fabric.setup_dataloaders(dataloader)
|
||||
|
||||
for i, batch in enumerate(dataloader):
|
||||
...
|
||||
|
||||
|
||||
|
||||
|
||||
----
|
||||
|
||||
|
||||
**********
|
||||
Next steps
|
||||
**********
|
||||
|
||||
.. raw:: html
|
||||
|
||||
<div class="display-card-container">
|
||||
<div class="row">
|
||||
|
||||
.. displayitem::
|
||||
:header: LLM 2D Parallel Example
|
||||
:description: Full example how to combine TP + FSDP in a large language model (Llama 3)
|
||||
:col_css: col-md-4
|
||||
:button_link: https://github.com/Lightning-AI/pytorch-lightning/tree/master/examples/fabric/tensor_parallel
|
||||
:height: 160
|
||||
:tag: advanced
|
||||
|
||||
.. displayitem::
|
||||
:header: Pipeline Parallelism
|
||||
:description: Coming sooon
|
||||
:col_css: col-md-4
|
||||
:height: 160
|
||||
:tag: advanced
|
||||
|
||||
|
||||
.. raw:: html
|
||||
|
||||
</div>
|
||||
</div>
|
||||
|
||||
|
|
||||
|
|
Loading…
Reference in New Issue