2023-01-04 18:11:29 +00:00
#############
Fabric (Beta)
#############
2021-10-30 10:25:52 +00:00
2023-01-04 18:11:29 +00:00
:class: `~lightning_fabric.fabric.Fabric` library allows you to scale any PyTorch model with just a few lines of code!
With Fabric you can easily scale your model to run on distributed devices using the strategy of your choice, while keeping full control over the training loop and optimization logic.
2021-11-02 15:13:01 +00:00
2023-01-04 18:11:29 +00:00
With only a few changes to your code, Fabric allows you to:
2021-10-30 10:25:52 +00:00
2023-01-04 18:11:29 +00:00
- Automatic placement of models and data onto the device
- Automatic support for mixed precision (speedup and smaller memory footprint)
- Seamless switching between hardware (CPU, GPU, TPU)
- State-of-the-art distributed training strategies (DDP, FSDP, DeepSpeed)
- Easy-to-use launch command for spawning processes (DDP, torchelastic, etc)
- Multi-node support (TorchElastic, SLURM, and more)
- You keep full control of your training loop
2021-10-30 10:25:52 +00:00
2023-01-04 18:11:29 +00:00
.. code-block :: diff
2021-10-30 10:25:52 +00:00
2023-01-04 18:11:29 +00:00
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, Dataset
2021-10-30 10:25:52 +00:00
2023-01-04 18:11:29 +00:00
+ from lightning.fabric import Fabric
2021-10-30 10:25:52 +00:00
2023-01-04 18:11:29 +00:00
class MyModel(nn.Module):
...
2021-10-30 10:25:52 +00:00
2023-01-04 18:11:29 +00:00
class MyDataset(Dataset):
...
2021-10-30 10:25:52 +00:00
2023-01-04 18:11:29 +00:00
+ fabric = Fabric(accelerator="cuda", devices=8, strategy="ddp")
+ fabric.launch()
2021-10-30 10:25:52 +00:00
2023-01-04 18:11:29 +00:00
- device = "cuda" if torch.cuda.is_available() else "cpu
model = MyModel(...)
optimizer = torch.optim.SGD(model.parameters())
+ model, optimizer = fabric.setup(model, optimizer)
dataloader = DataLoader(MyDataset(...), ...)
+ dataloader = fabric.setup_dataloaders(dataloader)
model.train()
2021-10-30 10:25:52 +00:00
2023-01-04 18:11:29 +00:00
for epoch in range(num_epochs):
for batch in dataloader:
- batch.to(device)
optimizer.zero_grad()
loss = model(batch)
- loss.backward()
+ fabric.backward(loss)
optimizer.step()
2021-10-30 10:25:52 +00:00
2023-01-04 18:11:29 +00:00
.. note :: :class: `~lightning_fabric.fabric.Fabric` is currently a beta feature. Its API is subject to change based on feedback.
2021-10-30 10:25:52 +00:00
----------
2023-01-04 18:11:29 +00:00
***** ***** ***** **
2023-01-04 15:57:18 +00:00
Convert to Fabric
2023-01-04 18:11:29 +00:00
***** ***** ***** **
2021-10-30 10:25:52 +00:00
2023-01-04 15:57:18 +00:00
Here are five easy steps to let :class: `~lightning_fabric.fabric.Fabric` scale your PyTorch models.
2021-10-30 10:25:52 +00:00
2023-01-04 18:11:29 +00:00
**Step 1:** Create the :class: `~lightning_fabric.fabric.Fabric` object at the beginning of your training code.
2021-10-30 10:25:52 +00:00
.. code-block :: python
2023-01-04 15:57:18 +00:00
from lightning.fabric import Fabric
2021-10-30 10:25:52 +00:00
2023-01-04 18:11:29 +00:00
fabric = Fabric()
2021-10-30 10:25:52 +00:00
2023-01-04 18:11:29 +00:00
**Step 2:** Call :meth: `~lightning_fabric.fabric.Fabric.setup` on each model and optimizer pair and :meth: `~lightning_fabric.fabric.Fabric.setup_dataloaders` on all your dataloaders.
2021-10-30 10:25:52 +00:00
2023-01-04 18:11:29 +00:00
.. code-block :: python
2021-10-30 10:25:52 +00:00
2023-01-04 18:11:29 +00:00
model, optimizer = fabric.setup(model, optimizer)
dataloader = fabric.setup_dataloaders(dataloader)
2021-10-30 10:25:52 +00:00
2023-01-04 18:11:29 +00:00
**Step 3:** Remove all `` .to `` and `` .cuda `` calls since :class: `~lightning_fabric.fabric.Fabric` will take care of it.
2021-10-30 10:25:52 +00:00
2023-01-04 18:11:29 +00:00
.. code-block :: diff
2021-10-30 10:25:52 +00:00
2023-01-04 18:11:29 +00:00
- model.to(device)
- batch.to(device)
2021-10-30 10:25:52 +00:00
2023-01-04 18:11:29 +00:00
**Step 4:** Replace `` loss.backward() `` by `` fabric.backward(loss) `` .
2021-10-30 10:25:52 +00:00
2023-01-04 18:11:29 +00:00
.. code-block :: diff
2021-10-30 10:25:52 +00:00
2023-01-04 18:11:29 +00:00
- loss.backward()
+ fabric.backward(loss)
2021-10-30 10:25:52 +00:00
2023-01-04 18:11:29 +00:00
**Step 5:** Run the script from the terminal with
2021-10-30 10:25:52 +00:00
2023-01-04 18:11:29 +00:00
.. code-block :: bash
2021-10-30 10:25:52 +00:00
2023-01-04 18:11:29 +00:00
lightning run model path/to/train.py``
or use the :meth: `~lightning_fabric.fabric.Fabric.launch` method in a notebook.
|
2021-10-30 10:25:52 +00:00
2023-01-04 18:11:29 +00:00
That's it! You can now train on any device at any scale with a switch of a flag.
2023-01-05 14:07:43 +00:00
Check out our examples that use Fabric:
2021-10-30 10:25:52 +00:00
2023-01-05 14:07:43 +00:00
- `Image Classification <https://github.com/Lightning-AI/lightning/blob/master/examples/fabric/image_classifier/README.md> `_
- `Generative Adversarial Network (GAN) <https://github.com/Lightning-AI/lightning/blob/master/examples/fabric/dcgan/README.md> `_
Here is how you run DDP with 8 GPUs and `torch.bfloat16 <https://pytorch.org/docs/1.10.0/generated/torch.Tensor.bfloat16.html> `_ precision:
2021-10-30 10:25:52 +00:00
2022-11-11 16:43:25 +00:00
.. code-block :: bash
2021-10-30 10:25:52 +00:00
2022-11-11 16:43:25 +00:00
lightning run model ./path/to/train.py --strategy=ddp --devices=8 --accelerator=cuda --precision="bf16"
2021-10-30 10:25:52 +00:00
2023-01-05 14:07:43 +00:00
Or `DeepSpeed Zero3 <https://www.deepspeed.ai/news/2021/03/07/zero3-offload.html> `_ with mixed precision:
2021-10-30 10:25:52 +00:00
2022-11-11 16:43:25 +00:00
.. code-block :: bash
2021-10-30 10:25:52 +00:00
2022-11-11 16:43:25 +00:00
lightning run model ./path/to/train.py --strategy=deepspeed --devices=8 --accelerator=cuda --precision=16
2021-10-30 10:25:52 +00:00
2023-01-04 15:57:18 +00:00
:class: `~lightning_fabric.fabric.Fabric` can also figure it out automatically for you!
2021-10-30 10:25:52 +00:00
2022-11-11 16:43:25 +00:00
.. code-block :: bash
lightning run model ./path/to/train.py --devices=auto --accelerator=auto --precision=16
2021-10-30 10:25:52 +00:00
You can also easily use distributed collectives if required.
.. code-block :: python
2023-01-04 15:57:18 +00:00
fabric = Fabric()
2021-10-30 10:25:52 +00:00
2022-11-11 16:43:25 +00:00
# Transfer and concatenate tensors across processes
2023-01-04 15:57:18 +00:00
fabric.all_gather(...)
2021-10-30 10:25:52 +00:00
2022-11-11 16:43:25 +00:00
# Transfer an object from one process to all the others
2023-01-04 15:57:18 +00:00
fabric.broadcast(..., src=...)
2021-10-30 10:25:52 +00:00
2022-11-11 16:43:25 +00:00
# The total number of processes running across all devices and nodes.
2023-01-04 15:57:18 +00:00
fabric.world_size
2021-10-30 10:25:52 +00:00
2022-11-11 16:43:25 +00:00
# The global index of the current process across all devices and nodes.
2023-01-04 15:57:18 +00:00
fabric.global_rank
2021-10-30 10:25:52 +00:00
2022-11-11 16:43:25 +00:00
# The index of the current process among the processes running on the local node.
2023-01-04 15:57:18 +00:00
fabric.local_rank
2021-10-30 10:25:52 +00:00
2022-11-11 16:43:25 +00:00
# The index of the current node.
2023-01-04 15:57:18 +00:00
fabric.node_rank
2021-10-30 10:25:52 +00:00
2022-11-17 23:09:58 +00:00
# Whether this global rank is rank zero.
2023-01-04 15:57:18 +00:00
if fabric.is_global_zero:
2022-11-11 16:43:25 +00:00
# do something on rank 0
...
2021-10-30 10:25:52 +00:00
2022-11-11 16:43:25 +00:00
# Wait for all processes to enter this call.
2023-01-04 15:57:18 +00:00
fabric.barrier()
2021-10-30 10:25:52 +00:00
2022-11-11 16:43:25 +00:00
The code stays agnostic, whether you are running on CPU, on two GPUS or on multiple machines with many GPUs.
2021-10-30 10:25:52 +00:00
2023-01-04 15:57:18 +00:00
If you require custom data or model device placement, you can deactivate :class: `~lightning_fabric.fabric.Fabric` 's automatic placement by doing `` fabric.setup_dataloaders(..., move_to_device=False) `` for the data and `` fabric.setup(..., move_to_device=False) `` for the model.
Furthermore, you can access the current device from `` fabric.device `` or rely on :meth: `~lightning_fabric.fabric.Fabric.to_device` utility to move an object to the current device.
2021-10-30 10:25:52 +00:00
----------
2023-01-06 09:08:55 +00:00
***** ***** ***** *** *
Fabric in Notebooks
***** ***** ***** *** *
2021-10-30 10:25:52 +00:00
2023-01-06 09:08:55 +00:00
Fabric works exactly the same way in notebooks (Jupyter, Google Colab, Kaggle, etc.) if you only run in a single process or a single GPU.
If you want to use multiprocessing, for example multi-GPU, you can put your code in a function and pass that function to the
:meth: `~lightning_fabric.fabric.Fabric.launch` method:
2021-10-30 10:25:52 +00:00
2023-01-06 09:08:55 +00:00
.. code-block :: python
# Notebook Cell
def train(fabric):
model = ...
optimizer = ...
model, optimizer = fabric.setup(model, optimizer)
...
2021-10-30 10:25:52 +00:00
2023-01-06 09:08:55 +00:00
# Notebook Cell
fabric = Fabric(accelerator="cuda", devices=2)
fabric.launch(train) # Launches the `train` function on two GPUs
2021-10-30 10:25:52 +00:00
2023-01-06 09:08:55 +00:00
As you can see, this function accepts one argument, the `` Fabric `` object, and it gets launched on as many devices as specified.
2021-10-30 10:25:52 +00:00
----------
2023-01-04 18:11:29 +00:00
***** ***** **
Fabric Flags
***** ***** **
2021-10-30 10:25:52 +00:00
2023-01-06 15:54:19 +00:00
Fabric is designed to accelerate distributed training and inference. It makes it easy to configure your device and communication strategy, and to switch seamlessly from one to the other.
2021-10-30 10:25:52 +00:00
accelerator
===========
Choose one of `` "cpu" `` , `` "gpu" `` , `` "tpu" `` , `` "auto" `` (IPU support is coming soon).
.. code-block :: python
# CPU accelerator
2023-01-04 15:57:18 +00:00
fabric = Fabric(accelerator="cpu")
2021-10-30 10:25:52 +00:00
# Running with GPU Accelerator using 2 GPUs
2023-01-04 15:57:18 +00:00
fabric = Fabric(devices=2, accelerator="gpu")
2021-10-30 10:25:52 +00:00
# Running with TPU Accelerator using 8 tpu cores
2023-01-04 15:57:18 +00:00
fabric = Fabric(devices=8, accelerator="tpu")
2021-10-30 10:25:52 +00:00
# Running with GPU Accelerator using the DistributedDataParallel strategy
2023-01-04 15:57:18 +00:00
fabric = Fabric(devices=4, accelerator="gpu", strategy="ddp")
2021-10-30 10:25:52 +00:00
2022-02-21 21:21:12 +00:00
The `` "auto" `` option recognizes the machine you are on and selects the available accelerator.
2021-10-30 10:25:52 +00:00
.. code-block :: python
# If your machine has GPUs, it will use the GPU Accelerator
2023-01-04 15:57:18 +00:00
fabric = Fabric(devices=2, accelerator="auto")
2021-10-30 10:25:52 +00:00
strategy
========
Choose a training strategy: `` "dp" `` , `` "ddp" `` , `` "ddp_spawn" `` , `` "tpu_spawn" `` , `` "deepspeed" `` , `` "ddp_sharded" `` , or `` "ddp_sharded_spawn" `` .
.. code-block :: python
# Running with the DistributedDataParallel strategy on 4 GPUs
2023-01-04 15:57:18 +00:00
fabric = Fabric(strategy="ddp", accelerator="gpu", devices=4)
2021-10-30 10:25:52 +00:00
# Running with the DDP Spawn strategy using 4 cpu processes
2023-01-04 15:57:18 +00:00
fabric = Fabric(strategy="ddp_spawn", accelerator="cpu", devices=4)
2021-10-30 10:25:52 +00:00
2022-03-29 12:09:41 +00:00
Additionally, you can pass in your custom strategy by configuring additional parameters.
2021-10-30 10:25:52 +00:00
.. code-block :: python
2023-01-04 15:57:18 +00:00
from lightning.fabric.strategies import DeepSpeedStrategy
2021-10-30 10:25:52 +00:00
2023-01-04 15:57:18 +00:00
fabric = Fabric(strategy=DeepSpeedStrategy(stage=2), accelerator="gpu", devices=2)
2021-10-30 10:25:52 +00:00
2022-12-20 17:38:28 +00:00
Support for Fully Sharded training strategies are coming soon.
2021-10-30 10:25:52 +00:00
devices
=======
Configure the devices to run on. Can be of type:
- int: the number of devices (e.g., GPUs) to train on
- list of int: which device index (e.g., GPU ID) to train on (0-indexed)
- str: a string representation of one of the above
.. code-block :: python
2023-01-04 15:57:18 +00:00
# default used by Fabric, i.e., use the CPU
fabric = Fabric(devices=None)
2021-10-30 10:25:52 +00:00
# equivalent
2023-01-04 15:57:18 +00:00
fabric = Fabric(devices=0)
2021-10-30 10:25:52 +00:00
2022-02-21 21:21:12 +00:00
# int: run on two GPUs
2023-01-04 15:57:18 +00:00
fabric = Fabric(devices=2, accelerator="gpu")
2021-10-30 10:25:52 +00:00
# list: run on GPUs 1, 4 (by bus ordering)
2023-01-04 15:57:18 +00:00
fabric = Fabric(devices=[1, 4], accelerator="gpu")
fabric = Fabric(devices="1, 4", accelerator="gpu") # equivalent
2021-10-30 10:25:52 +00:00
# -1: run on all GPUs
2023-01-04 15:57:18 +00:00
fabric = Fabric(devices=-1, accelerator="gpu")
fabric = Fabric(devices="-1", accelerator="gpu") # equivalent
2021-10-30 10:25:52 +00:00
num_nodes
=========
Number of cluster nodes for distributed operation.
.. code-block :: python
2023-01-04 15:57:18 +00:00
# Default used by Fabric
fabric = Fabric(num_nodes=1)
2021-10-30 10:25:52 +00:00
# Run on 8 nodes
2023-01-04 15:57:18 +00:00
fabric = Fabric(num_nodes=8)
2021-10-30 10:25:52 +00:00
Learn more about distributed multi-node training on clusters :doc: `here <../clouds/cluster>` .
precision
=========
2023-01-04 18:11:29 +00:00
Fabric supports double precision (64), full precision (32), or half precision (16) operation (including `bfloat16 <https://pytorch.org/docs/1.10.0/generated/torch.Tensor.bfloat16.html> `_ ).
2022-02-21 21:21:12 +00:00
Half precision, or mixed precision, is the combined use of 32 and 16-bit floating points to reduce the memory footprint during model training.
2021-10-30 10:25:52 +00:00
This can result in improved performance, achieving significant speedups on modern GPUs.
.. code-block :: python
2023-01-04 15:57:18 +00:00
# Default used by the Fabric
fabric = Fabric(precision=32, devices=1)
2021-10-30 10:25:52 +00:00
# 16-bit (mixed) precision
2023-01-04 15:57:18 +00:00
fabric = Fabric(precision=16, devices=1)
2021-10-30 10:25:52 +00:00
# 16-bit bfloat precision
2023-01-04 15:57:18 +00:00
fabric = Fabric(precision="bf16", devices=1)
2021-10-30 10:25:52 +00:00
# 64-bit (double) precision
2023-01-04 15:57:18 +00:00
fabric = Fabric(precision=64, devices=1)
2021-10-30 10:25:52 +00:00
plugins
=======
:ref: `Plugins` allow you to connect arbitrary backends, precision libraries, clusters etc. For example:
To define your own behavior, subclass the relevant class and pass it in. Here's an example linking up your own
2023-01-04 15:57:18 +00:00
:class: `~lightning.fabric.plugins.environments.ClusterEnvironment` .
2021-10-30 10:25:52 +00:00
.. code-block :: python
2023-01-04 15:57:18 +00:00
from lightning.fabric.plugins.environments import ClusterEnvironment
2021-10-30 10:25:52 +00:00
class MyCluster(ClusterEnvironment):
@property
def main_address(self):
return your_main_address
@property
def main_port(self):
return your_main_port
def world_size(self):
return the_world_size
2023-01-04 15:57:18 +00:00
fabric = Fabric(plugins=[MyCluster()], ...)
2021-10-30 10:25:52 +00:00
2023-01-06 15:54:19 +00:00
callbacks
=========
A callback class is a collection of methods that the training loop can call at a specific point in time, for example, at the end of an epoch.
Add callbacks to Fabric to inject logic into your training loop from an external callback class.
.. code-block :: python
class MyCallback:
def on_train_epoch_end(self, results):
...
You can then register this callback, or multiple ones directly in Fabric:
.. code-block :: python
fabric = Fabric(callbacks=[MyCallback()])
Then, in your training loop, you can call a hook by its name. Any callback objects that have this hook will execute it:
.. code-block :: python
# Call any hook by name
fabric.call("on_train_epoch_end", results={...})
2021-10-30 10:25:52 +00:00
----------
2023-01-04 18:11:29 +00:00
***** ***** *** *
Fabric Methods
***** ***** *** *
2021-10-30 10:25:52 +00:00
setup
=====
2022-02-21 21:21:12 +00:00
Set up a model and corresponding optimizer(s). If you need to set up multiple models, call `` setup() `` on each of them.
2021-10-30 10:25:52 +00:00
Moves the model and optimizer to the correct device automatically.
.. code-block :: python
model = nn.Linear(32, 64)
optimizer = torch.optim.SGD(model.parameters(), lr=0.001)
2022-02-21 21:21:12 +00:00
# Set up model and optimizer for accelerated training
2023-01-04 15:57:18 +00:00
model, optimizer = fabric.setup(model, optimizer)
2021-10-30 10:25:52 +00:00
2023-01-04 15:57:18 +00:00
# If you don't want Fabric to set the device
model, optimizer = fabric.setup(model, optimizer, move_to_device=False)
2021-10-30 10:25:52 +00:00
The setup method also prepares the model for the selected precision choice so that operations during `` forward() `` get
cast automatically.
setup_dataloaders
=================
2023-01-04 15:57:18 +00:00
Set up one or multiple dataloaders for accelerated operation. If you are running a distributed strategy (e.g., DDP), Fabric
2022-02-21 21:21:12 +00:00
replaces the sampler automatically for you. In addition, the dataloader will be configured to move the returned
2021-10-30 10:25:52 +00:00
data tensors to the correct device automatically.
.. code-block :: python
train_data = torch.utils.DataLoader(train_dataset, ...)
test_data = torch.utils.DataLoader(test_dataset, ...)
2023-01-04 15:57:18 +00:00
train_data, test_data = fabric.setup_dataloaders(train_data, test_data)
2021-10-30 10:25:52 +00:00
2023-01-04 15:57:18 +00:00
# If you don't want Fabric to move the data to the device
train_data, test_data = fabric.setup_dataloaders(train_data, test_data, move_to_device=False)
2021-10-30 10:25:52 +00:00
2023-01-04 15:57:18 +00:00
# If you don't want Fabric to replace the sampler in the context of distributed training
train_data, test_data = fabric.setup_dataloaders(train_data, test_data, replace_sampler=False)
2021-10-30 10:25:52 +00:00
backward
========
2022-02-21 21:21:12 +00:00
This replaces any occurrences of `` loss.backward() `` and makes your code accelerator and precision agnostic.
2021-10-30 10:25:52 +00:00
.. code-block :: python
output = model(input)
loss = loss_fn(output, target)
# loss.backward()
2023-01-04 15:57:18 +00:00
fabric.backward(loss)
2021-10-30 10:25:52 +00:00
to_device
=========
2023-01-04 15:57:18 +00:00
Use :meth: `~lightning_fabric.fabric.Fabric.to_device` to move models, tensors or collections of tensors to
the current device. By default :meth: `~lightning_fabric.fabric.Fabric.setup` and
:meth: `~lightning_fabric.fabric.Fabric.setup_dataloaders` already move the model and data to the correct
2021-10-30 10:25:52 +00:00
device, so calling this method is only necessary for manual operation when needed.
.. code-block :: python
data = torch.load("dataset.pt")
2023-01-04 15:57:18 +00:00
data = fabric.to_device(data)
2021-10-30 10:25:52 +00:00
seed_everything
===============
Make your code reproducible by calling this method at the beginning of your run.
.. code-block :: python
# Instead of `torch.manual_seed(...)` , call:
2023-01-04 15:57:18 +00:00
fabric.seed_everything(1234)
2021-10-30 10:25:52 +00:00
2023-01-04 15:57:18 +00:00
This covers PyTorch, NumPy and Python random number generators. In addition, Fabric takes care of properly initializing
2021-10-30 10:25:52 +00:00
the seed of dataloader worker processes (can be turned off by passing `` workers=False `` ).
autocast
========
Let the precision backend autocast the block of code under this context manager. This is optional and already done by
2023-01-04 15:57:18 +00:00
Fabric for the model's forward method (once the model was :meth: `~lightning_fabric.fabric.Fabric.setup` ).
2021-10-30 10:25:52 +00:00
You need this only if you wish to autocast more operations outside the ones in model forward:
.. code-block :: python
2023-01-04 15:57:18 +00:00
model, optimizer = fabric.setup(model, optimizer)
2021-10-30 10:25:52 +00:00
2023-01-04 15:57:18 +00:00
# Fabric handles precision automatically for the model
2021-10-30 10:25:52 +00:00
output = model(inputs)
2023-01-04 15:57:18 +00:00
with fabric.autocast(): # optional
2021-10-30 10:25:52 +00:00
loss = loss_function(output, target)
2023-01-04 15:57:18 +00:00
fabric.backward(loss)
2021-10-30 10:25:52 +00:00
...
print
=====
Print to the console via the built-in print function, but only on the main process.
This avoids excessive printing and logs when running on multiple devices/nodes.
.. code-block :: python
# Print only on the main process
2023-01-04 15:57:18 +00:00
fabric.print(f"{epoch}/{num_epochs}| Train Epoch Loss: {loss}")
2021-10-30 10:25:52 +00:00
save
====
2023-01-04 15:57:18 +00:00
Save contents to a checkpoint. Replaces all occurrences of `` torch.save(...) `` in your code. Fabric will take care of
2022-02-21 21:21:12 +00:00
handling the saving part correctly, no matter if you are running a single device, multi-devices or multi-nodes.
2021-10-30 10:25:52 +00:00
.. code-block :: python
# Instead of `torch.save(...)` , call:
2023-01-04 15:57:18 +00:00
fabric.save(model.state_dict(), "path/to/checkpoint.ckpt")
2021-10-30 10:25:52 +00:00
load
====
2023-01-04 15:57:18 +00:00
Load checkpoint contents from a file. Replaces all occurrences of `` torch.load(...) `` in your code. Fabric will take care of
2022-02-21 21:21:12 +00:00
handling the loading part correctly, no matter if you are running a single device, multi-device, or multi-node.
2021-10-30 10:25:52 +00:00
.. code-block :: python
# Instead of `torch.load(...)` , call:
2023-01-04 15:57:18 +00:00
fabric.load("path/to/checkpoint.ckpt")
2021-10-30 10:25:52 +00:00
barrier
=======
Call this if you want all processes to wait and synchronize. Once all processes have entered this call,
execution continues. Useful for example when you want to download data on one process and make all others wait until
the data is written to disk.
.. code-block :: python
# Download data only on one process
2023-01-04 15:57:18 +00:00
if fabric.global_rank == 0:
2021-10-30 10:25:52 +00:00
download_data("http://...")
# Wait until all processes meet up here
2023-01-04 15:57:18 +00:00
fabric.barrier()
2021-10-30 10:25:52 +00:00
# All processes are allowed to read the data now
2022-10-19 19:55:12 +00:00
no_backward_sync
================
Use this context manager when performing gradient accumulation and using a distributed strategy (e.g., DDP).
It will speed up your training loop by cutting redundant communication between processes during the accumulation phase.
.. code-block :: python
# Accumulate gradient 8 batches at a time
is_accumulating = batch_idx % 8 != 0
2023-01-04 15:57:18 +00:00
with fabric.no_backward_sync(model, enabled=is_accumulating):
2022-10-19 19:55:12 +00:00
output = model(input)
loss = ...
2023-01-04 15:57:18 +00:00
fabric.backward(loss)
2022-10-19 19:55:12 +00:00
...
# Step the optimizer every 8 batches
if not is_accumulating:
optimizer.step()
optimizer.zero_grad()
2023-01-04 15:57:18 +00:00
Both the model's `.forward()` and the `fabric.backward()` call need to run under this context as shown in the example above.
2022-10-19 19:55:12 +00:00
For single-device strategies, it is a no-op. There are strategies that don't support this:
- deepspeed
- dp
- xla
For these, the context manager falls back to a no-op and emits a warning.
2023-01-06 15:54:19 +00:00
call
====
Use this to run all registered callback hooks with a given name and inputs.
It is useful when building a Trainer that allows the user to run arbitrary code at fixed points in the training loop.
.. code-block :: python
class MyCallback:
def on_train_start(self):
...
def on_train_epoch_end(self, model, results):
...
fabric = Fabric(callbacks=[MyCallback()])
# Call any hook by name
fabric.call("on_train_start")
# Pass in additional arguments that the hook requires
fabric.call("on_train_epoch_end", model=..., results={...})
# Only the callbacks that have this method defined will be executed
fabric.call("undefined")