[docs] Update `checkpointing.rst` and `callbacks.rst` for Stateful support (#12351)

This commit is contained in:
jjenniferdai 2022-03-24 17:20:21 -07:00 committed by GitHub
parent 4ca3572051
commit ec7fa1e2d8
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 34 additions and 14 deletions

View File

@ -28,11 +28,27 @@ A Lightning checkpoint has everything needed to restore a training session inclu
- LightningModule's state_dict
- State of all optimizers
- State of all learning rate schedulers
- State of all callbacks
- State of all callbacks (for stateful callbacks)
- State of datamodule (for stateful datamodules)
- The hyperparameters used for that model if passed in as hparams (Argparse.Namespace)
- State of Loops (if using Fault-Tolerant training)
Individual Component States
===========================
Each component can save and load its state by implementing the PyTorch ``state_dict``, ``load_state_dict`` stateful protocol.
For details on implementing your own stateful callbacks and datamodules, refer to the individual docs pages at :doc:`callbacks <../extensions/callbacks>` and :doc:`datamodules <../extensions/datamodules>`.
Operating on Global Checkpoint Component States
===============================================
If you need to operate on the global component state (i.e. the entire checkpoint dictionary), you can read/add/delete/modify custom states in your checkpoints before they are being saved or loaded.
For this you can override :meth:`~pytorch_lightning.core.hooks.CheckpointHooks.on_save_checkpoint` and :meth:`~pytorch_lightning.core.hooks.CheckpointHooks.on_load_checkpoint` in your ``LightningModule``
or :meth:`~pytorch_lightning.callbacks.base.Callback.on_save_checkpoint` and :meth:`~pytorch_lightning.callbacks.base.Callback.on_load_checkpoint` methods in your ``Callback``.
*****************
Checkpoint Saving
*****************
@ -102,14 +118,6 @@ If using custom saving functions cannot be avoided, we recommend using the :func
model parallel distributed strategies such as deepspeed or sharded training.
Modifying Checkpoint When Saving and Loading
============================================
You can add/delete/modify custom states in your checkpoints before they are being saved or loaded. For this you can override :meth:`~pytorch_lightning.core.hooks.CheckpointHooks.on_save_checkpoint`
and :meth:`~pytorch_lightning.core.hooks.CheckpointHooks.on_load_checkpoint` in your ``LightningModule`` or :meth:`~pytorch_lightning.callbacks.base.Callback.on_save_checkpoint` and
:meth:`~pytorch_lightning.callbacks.base.Callback.on_load_checkpoint` methods in your ``Callback``.
Checkpointing Hyperparameters
=============================

View File

@ -116,8 +116,8 @@ Persisting State
----------------
Some callbacks require internal state in order to function properly. You can optionally
choose to persist your callback's state as part of model checkpoint files using the callback hooks
:meth:`~pytorch_lightning.callbacks.Callback.on_save_checkpoint` and :meth:`~pytorch_lightning.callbacks.Callback.on_load_checkpoint`.
choose to persist your callback's state as part of model checkpoint files using
:meth:`~pytorch_lightning.callbacks.Callback.state_dict` and :meth:`~pytorch_lightning.callbacks.Callback.load_state_dict`.
Note that the returned state must be able to be pickled.
When your callback is meant to be used only as a singleton callback then implementing the above two hooks is enough
@ -147,10 +147,10 @@ the following example.
if self.what == "batches":
self.state["batches"] += 1
def on_load_checkpoint(self, trainer, pl_module, callback_state):
self.state.update(callback_state)
def load_state_dict(self, state_dict):
self.state.update(state_dict)
def on_save_checkpoint(self, trainer, pl_module, checkpoint):
def state_dict(self):
return self.state.copy()
@ -422,12 +422,24 @@ on_exception
.. automethod:: pytorch_lightning.callbacks.Callback.on_exception
:noindex:
state_dict
~~~~~~~~~~
.. automethod:: pytorch_lightning.callbacks.Callback.state_dict
:noindex:
on_save_checkpoint
~~~~~~~~~~~~~~~~~~
.. automethod:: pytorch_lightning.callbacks.Callback.on_save_checkpoint
:noindex:
load_state_dict
~~~~~~~~~~~~~~~
.. automethod:: pytorch_lightning.callbacks.Callback.load_state_dict
:noindex:
on_load_checkpoint
~~~~~~~~~~~~~~~~~~