From 85786d0c83e52c99f724e8fec0cd12e284651198 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Wed, 18 Jan 2023 23:30:51 +0100 Subject: [PATCH] Distributed communication docs for Lite (#16373) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Justus Schock <12886177+justusschock@users.noreply.github.com> Co-authored-by: Carlos MocholĂ­ Co-authored-by: Jirka Borovec <6035284+Borda@users.noreply.github.com> --- .../fabric/advanced/collectives.rst | 48 ---- .../advanced/distributed_communication.rst | 238 ++++++++++++++++++ docs/source-pytorch/fabric/fabric.rst | 6 +- .../fabric/fundamentals/launch.rst | 34 +++ 4 files changed, 275 insertions(+), 51 deletions(-) delete mode 100644 docs/source-pytorch/fabric/advanced/collectives.rst create mode 100644 docs/source-pytorch/fabric/advanced/distributed_communication.rst diff --git a/docs/source-pytorch/fabric/advanced/collectives.rst b/docs/source-pytorch/fabric/advanced/collectives.rst deleted file mode 100644 index 1905873902..0000000000 --- a/docs/source-pytorch/fabric/advanced/collectives.rst +++ /dev/null @@ -1,48 +0,0 @@ -:orphan: - -########################################### -Communication between distributed processes -########################################### - -Page is under construction. - ----- - - -You can also easily use distributed collectives if required. - -.. code-block:: python - - fabric = Fabric() - - # Transfer and concatenate tensors across processes - fabric.all_gather(...) - - # Transfer an object from one process to all the others - fabric.broadcast(..., src=...) - - # The total number of processes running across all devices and nodes. - fabric.world_size - - # The global index of the current process across all devices and nodes. - fabric.global_rank - - # The index of the current process among the processes running on the local node. - fabric.local_rank - - # The index of the current node. - fabric.node_rank - - # Whether this global rank is rank zero. - if fabric.is_global_zero: - # do something on rank 0 - ... - - # Wait for all processes to enter this call. - fabric.barrier() - - -The code stays agnostic, whether you are running on CPU, on two GPUS or on multiple machines with many GPUs. - -If you require custom data or model device placement, you can deactivate :class:`~lightning_fabric.fabric.Fabric`'s automatic placement by doing ``fabric.setup_dataloaders(..., move_to_device=False)`` for the data and ``fabric.setup(..., move_to_device=False)`` for the model. -Furthermore, you can access the current device from ``fabric.device`` or rely on :meth:`~lightning_fabric.fabric.Fabric.to_device` utility to move an object to the current device. diff --git a/docs/source-pytorch/fabric/advanced/distributed_communication.rst b/docs/source-pytorch/fabric/advanced/distributed_communication.rst new file mode 100644 index 0000000000..405abd4405 --- /dev/null +++ b/docs/source-pytorch/fabric/advanced/distributed_communication.rst @@ -0,0 +1,238 @@ +:orphan: + +########################################### +Communication between distributed processes +########################################### + +With Fabric you can easily access information about a process or send data between processes with a standardized API and agnostic to the distributed strategy. + + +---- + + +******************* +Rank and world size +******************* + +The rank assigned to a process is a zero-based index in the range of *0, ..., world size - 1*, where *world size* is the total number of distributed processes. +If you are using multi-GPU, think of the rank as the *GPU ID* or *GPU index*, although rank extends to distributed processing in general. + +The rank is unique across all processes, regardless of how they are distributed across machines, and it is therefore also called **global rank**. +We can also identify processes by their **local rank**, which is only unique among processes runing on the same machine, but is not unique globally across all machines. +Finally, each process is associated with a **node rank** in the range *0, ..., num nodes - 1*, which identifies on which machine (node) the process is running on. + +.. figure:: https://pl-public-data.s3.amazonaws.com/assets_lightning/fabric_collectives_ranks.jpeg + :alt: The different type of process ranks: Local, global, node. + :width: 100% + +Here is how you launch multiple processes in Fabric: + +.. code-block:: python + + from lightning.fabric import Fabric + + # Devices and num_nodes determine how many processes there are + fabric = Fabric(devices=2, num_nodes=3) + fabric.launch() + +Learn more about :doc:`launching distributed training <../fundamentals/launch>`. +And here is how you access all rank and world size information: + +.. code-block:: python + + # The total number of processes running across all devices and nodes + fabric.world_size # 2 * 3 = 6 + + # The global index of the current process across all devices and nodes + fabric.global_rank # -> {0, 1, 2, 3, 4, 5} + + # The index of the current process among the processes running on the local node + fabric.local_rank # -> {0, 1} + + # The index of the current node + fabric.node_rank # -> {0, 1, 2} + + # Do something only on rank 0 + if fabric.global_rank == 0: + ... + + +.. _race conditions: + +Avoid race conditions +===================== + +Access to the rank information helps you avoid *race conditions* which could crash your script or lead to corrupted data. +Such conditions can occur when multiple processes are trying to write to the same file all at the same time, for example, in the case of writing a checkpoint file or downloading a dataset. +Avoid this from happening by guarding your logic with a rank check: + +.. code-block:: python + + # Only write files from one process (rank 0) ... + if fabric.global_rank == 0: + with open("output.txt", "w") as file: + file.write(...) + + # ... or save from all processes but don't write to the same file + with open(f"output-{fabric.global_rank}.txt", "w") as file: + file.write(...) + + # Multi-node: download a dataset, the filesystem between nodes is shared + if fabric.global_rank == 0: + download_dataset() + + # Multi-node: download a dataset, the filesystem between nodes is NOT shared + if fabric.local_rank == 0: + download_dataset() + +---- + + +******* +Barrier +******* + +The barrier forces every process to wait until all processes have reached it. +In other words, it is a **synchronization**. + +.. figure:: https://pl-public-data.s3.amazonaws.com/assets_lightning/fabric_collectives_barrier.jpeg + :alt: The barrier for process synchronization + :width: 100% + +A barrier is needed when processes do different amounts of work and as a result fall out of sync. + +.. code-block:: python + + fabric = Fabric(accelerator="cpu", devices=4) + fabric.launch() + + # Simulate each process taking a different amount of time + sleep(2 * fabric.global_rank) + print(f"Process {fabric.global_rank} is done.") + + # Wait for all processes to reach the barrier + fabric.barrier() + print("All processes reached the barrier!") + + +A more realistic scenario is when downloading data. +Here, we need to ensure that processes only start to load the data once it has completed downloading. +Since downloading should be done on rank 0 only to :ref:`avoid race conditions `, we need a barrier: + +.. code-block:: python + + if fabric.global_rank == 0: + print("Downloading dataset. This can take a while ...") + download_dataset() + + # All other processes wait here until rank 0 is done with downloading: + fabric.barrier() + + # After everyone reached the barrier, they can access the downloaded files: + load_dataset() + + +---- + +.. _broadcast collective: + +********* +Broadcast +********* + +.. figure:: https://pl-public-data.s3.amazonaws.com/assets_lightning/fabric_collectives_broadcast.jpeg + :alt: The broadcast collective operation + :width: 100% + +The broadcast operation sends a tensor of data from one process to all other processes so that all end up with the same data. + +.. code-block:: python + + fabric = Fabric(...) + + # Transfer a tensor from one process to all the others + result = fabric.broadcast(tensor) + + # By default, the source is the process rank 0 ... + result = fabric.broadcast(tensor, src=0) + + # ... which can be change to a different rank + result = fabric.broadcast(tensor, src=3) + + +A concrete example: + +.. code-block:: python + + fabric = Fabric(devices=4, accelerator="cpu") + fabric.launch() + + # Data is different on each process + learning_rate = torch.rand(1) + print("Before broadcast:", learning_rate) + + # Transfer the tensor from one process to all the others + learning_rate = fabric.broadcast(learning_rate) + print("After broadcast:", learning_rate) + + +---- + + +****** +Gather +****** + +.. figure:: https://pl-public-data.s3.amazonaws.com/assets_lightning/fabric_collectives_all-gather.jpeg + :alt: The All-gather collective operation + :width: 100% + +The gather operation transfers the tensors from each process to every other process and stacks the results. +As opposed to the :ref:`broadcast `, every process gets the data from every other process, not just from a particular rank. + +.. code-block:: python + + fabric = Fabric(...) + + # Gather the data from + result = fabric.all_gather(tensor) + + # Tip: Turn off gradient syncing if you don't need to back-propagate through it + with torch.no_grad(): + result = fabric.all_gather(tensor) + + + +A concrete example: + +.. code-block:: python + + fabric = Fabric(devices=4, accelerator="cpu") + fabric.launch() + + # Data is different in each process + result = torch.tensor(10 * fabric.global_rank) + + # Every process gathers the tensors from all other processes + # and stacks the result: + result = fabric.all_gather(data) + print("Result of all-gather:", result) # tensor([ 0, 10, 20, 30]) + + +---- + + +****** +Reduce +****** + +.. figure:: https://pl-public-data.s3.amazonaws.com/assets_lightning/fabric_collectives_all-reduce.jpeg + :alt: The All-reduce collective operation + :width: 100% + +.. code-block:: python + + fabric = Fabric(...) + + # Coming soon + result = fabric.all_reduce(tensor) diff --git a/docs/source-pytorch/fabric/fabric.rst b/docs/source-pytorch/fabric/fabric.rst index d7d7184ba2..464d541077 100644 --- a/docs/source-pytorch/fabric/fabric.rst +++ b/docs/source-pytorch/fabric/fabric.rst @@ -93,7 +93,7 @@ Fundamentals :tag: basic .. displayitem:: - :header: Distributed Operation + :header: Launch Distributed Training :description: Launch a Python script on multiple devices and machines :button_link: fundamentals/launch.html :col_css: col-md-4 @@ -193,9 +193,9 @@ Advanced Topics :tag: advanced .. displayitem:: - :header: Collectives + :header: Distributed Communication :description: Learn all about communication primitives for distributed operation. Gather, reduce, broadcast, etc. - :button_link: advanced/collectives.html + :button_link: advanced/distributed_communication.html :col_css: col-md-4 :height: 160 :tag: advanced diff --git a/docs/source-pytorch/fabric/fundamentals/launch.rst b/docs/source-pytorch/fabric/fundamentals/launch.rst index 043f2c7f65..efa24c6be3 100644 --- a/docs/source-pytorch/fabric/fundamentals/launch.rst +++ b/docs/source-pytorch/fabric/fundamentals/launch.rst @@ -143,3 +143,37 @@ Launch inside a Notebook It is also possible to use Fabric in a Jupyter notebook (including Google Colab, Kaggle, etc.) and launch multiple processes there. You can learn more about it :ref:`here `. + + +---- + + +********** +Next steps +********** + +.. raw:: html + +
+
+ +.. displayitem:: + :header: Mixed Precision Training + :description: Save memory and speed up training using mixed precision + :col_css: col-md-4 + :button_link: ../fundamentals/precision.html + :height: 160 + :tag: intermediate + +.. displayitem:: + :header: Distributed Communication + :description: Learn all about communication primitives for distributed operation. Gather, reduce, broadcast, etc. + :button_link: ../advanced/distributed_communication.html + :col_css: col-md-4 + :height: 160 + :tag: advanced + +.. raw:: html + +
+