lightning/docs/source-fabric/advanced/model_parallel/tp.rst

288 lines
11 KiB
ReStructuredText

##################
Tensor Parallelism
##################
Tensor parallelism is a technique for training large models by distributing layers across multiple devices, improving memory management and efficiency by reducing inter-device communication.
However, for smaller models, the communication overhead may outweigh its benefits.
This method is most effective for models with very large layers, significantly enhancing performance and memory efficiency.
.. raw:: html
<a target="_blank" href="https://lightning.ai/lightning-ai/studios/tensor-parallelism-supercharging-large-model-training-with-lightning-fabric">
<img src="https://pl-bolts-doc-images.s3.us-east-2.amazonaws.com/app-2/studio-badge.svg" alt="Open In Studio" style="width: auto; max-width: none;"/>
</a>
----
*******************************************
How to exploit parallelism in linear layers
*******************************************
In tensor parallelism, the computation of a linear layer can be split up across GPUs.
This saves memory because each GPU only needs to hold a portion of the weight matrix.
There are two ways a linear layer can be split up: row-wise or column-wise.
Column-wise Parallel
====================
In a column-wise parallel layer, the weight matrix is split evenly along the column dimension.
Each GPU is sent the same input, and computes a regular matrix multiplication with its portion of the weight matrix.
At the end, the outputs from each GPU can be concatenated to form the final output.
.. figure:: https://pl-public-data.s3.amazonaws.com/assets_lightning/fabric/tp-colwise.jpeg
:alt: Left: Regular matrix multiplication. Right: Column-wise parallel matrix multiplication split across two GPUs.
:width: 100%
Row-wise Parallel
=================
Row-wise parallelism divides the rows of the weight matrix evenly across devices.
In addition, the input gets split the same way along the inner dimension (because the weight matrix now has fewer rows).
Each GPU then performs a regular matrix multiplication with its portion of the weight matrix and inputs.
At the end, the outputs from each GPU can be summed up element-wise (all-reduce) to form the final output.
.. figure:: https://pl-public-data.s3.amazonaws.com/assets_lightning/fabric/tp-rowwise.jpeg
:alt: Left: Regular matrix multiplication. Right: Row-wise parallel matrix multiplication split across two GPUs.
:width: 100%
Combined Column- and Row-wise Parallel
======================================
When there are multiple linear layers in sequence, e.g., in a MLP or a Transformer, the column-wise and row-wise parallel styles can be combined for maximum effect.
Instead of concatenating the output of the column-wise parallel layer, we keep the outputs separate and feed them directly to the row-wise parallel layer.
This way, we avoid costly data transfers between GPUs.
.. figure:: https://pl-public-data.s3.amazonaws.com/assets_lightning/fabric/tp-combined.jpeg
:alt: Top: Two regular matrix multiplications in sequence. Bottom: Combined column-wise and row-wise parallel matrix multiplications across two GPUs.
:width: 100%
Note that activation functions between the layers can still be applied without additional communication because they are element-wise, but are not shown in the figures for simplicity.
----
***********************************
Apply tensor parallelism to a model
***********************************
To apply tensor parallelism to a model with Fabric, you need a good understanding of your model's architecture to make the decision of where to apply the parallel styles you've seen above.
Let's start with a simple MLP toy example:
.. code-block:: python
import torch.nn as nn
import torch.nn.functional as F
class FeedForward(nn.Module):
def __init__(self, dim, hidden_dim):
super().__init__()
self.w1 = nn.Linear(dim, hidden_dim, bias=False)
self.w2 = nn.Linear(hidden_dim, dim, bias=False)
self.w3 = nn.Linear(dim, hidden_dim, bias=False)
def forward(self, x):
return self.w2(F.silu(self.w1(x)) * self.w3(x))
This model has three linear layers. Layers ``w1`` and ``w3`` produce an output that is later multiplied element-wise.
That output is then fed into layer ``w2``.
Therefore, ``w1`` and ``w3`` are suitable candidates for column-wise parallelism, because their output(s) can easily be combined with ``w2`` in row-wise fashion.
In Fabric, define a function that applies the tensor parallelism to the model:
.. code-block:: python
from torch.distributed.tensor.parallel import ColwiseParallel, RowwiseParallel
from torch.distributed.tensor.parallel import parallelize_module
def parallelize_feedforward(model, device_mesh):
# Lightning will set up a device mesh for you
tp_mesh = device_mesh["tensor_parallel"]
# Use PyTorch's distributed tensor APIs to parallelize the model
plan = {
"w1": ColwiseParallel(),
"w2": RowwiseParallel(),
"w3": ColwiseParallel(),
}
parallelize_module(model, tp_mesh, plan)
return model
By writing the parallelization code in a separate function rather than hardcoding it into the model, we keep the original source code clean and maintainable.
Next, configure the :class:`~lightning.fabric.strategies.model_parallel.ModelParallelStrategy` in Fabric:
.. code-block:: python
import lightning as L
from lightning.fabric.strategies import ModelParallelStrategy
# 1. Pass the parallelization function to the strategy
strategy = ModelParallelStrategy(parallelize_fn=parallelize_feedforward)
# 2. Configure devices and set the strategy in Fabric
fabric = L.Fabric(accelerator="cuda", devices=2, strategy=strategy)
fabric.launch()
The strategy takes the custom parallelization function as input.
No other changes to your training code are necessary at this point.
Later in the code, when you call ``fabric.setup(model)``, Fabric will apply the ``parallelize_feedforward`` function to the model automatically.
.. collapse:: Full training example (requires at least 2 GPUs).
.. code-block:: python
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.distributed.tensor.parallel import ColwiseParallel, RowwiseParallel
from torch.distributed.tensor.parallel import parallelize_module
import lightning as L
from lightning.pytorch.demos.boring_classes import RandomDataset
from lightning.fabric.strategies import ModelParallelStrategy
class FeedForward(nn.Module):
def __init__(self, dim, hidden_dim):
super().__init__()
self.w1 = nn.Linear(dim, hidden_dim, bias=False)
self.w2 = nn.Linear(hidden_dim, dim, bias=False)
self.w3 = nn.Linear(dim, hidden_dim, bias=False)
def forward(self, x):
return self.w2(F.silu(self.w1(x)) * self.w3(x))
def parallelize_feedforward(model, device_mesh):
# Lightning will set up a device mesh for you
tp_mesh = device_mesh["tensor_parallel"]
# Use PyTorch's distributed tensor APIs to parallelize the model
plan = {
"w1": ColwiseParallel(),
"w2": RowwiseParallel(),
"w3": ColwiseParallel(),
}
parallelize_module(model, tp_mesh, plan)
return model
strategy = ModelParallelStrategy(parallelize_fn=parallelize_feedforward)
fabric = L.Fabric(accelerator="cuda", devices=2, strategy=strategy)
fabric.launch()
# Initialize the model
model = FeedForward(8192, 8192)
model = fabric.setup(model)
# Define the optimizer
optimizer = torch.optim.AdamW(model.parameters(), lr=3e-3)
optimizer = fabric.setup_optimizers(optimizer)
# Define dataset/dataloader
dataset = RandomDataset(8192, 64)
dataloader = torch.utils.data.DataLoader(dataset, batch_size=8)
dataloader = fabric.setup_dataloaders(dataloader)
# Simplified training loop
for i, batch in enumerate(dataloader):
output = model(batch)
loss = output.sum()
fabric.backward(loss)
optimizer.step()
optimizer.zero_grad()
fabric.print(f"Iteration {i} complete")
fabric.print(f"Peak memory usage: {torch.cuda.max_memory_allocated() / 1e9:.02f} GB")
.. note:: Tensor Parallelism in Lightning Fabric as well as PyTorch is experimental. The APIs may change in the future.
When measuring the peak memory consumption, we should see that doubling the number of GPUs reduces the memory consumption roughly by half:
.. list-table::
:widths: 20 20 20 20 20
:header-rows: 1
* -
- 1 GPU (no TP)
- 2 GPUs
- 4 GPUs
- 8 GPUs
* - Memory per GPU
- 4.04 GB
- 2.03 GB
- 1.02 GB
- 0.60 GB
Beyond this toy example, we recommend you study our `LLM Tensor Parallel Example (Llama 3) <https://github.com/Lightning-AI/pytorch-lightning/tree/master/examples/fabric/tensor_parallel>`_.
----
.. _tp-data-loading:
***************************
Data-loading considerations
***************************
In a tensor-parallelized model, it is important that the model receives an identical input on each GPU.
Otherwise, training won't converge.
Therefore, when you shuffle data in your dataset or data loader, or when applying randomized transformations/augmentations in your data, ensure that the seed is set appropriately.
Given this requirement, your global batch size will be limited by the memory of a single GPU.
To scale the batch size and accelerate training further, you can combine :doc:`tensor parallelism with data parallelism (in particular, FSDP) <tp_fsdp>`.
----
**********
Next steps
**********
.. raw:: html
<div class="display-card-container">
<div class="row">
.. displayitem::
:header: LLM Tensor Parallel Example
:description: Full example how to apply tensor parallelism to a large language model (Llama 3)
:col_css: col-md-4
:button_link: https://github.com/Lightning-AI/pytorch-lightning/tree/master/examples/fabric/tensor_parallel
:height: 160
:tag: advanced
.. displayitem::
:header: 2D Parallel (FSDP + TP)
:description: Combine Tensor Parallelism with FSDP (2D Parallel) to train efficiently on 100s of GPUs
:button_link: tp_fsdp.html
:col_css: col-md-4
:height: 160
:tag: advanced
.. displayitem::
:header: PyTorch API Reference
:description: Explore the official PyTorch Tensor Parallel APIs
:button_link: https://pytorch.org/docs/stable/distributed.tensor.parallel.html
:col_css: col-md-4
:height: 160
:tag: advanced
.. raw:: html
</div>
</div>
|