Distributed communication docs for Lite (#16373)
Co-authored-by: Justus Schock <12886177+justusschock@users.noreply.github.com> Co-authored-by: Carlos Mocholí <carlossmocholi@gmail.com> Co-authored-by: Jirka Borovec <6035284+Borda@users.noreply.github.com>
This commit is contained in:
parent
01668bfef2
commit
85786d0c83
|
@ -1,48 +0,0 @@
|
|||
:orphan:
|
||||
|
||||
###########################################
|
||||
Communication between distributed processes
|
||||
###########################################
|
||||
|
||||
Page is under construction.
|
||||
|
||||
----
|
||||
|
||||
|
||||
You can also easily use distributed collectives if required.
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
fabric = Fabric()
|
||||
|
||||
# Transfer and concatenate tensors across processes
|
||||
fabric.all_gather(...)
|
||||
|
||||
# Transfer an object from one process to all the others
|
||||
fabric.broadcast(..., src=...)
|
||||
|
||||
# The total number of processes running across all devices and nodes.
|
||||
fabric.world_size
|
||||
|
||||
# The global index of the current process across all devices and nodes.
|
||||
fabric.global_rank
|
||||
|
||||
# The index of the current process among the processes running on the local node.
|
||||
fabric.local_rank
|
||||
|
||||
# The index of the current node.
|
||||
fabric.node_rank
|
||||
|
||||
# Whether this global rank is rank zero.
|
||||
if fabric.is_global_zero:
|
||||
# do something on rank 0
|
||||
...
|
||||
|
||||
# Wait for all processes to enter this call.
|
||||
fabric.barrier()
|
||||
|
||||
|
||||
The code stays agnostic, whether you are running on CPU, on two GPUS or on multiple machines with many GPUs.
|
||||
|
||||
If you require custom data or model device placement, you can deactivate :class:`~lightning_fabric.fabric.Fabric`'s automatic placement by doing ``fabric.setup_dataloaders(..., move_to_device=False)`` for the data and ``fabric.setup(..., move_to_device=False)`` for the model.
|
||||
Furthermore, you can access the current device from ``fabric.device`` or rely on :meth:`~lightning_fabric.fabric.Fabric.to_device` utility to move an object to the current device.
|
|
@ -0,0 +1,238 @@
|
|||
:orphan:
|
||||
|
||||
###########################################
|
||||
Communication between distributed processes
|
||||
###########################################
|
||||
|
||||
With Fabric you can easily access information about a process or send data between processes with a standardized API and agnostic to the distributed strategy.
|
||||
|
||||
|
||||
----
|
||||
|
||||
|
||||
*******************
|
||||
Rank and world size
|
||||
*******************
|
||||
|
||||
The rank assigned to a process is a zero-based index in the range of *0, ..., world size - 1*, where *world size* is the total number of distributed processes.
|
||||
If you are using multi-GPU, think of the rank as the *GPU ID* or *GPU index*, although rank extends to distributed processing in general.
|
||||
|
||||
The rank is unique across all processes, regardless of how they are distributed across machines, and it is therefore also called **global rank**.
|
||||
We can also identify processes by their **local rank**, which is only unique among processes runing on the same machine, but is not unique globally across all machines.
|
||||
Finally, each process is associated with a **node rank** in the range *0, ..., num nodes - 1*, which identifies on which machine (node) the process is running on.
|
||||
|
||||
.. figure:: https://pl-public-data.s3.amazonaws.com/assets_lightning/fabric_collectives_ranks.jpeg
|
||||
:alt: The different type of process ranks: Local, global, node.
|
||||
:width: 100%
|
||||
|
||||
Here is how you launch multiple processes in Fabric:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
from lightning.fabric import Fabric
|
||||
|
||||
# Devices and num_nodes determine how many processes there are
|
||||
fabric = Fabric(devices=2, num_nodes=3)
|
||||
fabric.launch()
|
||||
|
||||
Learn more about :doc:`launching distributed training <../fundamentals/launch>`.
|
||||
And here is how you access all rank and world size information:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
# The total number of processes running across all devices and nodes
|
||||
fabric.world_size # 2 * 3 = 6
|
||||
|
||||
# The global index of the current process across all devices and nodes
|
||||
fabric.global_rank # -> {0, 1, 2, 3, 4, 5}
|
||||
|
||||
# The index of the current process among the processes running on the local node
|
||||
fabric.local_rank # -> {0, 1}
|
||||
|
||||
# The index of the current node
|
||||
fabric.node_rank # -> {0, 1, 2}
|
||||
|
||||
# Do something only on rank 0
|
||||
if fabric.global_rank == 0:
|
||||
...
|
||||
|
||||
|
||||
.. _race conditions:
|
||||
|
||||
Avoid race conditions
|
||||
=====================
|
||||
|
||||
Access to the rank information helps you avoid *race conditions* which could crash your script or lead to corrupted data.
|
||||
Such conditions can occur when multiple processes are trying to write to the same file all at the same time, for example, in the case of writing a checkpoint file or downloading a dataset.
|
||||
Avoid this from happening by guarding your logic with a rank check:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
# Only write files from one process (rank 0) ...
|
||||
if fabric.global_rank == 0:
|
||||
with open("output.txt", "w") as file:
|
||||
file.write(...)
|
||||
|
||||
# ... or save from all processes but don't write to the same file
|
||||
with open(f"output-{fabric.global_rank}.txt", "w") as file:
|
||||
file.write(...)
|
||||
|
||||
# Multi-node: download a dataset, the filesystem between nodes is shared
|
||||
if fabric.global_rank == 0:
|
||||
download_dataset()
|
||||
|
||||
# Multi-node: download a dataset, the filesystem between nodes is NOT shared
|
||||
if fabric.local_rank == 0:
|
||||
download_dataset()
|
||||
|
||||
----
|
||||
|
||||
|
||||
*******
|
||||
Barrier
|
||||
*******
|
||||
|
||||
The barrier forces every process to wait until all processes have reached it.
|
||||
In other words, it is a **synchronization**.
|
||||
|
||||
.. figure:: https://pl-public-data.s3.amazonaws.com/assets_lightning/fabric_collectives_barrier.jpeg
|
||||
:alt: The barrier for process synchronization
|
||||
:width: 100%
|
||||
|
||||
A barrier is needed when processes do different amounts of work and as a result fall out of sync.
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
fabric = Fabric(accelerator="cpu", devices=4)
|
||||
fabric.launch()
|
||||
|
||||
# Simulate each process taking a different amount of time
|
||||
sleep(2 * fabric.global_rank)
|
||||
print(f"Process {fabric.global_rank} is done.")
|
||||
|
||||
# Wait for all processes to reach the barrier
|
||||
fabric.barrier()
|
||||
print("All processes reached the barrier!")
|
||||
|
||||
|
||||
A more realistic scenario is when downloading data.
|
||||
Here, we need to ensure that processes only start to load the data once it has completed downloading.
|
||||
Since downloading should be done on rank 0 only to :ref:`avoid race conditions <race conditions>`, we need a barrier:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
if fabric.global_rank == 0:
|
||||
print("Downloading dataset. This can take a while ...")
|
||||
download_dataset()
|
||||
|
||||
# All other processes wait here until rank 0 is done with downloading:
|
||||
fabric.barrier()
|
||||
|
||||
# After everyone reached the barrier, they can access the downloaded files:
|
||||
load_dataset()
|
||||
|
||||
|
||||
----
|
||||
|
||||
.. _broadcast collective:
|
||||
|
||||
*********
|
||||
Broadcast
|
||||
*********
|
||||
|
||||
.. figure:: https://pl-public-data.s3.amazonaws.com/assets_lightning/fabric_collectives_broadcast.jpeg
|
||||
:alt: The broadcast collective operation
|
||||
:width: 100%
|
||||
|
||||
The broadcast operation sends a tensor of data from one process to all other processes so that all end up with the same data.
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
fabric = Fabric(...)
|
||||
|
||||
# Transfer a tensor from one process to all the others
|
||||
result = fabric.broadcast(tensor)
|
||||
|
||||
# By default, the source is the process rank 0 ...
|
||||
result = fabric.broadcast(tensor, src=0)
|
||||
|
||||
# ... which can be change to a different rank
|
||||
result = fabric.broadcast(tensor, src=3)
|
||||
|
||||
|
||||
A concrete example:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
fabric = Fabric(devices=4, accelerator="cpu")
|
||||
fabric.launch()
|
||||
|
||||
# Data is different on each process
|
||||
learning_rate = torch.rand(1)
|
||||
print("Before broadcast:", learning_rate)
|
||||
|
||||
# Transfer the tensor from one process to all the others
|
||||
learning_rate = fabric.broadcast(learning_rate)
|
||||
print("After broadcast:", learning_rate)
|
||||
|
||||
|
||||
----
|
||||
|
||||
|
||||
******
|
||||
Gather
|
||||
******
|
||||
|
||||
.. figure:: https://pl-public-data.s3.amazonaws.com/assets_lightning/fabric_collectives_all-gather.jpeg
|
||||
:alt: The All-gather collective operation
|
||||
:width: 100%
|
||||
|
||||
The gather operation transfers the tensors from each process to every other process and stacks the results.
|
||||
As opposed to the :ref:`broadcast <broadcast collective>`, every process gets the data from every other process, not just from a particular rank.
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
fabric = Fabric(...)
|
||||
|
||||
# Gather the data from
|
||||
result = fabric.all_gather(tensor)
|
||||
|
||||
# Tip: Turn off gradient syncing if you don't need to back-propagate through it
|
||||
with torch.no_grad():
|
||||
result = fabric.all_gather(tensor)
|
||||
|
||||
|
||||
|
||||
A concrete example:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
fabric = Fabric(devices=4, accelerator="cpu")
|
||||
fabric.launch()
|
||||
|
||||
# Data is different in each process
|
||||
result = torch.tensor(10 * fabric.global_rank)
|
||||
|
||||
# Every process gathers the tensors from all other processes
|
||||
# and stacks the result:
|
||||
result = fabric.all_gather(data)
|
||||
print("Result of all-gather:", result) # tensor([ 0, 10, 20, 30])
|
||||
|
||||
|
||||
----
|
||||
|
||||
|
||||
******
|
||||
Reduce
|
||||
******
|
||||
|
||||
.. figure:: https://pl-public-data.s3.amazonaws.com/assets_lightning/fabric_collectives_all-reduce.jpeg
|
||||
:alt: The All-reduce collective operation
|
||||
:width: 100%
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
fabric = Fabric(...)
|
||||
|
||||
# Coming soon
|
||||
result = fabric.all_reduce(tensor)
|
|
@ -93,7 +93,7 @@ Fundamentals
|
|||
:tag: basic
|
||||
|
||||
.. displayitem::
|
||||
:header: Distributed Operation
|
||||
:header: Launch Distributed Training
|
||||
:description: Launch a Python script on multiple devices and machines
|
||||
:button_link: fundamentals/launch.html
|
||||
:col_css: col-md-4
|
||||
|
@ -193,9 +193,9 @@ Advanced Topics
|
|||
:tag: advanced
|
||||
|
||||
.. displayitem::
|
||||
:header: Collectives
|
||||
:header: Distributed Communication
|
||||
:description: Learn all about communication primitives for distributed operation. Gather, reduce, broadcast, etc.
|
||||
:button_link: advanced/collectives.html
|
||||
:button_link: advanced/distributed_communication.html
|
||||
:col_css: col-md-4
|
||||
:height: 160
|
||||
:tag: advanced
|
||||
|
|
|
@ -143,3 +143,37 @@ Launch inside a Notebook
|
|||
|
||||
It is also possible to use Fabric in a Jupyter notebook (including Google Colab, Kaggle, etc.) and launch multiple processes there.
|
||||
You can learn more about it :ref:`here <Fabric in Notebooks>`.
|
||||
|
||||
|
||||
----
|
||||
|
||||
|
||||
**********
|
||||
Next steps
|
||||
**********
|
||||
|
||||
.. raw:: html
|
||||
|
||||
<div class="display-card-container">
|
||||
<div class="row">
|
||||
|
||||
.. displayitem::
|
||||
:header: Mixed Precision Training
|
||||
:description: Save memory and speed up training using mixed precision
|
||||
:col_css: col-md-4
|
||||
:button_link: ../fundamentals/precision.html
|
||||
:height: 160
|
||||
:tag: intermediate
|
||||
|
||||
.. displayitem::
|
||||
:header: Distributed Communication
|
||||
:description: Learn all about communication primitives for distributed operation. Gather, reduce, broadcast, etc.
|
||||
:button_link: ../advanced/distributed_communication.html
|
||||
:col_css: col-md-4
|
||||
:height: 160
|
||||
:tag: advanced
|
||||
|
||||
.. raw:: html
|
||||
|
||||
</div>
|
||||
</div>
|
||||
|
|
Loading…
Reference in New Issue