173 lines
5.3 KiB
ReStructuredText
173 lines
5.3 KiB
ReStructuredText
:orphan:
|
|
|
|
.. _checkpointing_expert:
|
|
|
|
################################
|
|
Distributed checkpoints (expert)
|
|
################################
|
|
|
|
Generally, the bigger your model is, the longer it takes to save a checkpoint to disk.
|
|
With distributed checkpoints (sometimes called sharded checkpoints), you can save and load the state of your training script with multiple GPUs or nodes more efficiently, avoiding memory issues.
|
|
|
|
|
|
----
|
|
|
|
|
|
*****************************
|
|
Save a distributed checkpoint
|
|
*****************************
|
|
|
|
The distributed checkpoint format can be enabled when you train with the :doc:`FSDP strategy <../advanced/model_parallel/fsdp>`.
|
|
|
|
.. code-block:: python
|
|
|
|
import lightning as L
|
|
from lightning.pytorch.strategies import FSDPStrategy
|
|
|
|
# 1. Select the FSDP strategy and set the sharded/distributed checkpoint format
|
|
strategy = FSDPStrategy(state_dict_type="sharded")
|
|
|
|
# 2. Pass the strategy to the Trainer
|
|
trainer = L.Trainer(devices=2, strategy=strategy, ...)
|
|
|
|
# 3. Run the trainer
|
|
trainer.fit(model)
|
|
|
|
|
|
With ``state_dict_type="sharded"``, each process/GPU will save its own file into a folder at the given path.
|
|
This reduces memory peaks and speeds up the saving to disk.
|
|
|
|
.. collapse:: Full example
|
|
|
|
.. code-block:: python
|
|
|
|
import lightning as L
|
|
from lightning.pytorch.strategies import FSDPStrategy
|
|
from lightning.pytorch.demos import LightningTransformer
|
|
|
|
model = LightningTransformer()
|
|
|
|
strategy = FSDPStrategy(state_dict_type="sharded")
|
|
trainer = L.Trainer(
|
|
accelerator="cuda",
|
|
devices=4,
|
|
strategy=strategy,
|
|
max_steps=3,
|
|
)
|
|
trainer.fit(model)
|
|
|
|
|
|
Check the contents of the checkpoint folder:
|
|
|
|
.. code-block:: bash
|
|
|
|
ls -a lightning_logs/version_0/checkpoints/epoch=0-step=3.ckpt/
|
|
|
|
.. code-block::
|
|
|
|
epoch=0-step=3.ckpt/
|
|
├── __0_0.distcp
|
|
├── __1_0.distcp
|
|
├── __2_0.distcp
|
|
├── __3_0.distcp
|
|
├── .metadata
|
|
└── meta.pt
|
|
|
|
The ``.distcp`` files contain the tensor shards from each process/GPU. You can see that the size of these files
|
|
is roughly 1/4 of the total size of the checkpoint since the script distributes the model across 4 GPUs.
|
|
|
|
|
|
----
|
|
|
|
|
|
*****************************
|
|
Load a distributed checkpoint
|
|
*****************************
|
|
|
|
You can easily load a distributed checkpoint in Trainer if your script uses :doc:`FSDP <../advanced/model_parallel/fsdp>`.
|
|
|
|
.. code-block:: python
|
|
|
|
import lightning as L
|
|
from lightning.pytorch.strategies import FSDPStrategy
|
|
|
|
# 1. Select the FSDP strategy and set the sharded/distributed checkpoint format
|
|
strategy = FSDPStrategy(state_dict_type="sharded")
|
|
|
|
# 2. Pass the strategy to the Trainer
|
|
trainer = L.Trainer(devices=2, strategy=strategy, ...)
|
|
|
|
# 3. Set the checkpoint path to load
|
|
trainer.fit(model, ckpt_path="path/to/checkpoint")
|
|
|
|
Note that you can load the distributed checkpoint even if the world size has changed, i.e., you are running on a different number of GPUs than when you saved the checkpoint.
|
|
|
|
.. collapse:: Full example
|
|
|
|
.. code-block:: python
|
|
|
|
import lightning as L
|
|
from lightning.pytorch.strategies import FSDPStrategy
|
|
from lightning.pytorch.demos import LightningTransformer
|
|
|
|
model = LightningTransformer()
|
|
|
|
strategy = FSDPStrategy(state_dict_type="sharded")
|
|
trainer = L.Trainer(
|
|
accelerator="cuda",
|
|
devices=2,
|
|
strategy=strategy,
|
|
max_steps=5,
|
|
)
|
|
trainer.fit(model, ckpt_path="lightning_logs/version_0/checkpoints/epoch=0-step=3.ckpt")
|
|
|
|
|
|
.. important::
|
|
|
|
If you want to load a distributed checkpoint into a script that doesn't use FSDP (or Trainer at all), then you will have to :ref:`convert it to a single-file checkpoint first <Convert dist-checkpoint>`.
|
|
|
|
|
|
----
|
|
|
|
|
|
.. _Convert dist-checkpoint:
|
|
|
|
********************************
|
|
Convert a distributed checkpoint
|
|
********************************
|
|
|
|
It is possible to convert a distributed checkpoint to a regular, single-file checkpoint with this utility:
|
|
|
|
.. code-block:: bash
|
|
|
|
python -m lightning.pytorch.utilities.consolidate_checkpoint path/to/my/checkpoint
|
|
|
|
You will need to do this for example if you want to load the checkpoint into a script that doesn't use FSDP, or need to export the checkpoint to a different format for deployment, evaluation, etc.
|
|
|
|
.. note::
|
|
|
|
All tensors in the checkpoint will be converted to CPU tensors, and no GPUs are required to run the conversion command.
|
|
This function assumes you have enough free CPU memory to hold the entire checkpoint in memory.
|
|
|
|
.. collapse:: Full example
|
|
|
|
Assuming you have saved a checkpoint ``epoch=0-step=3.ckpt`` using the examples above, run the following command to convert it:
|
|
|
|
.. code-block:: bash
|
|
|
|
cd lightning_logs/version_0/checkpoints
|
|
python -m lightning.pytorch.utilities.consolidate_checkpoint epoch=0-step=3.ckpt
|
|
|
|
This saves a new file ``epoch=0-step=3.ckpt.consolidated`` next to the sharded checkpoint which you can load normally in PyTorch:
|
|
|
|
.. code-block:: python
|
|
|
|
import torch
|
|
|
|
checkpoint = torch.load("epoch=0-step=3.ckpt.consolidated")
|
|
print(list(checkpoint.keys()))
|
|
print(checkpoint["state_dict"]["model.transformer.decoder.layers.31.norm1.weight"])
|
|
|
|
|
|
|
|