103 lines
2.9 KiB
ReStructuredText
103 lines
2.9 KiB
ReStructuredText
##############################
|
|
Saving and Loading Checkpoints
|
|
##############################
|
|
|
|
Fabric makes it easy and efficient to save the state of your training loop into a checkpoint file, no matter how large your model is.
|
|
|
|
----
|
|
|
|
|
|
********************************
|
|
Define the state of your program
|
|
********************************
|
|
|
|
To save and resume your training, you need to define which variables in your program you want to have saved.
|
|
Put everything into a dictionary, including models and optimizers and whatever metadata you have:
|
|
|
|
.. code-block:: python
|
|
|
|
# Define the state of your program/loop
|
|
state = {"model1": model1, "model2": model2, "optimizer": optimizer, "iteration": iteration, "hparams": ...}
|
|
|
|
|
|
----
|
|
|
|
|
|
*****************
|
|
Save a checkpoint
|
|
*****************
|
|
|
|
To save the state to the filesystem, pass it to the :meth:`~lightning.fabric.fabric.Fabric.save` method:
|
|
|
|
.. code-block:: python
|
|
|
|
fabric.save("path/to/checkpoint.ckpt", state)
|
|
|
|
This will unwrap your model and optimizer and automatically convert their `state_dict` for you.
|
|
Fabric and the underlying strategy will decide in which format your checkpoint gets saved.
|
|
For example, ``strategy="ddp"`` saves a single file on rank 0, while ``strategy="fsdp"`` saves multiple files from all ranks.
|
|
|
|
|
|
----
|
|
|
|
|
|
*************************
|
|
Restore from a checkpoint
|
|
*************************
|
|
|
|
You can restore the state by loading a saved checkpoint back with :meth:`~lightning.fabric.fabric.Fabric.load`:
|
|
|
|
.. code-block:: python
|
|
|
|
fabric.load("path/to/checkpoint.ckpt", state)
|
|
|
|
Fabric will replace the state of your objects in-place.
|
|
You can also request only to restore a portion of the checkpoint.
|
|
For example, you want only to restore the model weights in your inference script:
|
|
|
|
.. code-block:: python
|
|
|
|
state = {"model1": model1}
|
|
remainder = fabric.load("path/to/checkpoint.ckpt", state)
|
|
|
|
The remainder of the checkpoint that wasn't restored gets returned in case you want to do something else with it.
|
|
If you want to be in complete control of how states get restored, you can omit passing a state and get the entire raw checkpoint dictionary returned:
|
|
|
|
.. code-block:: python
|
|
|
|
# Request the raw checkpoint
|
|
full_checkpoint = fabric.load("path/to/checkpoint.ckpt")
|
|
|
|
model.load_state_dict(full_checkpoint["model"])
|
|
optimizer.load_state_dict(full_checkpoint["optimizer"])
|
|
...
|
|
|
|
|
|
|
|
----
|
|
|
|
|
|
**********
|
|
Next steps
|
|
**********
|
|
|
|
Learn from our template how Fabrics checkpoint mechanism can be integrated into a full Trainer:
|
|
|
|
.. raw:: html
|
|
|
|
<div class="display-card-container">
|
|
<div class="row">
|
|
|
|
.. displayitem::
|
|
:header: Trainer Template
|
|
:description: Take our Fabric Trainer template and customize it for your needs
|
|
:button_link: https://github.com/Lightning-AI/lightning/tree/master/examples/fabric/build_your_own_trainer
|
|
:col_css: col-md-4
|
|
:height: 150
|
|
:tag: intermediate
|
|
|
|
.. raw:: html
|
|
|
|
</div>
|
|
</div>
|