lightning/docs/source-pytorch/common/checkpointing_expert.rst

127 lines
4.5 KiB
ReStructuredText

:orphan:
.. _checkpointing_expert:
################################
Distributed checkpoints (expert)
################################
*********************************
Writing your own Checkpoint class
*********************************
We provide ``Checkpoint`` class, for easier subclassing. Users may want to subclass this class in case of writing custom ``ModelCheckpoint`` callback, so that the ``Trainer`` recognizes the custom class as a checkpointing callback.
***********************
Customize Checkpointing
***********************
.. warning:: This is an :ref:`experimental <versioning:Experimental API>` feature.
Lightning supports modifying the checkpointing save/load functionality through the ``CheckpointIO``. This encapsulates the save/load logic
that is managed by the ``Strategy``. ``CheckpointIO`` is different from :meth:`~lightning.pytorch.core.hooks.CheckpointHooks.on_save_checkpoint`
and :meth:`~lightning.pytorch.core.hooks.CheckpointHooks.on_load_checkpoint` methods as it determines how the checkpoint is saved/loaded to storage rather than
what's saved in the checkpoint.
TODO: I don't understand this...
******************************
Built-in Checkpoint IO Plugins
******************************
.. list-table:: Built-in Checkpoint IO Plugins
:widths: 25 75
:header-rows: 1
* - Plugin
- Description
* - :class:`~lightning.pytorch.plugins.io.TorchCheckpointIO`
- CheckpointIO that utilizes :func:`torch.save` and :func:`torch.load` to save and load checkpoints
respectively, common for most use cases.
* - :class:`~lightning.pytorch.plugins.io.XLACheckpointIO`
- CheckpointIO that utilizes ``xm.save`` to save checkpoints for TPU training strategies.
* - :class:`~lightning.pytorch.plugins.io.AsyncCheckpointIO`
- ``AsyncCheckpointIO`` enables saving the checkpoints asynchronously in a thread.
***************************
Custom Checkpoint IO Plugin
***************************
``CheckpointIO`` can be extended to include your custom save/load functionality to and from a path. The ``CheckpointIO`` object can be passed to either a ``Trainer`` directly or a ``Strategy`` as shown below:
.. code-block:: python
from lightning.pytorch import Trainer
from lightning.pytorch.callbacks import ModelCheckpoint
from lightning.pytorch.plugins import CheckpointIO
from lightning.pytorch.strategies import SingleDeviceStrategy
class CustomCheckpointIO(CheckpointIO):
def save_checkpoint(self, checkpoint, path, storage_options=None):
...
def load_checkpoint(self, path, storage_options=None):
...
def remove_checkpoint(self, path):
...
custom_checkpoint_io = CustomCheckpointIO()
# Either pass into the Trainer object
model = MyModel()
trainer = Trainer(
plugins=[custom_checkpoint_io],
callbacks=ModelCheckpoint(save_last=True),
)
trainer.fit(model)
# or pass into Strategy
model = MyModel()
device = torch.device("cpu")
trainer = Trainer(
strategy=SingleDeviceStrategy(device, checkpoint_io=custom_checkpoint_io),
callbacks=ModelCheckpoint(save_last=True),
)
trainer.fit(model)
.. note::
Some ``Strategy``s like ``DeepSpeedStrategy`` do not support custom ``CheckpointIO`` as checkpointing logic is not modifiable.
**************************
Asynchronous Checkpointing
**************************
.. warning:: This is an :ref:`experimental <versioning:Experimental API>` feature.
To enable saving the checkpoints asynchronously without blocking your training, you can configure
:class:`~lightning.pytorch.plugins.io.async_plugin.AsyncCheckpointIO` plugin to ``Trainer``.
.. code-block:: python
from lightning.pytorch.plugins.io import AsyncCheckpointIO
async_ckpt_io = AsyncCheckpointIO()
trainer = Trainer(plugins=[async_ckpt_io])
It uses its base ``CheckpointIO`` plugin's saving logic to save the checkpoint but performs this operation asynchronously.
By default, this base ``CheckpointIO`` will be set-up for you and all you need to provide is the ``AsyncCheckpointIO`` instance to the ``Trainer``.
But if you want the plugin to use your own custom base ``CheckpointIO`` and want the base to behave asynchronously, pass it as an argument while initializing ``AsyncCheckpointIO``.
.. code-block:: python
from lightning.pytorch.plugins.io import AsyncCheckpointIO
base_ckpt_io = MyCustomCheckpointIO()
async_ckpt_io = AsyncCheckpointIO(checkpoint_io=base_ckpt_io)
trainer = Trainer(plugins=[async_ckpt_io])