Distributed communication docs for Lite (#16373)

Co-authored-by: Justus Schock <12886177+justusschock@users.noreply.github.com>
Co-authored-by: Carlos Mocholí <carlossmocholi@gmail.com>
Co-authored-by: Jirka Borovec <6035284+Borda@users.noreply.github.com>
This commit is contained in:
Adrian Wälchli 2023-01-18 23:30:51 +01:00 committed by GitHub
parent 01668bfef2
commit 85786d0c83
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 275 additions and 51 deletions

View File

@ -1,48 +0,0 @@
:orphan:
###########################################
Communication between distributed processes
###########################################
Page is under construction.
----
You can also easily use distributed collectives if required.
.. code-block:: python
fabric = Fabric()
# Transfer and concatenate tensors across processes
fabric.all_gather(...)
# Transfer an object from one process to all the others
fabric.broadcast(..., src=...)
# The total number of processes running across all devices and nodes.
fabric.world_size
# The global index of the current process across all devices and nodes.
fabric.global_rank
# The index of the current process among the processes running on the local node.
fabric.local_rank
# The index of the current node.
fabric.node_rank
# Whether this global rank is rank zero.
if fabric.is_global_zero:
# do something on rank 0
...
# Wait for all processes to enter this call.
fabric.barrier()
The code stays agnostic, whether you are running on CPU, on two GPUS or on multiple machines with many GPUs.
If you require custom data or model device placement, you can deactivate :class:`~lightning_fabric.fabric.Fabric`'s automatic placement by doing ``fabric.setup_dataloaders(..., move_to_device=False)`` for the data and ``fabric.setup(..., move_to_device=False)`` for the model.
Furthermore, you can access the current device from ``fabric.device`` or rely on :meth:`~lightning_fabric.fabric.Fabric.to_device` utility to move an object to the current device.

View File

@ -0,0 +1,238 @@
:orphan:
###########################################
Communication between distributed processes
###########################################
With Fabric you can easily access information about a process or send data between processes with a standardized API and agnostic to the distributed strategy.
----
*******************
Rank and world size
*******************
The rank assigned to a process is a zero-based index in the range of *0, ..., world size - 1*, where *world size* is the total number of distributed processes.
If you are using multi-GPU, think of the rank as the *GPU ID* or *GPU index*, although rank extends to distributed processing in general.
The rank is unique across all processes, regardless of how they are distributed across machines, and it is therefore also called **global rank**.
We can also identify processes by their **local rank**, which is only unique among processes runing on the same machine, but is not unique globally across all machines.
Finally, each process is associated with a **node rank** in the range *0, ..., num nodes - 1*, which identifies on which machine (node) the process is running on.
.. figure:: https://pl-public-data.s3.amazonaws.com/assets_lightning/fabric_collectives_ranks.jpeg
:alt: The different type of process ranks: Local, global, node.
:width: 100%
Here is how you launch multiple processes in Fabric:
.. code-block:: python
from lightning.fabric import Fabric
# Devices and num_nodes determine how many processes there are
fabric = Fabric(devices=2, num_nodes=3)
fabric.launch()
Learn more about :doc:`launching distributed training <../fundamentals/launch>`.
And here is how you access all rank and world size information:
.. code-block:: python
# The total number of processes running across all devices and nodes
fabric.world_size # 2 * 3 = 6
# The global index of the current process across all devices and nodes
fabric.global_rank # -> {0, 1, 2, 3, 4, 5}
# The index of the current process among the processes running on the local node
fabric.local_rank # -> {0, 1}
# The index of the current node
fabric.node_rank # -> {0, 1, 2}
# Do something only on rank 0
if fabric.global_rank == 0:
...
.. _race conditions:
Avoid race conditions
=====================
Access to the rank information helps you avoid *race conditions* which could crash your script or lead to corrupted data.
Such conditions can occur when multiple processes are trying to write to the same file all at the same time, for example, in the case of writing a checkpoint file or downloading a dataset.
Avoid this from happening by guarding your logic with a rank check:
.. code-block:: python
# Only write files from one process (rank 0) ...
if fabric.global_rank == 0:
with open("output.txt", "w") as file:
file.write(...)
# ... or save from all processes but don't write to the same file
with open(f"output-{fabric.global_rank}.txt", "w") as file:
file.write(...)
# Multi-node: download a dataset, the filesystem between nodes is shared
if fabric.global_rank == 0:
download_dataset()
# Multi-node: download a dataset, the filesystem between nodes is NOT shared
if fabric.local_rank == 0:
download_dataset()
----
*******
Barrier
*******
The barrier forces every process to wait until all processes have reached it.
In other words, it is a **synchronization**.
.. figure:: https://pl-public-data.s3.amazonaws.com/assets_lightning/fabric_collectives_barrier.jpeg
:alt: The barrier for process synchronization
:width: 100%
A barrier is needed when processes do different amounts of work and as a result fall out of sync.
.. code-block:: python
fabric = Fabric(accelerator="cpu", devices=4)
fabric.launch()
# Simulate each process taking a different amount of time
sleep(2 * fabric.global_rank)
print(f"Process {fabric.global_rank} is done.")
# Wait for all processes to reach the barrier
fabric.barrier()
print("All processes reached the barrier!")
A more realistic scenario is when downloading data.
Here, we need to ensure that processes only start to load the data once it has completed downloading.
Since downloading should be done on rank 0 only to :ref:`avoid race conditions <race conditions>`, we need a barrier:
.. code-block:: python
if fabric.global_rank == 0:
print("Downloading dataset. This can take a while ...")
download_dataset()
# All other processes wait here until rank 0 is done with downloading:
fabric.barrier()
# After everyone reached the barrier, they can access the downloaded files:
load_dataset()
----
.. _broadcast collective:
*********
Broadcast
*********
.. figure:: https://pl-public-data.s3.amazonaws.com/assets_lightning/fabric_collectives_broadcast.jpeg
:alt: The broadcast collective operation
:width: 100%
The broadcast operation sends a tensor of data from one process to all other processes so that all end up with the same data.
.. code-block:: python
fabric = Fabric(...)
# Transfer a tensor from one process to all the others
result = fabric.broadcast(tensor)
# By default, the source is the process rank 0 ...
result = fabric.broadcast(tensor, src=0)
# ... which can be change to a different rank
result = fabric.broadcast(tensor, src=3)
A concrete example:
.. code-block:: python
fabric = Fabric(devices=4, accelerator="cpu")
fabric.launch()
# Data is different on each process
learning_rate = torch.rand(1)
print("Before broadcast:", learning_rate)
# Transfer the tensor from one process to all the others
learning_rate = fabric.broadcast(learning_rate)
print("After broadcast:", learning_rate)
----
******
Gather
******
.. figure:: https://pl-public-data.s3.amazonaws.com/assets_lightning/fabric_collectives_all-gather.jpeg
:alt: The All-gather collective operation
:width: 100%
The gather operation transfers the tensors from each process to every other process and stacks the results.
As opposed to the :ref:`broadcast <broadcast collective>`, every process gets the data from every other process, not just from a particular rank.
.. code-block:: python
fabric = Fabric(...)
# Gather the data from
result = fabric.all_gather(tensor)
# Tip: Turn off gradient syncing if you don't need to back-propagate through it
with torch.no_grad():
result = fabric.all_gather(tensor)
A concrete example:
.. code-block:: python
fabric = Fabric(devices=4, accelerator="cpu")
fabric.launch()
# Data is different in each process
result = torch.tensor(10 * fabric.global_rank)
# Every process gathers the tensors from all other processes
# and stacks the result:
result = fabric.all_gather(data)
print("Result of all-gather:", result) # tensor([ 0, 10, 20, 30])
----
******
Reduce
******
.. figure:: https://pl-public-data.s3.amazonaws.com/assets_lightning/fabric_collectives_all-reduce.jpeg
:alt: The All-reduce collective operation
:width: 100%
.. code-block:: python
fabric = Fabric(...)
# Coming soon
result = fabric.all_reduce(tensor)

View File

@ -93,7 +93,7 @@ Fundamentals
:tag: basic
.. displayitem::
:header: Distributed Operation
:header: Launch Distributed Training
:description: Launch a Python script on multiple devices and machines
:button_link: fundamentals/launch.html
:col_css: col-md-4
@ -193,9 +193,9 @@ Advanced Topics
:tag: advanced
.. displayitem::
:header: Collectives
:header: Distributed Communication
:description: Learn all about communication primitives for distributed operation. Gather, reduce, broadcast, etc.
:button_link: advanced/collectives.html
:button_link: advanced/distributed_communication.html
:col_css: col-md-4
:height: 160
:tag: advanced

View File

@ -143,3 +143,37 @@ Launch inside a Notebook
It is also possible to use Fabric in a Jupyter notebook (including Google Colab, Kaggle, etc.) and launch multiple processes there.
You can learn more about it :ref:`here <Fabric in Notebooks>`.
----
**********
Next steps
**********
.. raw:: html
<div class="display-card-container">
<div class="row">
.. displayitem::
:header: Mixed Precision Training
:description: Save memory and speed up training using mixed precision
:col_css: col-md-4
:button_link: ../fundamentals/precision.html
:height: 160
:tag: intermediate
.. displayitem::
:header: Distributed Communication
:description: Learn all about communication primitives for distributed operation. Gather, reduce, broadcast, etc.
:button_link: ../advanced/distributed_communication.html
:col_css: col-md-4
:height: 160
:tag: advanced
.. raw:: html
</div>
</div>