2023-01-10 19:11:03 +00:00
:orphan:
##############
Fabric Methods
##############
setup
=====
Set up a model and corresponding optimizer(s). If you need to set up multiple models, call `` setup() `` on each of them.
Moves the model and optimizer to the correct device automatically.
.. code-block :: python
model = nn.Linear(32, 64)
optimizer = torch.optim.SGD(model.parameters(), lr=0.001)
# Set up model and optimizer for accelerated training
model, optimizer = fabric.setup(model, optimizer)
# If you don't want Fabric to set the device
model, optimizer = fabric.setup(model, optimizer, move_to_device=False)
The setup method also prepares the model for the selected precision choice so that operations during `` forward() `` get
cast automatically.
setup_dataloaders
=================
2023-01-25 10:45:09 +00:00
Set up one or multiple data loaders for accelerated operation. If you run a distributed strategy (e.g., DDP), Fabric
automatically replaces the sampler. In addition, the data loader will be configured to move the returned
2023-01-10 19:11:03 +00:00
data tensors to the correct device automatically.
.. code-block :: python
train_data = torch.utils.DataLoader(train_dataset, ...)
test_data = torch.utils.DataLoader(test_dataset, ...)
train_data, test_data = fabric.setup_dataloaders(train_data, test_data)
# If you don't want Fabric to move the data to the device
train_data, test_data = fabric.setup_dataloaders(train_data, test_data, move_to_device=False)
# If you don't want Fabric to replace the sampler in the context of distributed training
train_data, test_data = fabric.setup_dataloaders(train_data, test_data, replace_sampler=False)
backward
========
This replaces any occurrences of `` loss.backward() `` and makes your code accelerator and precision agnostic.
.. code-block :: python
output = model(input)
loss = loss_fn(output, target)
# loss.backward()
fabric.backward(loss)
to_device
=========
2023-01-25 10:45:09 +00:00
Use :meth: `~lightning_fabric.fabric.Fabric.to_device` to move models, tensors, or collections of tensors to
2023-01-10 19:11:03 +00:00
the current device. By default :meth: `~lightning_fabric.fabric.Fabric.setup` and
:meth: `~lightning_fabric.fabric.Fabric.setup_dataloaders` already move the model and data to the correct
device, so calling this method is only necessary for manual operation when needed.
.. code-block :: python
data = torch.load("dataset.pt")
data = fabric.to_device(data)
seed_everything
===============
Make your code reproducible by calling this method at the beginning of your run.
.. code-block :: python
# Instead of `torch.manual_seed(...)` , call:
fabric.seed_everything(1234)
2023-01-25 10:45:09 +00:00
This covers PyTorch, NumPy, and Python random number generators. In addition, Fabric takes care of properly initializing
the seed of data loader worker processes (can be turned off by passing `` workers=False `` ).
2023-01-10 19:11:03 +00:00
autocast
========
Let the precision backend autocast the block of code under this context manager. This is optional and already done by
Fabric for the model's forward method (once the model was :meth: `~lightning_fabric.fabric.Fabric.setup` ).
You need this only if you wish to autocast more operations outside the ones in model forward:
.. code-block :: python
model, optimizer = fabric.setup(model, optimizer)
# Fabric handles precision automatically for the model
output = model(inputs)
with fabric.autocast(): # optional
loss = loss_function(output, target)
fabric.backward(loss)
...
2023-01-23 13:28:20 +00:00
See also: :doc: `../fundamentals/precision`
2023-01-10 19:11:03 +00:00
print
=====
Print to the console via the built-in print function, but only on the main process.
This avoids excessive printing and logs when running on multiple devices/nodes.
.. code-block :: python
# Print only on the main process
fabric.print(f"{epoch}/{num_epochs}| Train Epoch Loss: {loss}")
save
====
2023-01-19 20:40:12 +00:00
Save the state of objects to a checkpoint file.
Replaces all occurrences of `` torch.save(...) `` in your code.
2023-01-25 10:45:09 +00:00
Fabric will handle the saving part correctly, whether running a single device, multi-devices, or multi-nodes.
2023-01-10 19:11:03 +00:00
.. code-block :: python
2023-01-19 20:40:12 +00:00
# Define the state of your program/loop
state = {
"model1": model1,
"model2": model2,
"optimizer": optimizer,
"iteration": iteration,
}
# Instead of `torch.save(...)`
fabric.save("path/to/checkpoint.ckpt", state)
2023-01-25 10:45:09 +00:00
You should pass the model and optimizer objects directly into the dictionary so Fabric can unwrap them and automatically retrieve their *state-dict* .
2023-01-19 20:40:12 +00:00
See also: :doc: `../guide/checkpoint`
2023-01-10 19:11:03 +00:00
load
====
2023-01-19 20:40:12 +00:00
Load checkpoint contents from a file and restore the state of objects in your program.
Replaces all occurrences of `` torch.load(...) `` in your code.
2023-01-25 10:45:09 +00:00
Fabric will handle the loading part correctly, whether running a single device, multi-device, or multi-node.
2023-01-10 19:11:03 +00:00
.. code-block :: python
2023-01-19 20:40:12 +00:00
# Define the state of your program/loop
state = {
"model1": model1,
"model2": model2,
"optimizer": optimizer,
"iteration": iteration,
}
# Restore the state of objects (in-place)
fabric.load("path/to/checkpoint.ckpt", state)
# Or load everything and restore your objects manually
checkpoint = fabric.load("./checkpoints/version_2/checkpoint.ckpt")
model.load_state_dict(all_states["model"])
...
See also: :doc: `../guide/checkpoint`
2023-01-10 19:11:03 +00:00
barrier
=======
Call this if you want all processes to wait and synchronize. Once all processes have entered this call,
2023-01-25 10:45:09 +00:00
execution continues. Useful for example, when you want to download data on one process and make all others wait until
2023-01-10 19:11:03 +00:00
the data is written to disk.
.. code-block :: python
if fabric.global_rank == 0:
2023-01-24 22:35:00 +00:00
print("Downloading dataset. This can take a while ...")
download_dataset("http://...")
2023-01-10 19:11:03 +00:00
2023-01-24 22:35:00 +00:00
# All other processes wait here until rank 0 is done with downloading:
2023-01-10 19:11:03 +00:00
fabric.barrier()
2023-01-24 22:35:00 +00:00
# After everyone reached the barrier, they can access the downloaded files:
load_dataset()
See also: :doc: `../advanced/distributed_communication`
all_gather, all_reduce, broadcast
=================================
You can send tensors and other data between processes using collective operations.
The three most common ones, :meth: `~lightning_fabric.fabric.Fabric.broadcast` , :meth: `~lightning_fabric.fabric.Fabric.all_gather` and :meth: `~lightning_fabric.fabric.Fabric.all_reduce` are available directly on the Fabric object for convenience:
- :meth: `~lightning_fabric.fabric.Fabric.broadcast` : Send a tensor from one process to all others.
- :meth: `~lightning_fabric.fabric.Fabric.all_gather` : Gather tensors from every process and stack them.
- :meth: `~lightning_fabric.fabric.Fabric.all_reduce` : Apply a reduction function on tensors across processes (sum, mean, etc.).
.. code-block :: python
# Send the value of a tensor from rank 0 to all others
result = fabric.broadcast(tensor, src=0)
# Every process gets the stack of tensors from everybody else
all_tensors = fabric.all_gather(tensor)
# Sum a tensor across processes (everyone gets the result)
reduced_tensor = fabric.all_reduce(tensor, reduce_op="sum")
# Also works with a collection of tensors (dict, list, tuple):
collection = {"loss": torch.tensor(...), "data": ...}
gathered_collection = fabric.all_gather(collection, ...)
reduced_collection = fabric.all_reduce(collection, ...)
.. important ::
2023-01-25 10:45:09 +00:00
Every process needs to enter the collective calls. Otherwise, the program will hang!
2023-01-24 22:35:00 +00:00
Learn more about :doc: `distributed communication <../advanced/distributed_communication>` .
2023-01-10 19:11:03 +00:00
no_backward_sync
================
Use this context manager when performing gradient accumulation and using a distributed strategy (e.g., DDP).
It will speed up your training loop by cutting redundant communication between processes during the accumulation phase.
.. code-block :: python
# Accumulate gradient 8 batches at a time
is_accumulating = batch_idx % 8 != 0
with fabric.no_backward_sync(model, enabled=is_accumulating):
output = model(input)
loss = ...
fabric.backward(loss)
...
# Step the optimizer every 8 batches
if not is_accumulating:
optimizer.step()
optimizer.zero_grad()
Both the model's `.forward()` and the `fabric.backward()` call need to run under this context as shown in the example above.
2023-01-25 10:45:09 +00:00
For single-device strategies, it is a no-op. Some strategies don't support this:
2023-01-10 19:11:03 +00:00
- deepspeed
- dp
- xla
For these, the context manager falls back to a no-op and emits a warning.
call
====
Use this to run all registered callback hooks with a given name and inputs.
It is useful when building a Trainer that allows the user to run arbitrary code at fixed points in the training loop.
.. code-block :: python
class MyCallback:
def on_train_start(self):
...
def on_train_epoch_end(self, model, results):
...
fabric = Fabric(callbacks=[MyCallback()])
# Call any hook by name
fabric.call("on_train_start")
# Pass in additional arguments that the hook requires
fabric.call("on_train_epoch_end", model=..., results={...})
# Only the callbacks that have this method defined will be executed
fabric.call("undefined")
2023-01-23 13:28:20 +00:00
See also: :doc: `../guide/callbacks`
2023-01-10 19:11:03 +00:00
log and log_dict
================
2023-01-25 10:45:09 +00:00
These methods allow you to send scalar metrics to a logger registered in Fabric.
2023-01-10 19:11:03 +00:00
.. code-block :: python
# Set the logger in Fabric
fabric = Fabric(loggers=TensorBoardLogger(...))
# Anywhere in your training loop or model:
fabric.log("loss", loss)
# Or send multiple metrics at once:
fabric.log_dict({"loss": loss, "accuracy": acc})
If no loggers are given to Fabric (default), `` log `` and `` log_dict `` won't do anything.
Here is what's happening under the hood (pseudo code) when you call `` .log() `` or `` log_dict `` :
.. code-block :: python
# When you call .log() or .log_dict(), we do this:
for logger in fabric.loggers:
logger.log_metrics(metrics=metrics, step=step)
2023-01-23 13:28:20 +00:00
See also: :doc: `../guide/logging`