[docs] Update `checkpointing.rst` and `callbacks.rst` for Stateful support (#12351)
This commit is contained in:
parent
4ca3572051
commit
ec7fa1e2d8
|
@ -28,11 +28,27 @@ A Lightning checkpoint has everything needed to restore a training session inclu
|
|||
- LightningModule's state_dict
|
||||
- State of all optimizers
|
||||
- State of all learning rate schedulers
|
||||
- State of all callbacks
|
||||
- State of all callbacks (for stateful callbacks)
|
||||
- State of datamodule (for stateful datamodules)
|
||||
- The hyperparameters used for that model if passed in as hparams (Argparse.Namespace)
|
||||
- State of Loops (if using Fault-Tolerant training)
|
||||
|
||||
|
||||
Individual Component States
|
||||
===========================
|
||||
|
||||
Each component can save and load its state by implementing the PyTorch ``state_dict``, ``load_state_dict`` stateful protocol.
|
||||
For details on implementing your own stateful callbacks and datamodules, refer to the individual docs pages at :doc:`callbacks <../extensions/callbacks>` and :doc:`datamodules <../extensions/datamodules>`.
|
||||
|
||||
|
||||
Operating on Global Checkpoint Component States
|
||||
===============================================
|
||||
|
||||
If you need to operate on the global component state (i.e. the entire checkpoint dictionary), you can read/add/delete/modify custom states in your checkpoints before they are being saved or loaded.
|
||||
For this you can override :meth:`~pytorch_lightning.core.hooks.CheckpointHooks.on_save_checkpoint` and :meth:`~pytorch_lightning.core.hooks.CheckpointHooks.on_load_checkpoint` in your ``LightningModule``
|
||||
or :meth:`~pytorch_lightning.callbacks.base.Callback.on_save_checkpoint` and :meth:`~pytorch_lightning.callbacks.base.Callback.on_load_checkpoint` methods in your ``Callback``.
|
||||
|
||||
|
||||
*****************
|
||||
Checkpoint Saving
|
||||
*****************
|
||||
|
@ -102,14 +118,6 @@ If using custom saving functions cannot be avoided, we recommend using the :func
|
|||
model parallel distributed strategies such as deepspeed or sharded training.
|
||||
|
||||
|
||||
Modifying Checkpoint When Saving and Loading
|
||||
============================================
|
||||
|
||||
You can add/delete/modify custom states in your checkpoints before they are being saved or loaded. For this you can override :meth:`~pytorch_lightning.core.hooks.CheckpointHooks.on_save_checkpoint`
|
||||
and :meth:`~pytorch_lightning.core.hooks.CheckpointHooks.on_load_checkpoint` in your ``LightningModule`` or :meth:`~pytorch_lightning.callbacks.base.Callback.on_save_checkpoint` and
|
||||
:meth:`~pytorch_lightning.callbacks.base.Callback.on_load_checkpoint` methods in your ``Callback``.
|
||||
|
||||
|
||||
Checkpointing Hyperparameters
|
||||
=============================
|
||||
|
||||
|
|
|
@ -116,8 +116,8 @@ Persisting State
|
|||
----------------
|
||||
|
||||
Some callbacks require internal state in order to function properly. You can optionally
|
||||
choose to persist your callback's state as part of model checkpoint files using the callback hooks
|
||||
:meth:`~pytorch_lightning.callbacks.Callback.on_save_checkpoint` and :meth:`~pytorch_lightning.callbacks.Callback.on_load_checkpoint`.
|
||||
choose to persist your callback's state as part of model checkpoint files using
|
||||
:meth:`~pytorch_lightning.callbacks.Callback.state_dict` and :meth:`~pytorch_lightning.callbacks.Callback.load_state_dict`.
|
||||
Note that the returned state must be able to be pickled.
|
||||
|
||||
When your callback is meant to be used only as a singleton callback then implementing the above two hooks is enough
|
||||
|
@ -147,10 +147,10 @@ the following example.
|
|||
if self.what == "batches":
|
||||
self.state["batches"] += 1
|
||||
|
||||
def on_load_checkpoint(self, trainer, pl_module, callback_state):
|
||||
self.state.update(callback_state)
|
||||
def load_state_dict(self, state_dict):
|
||||
self.state.update(state_dict)
|
||||
|
||||
def on_save_checkpoint(self, trainer, pl_module, checkpoint):
|
||||
def state_dict(self):
|
||||
return self.state.copy()
|
||||
|
||||
|
||||
|
@ -422,12 +422,24 @@ on_exception
|
|||
.. automethod:: pytorch_lightning.callbacks.Callback.on_exception
|
||||
:noindex:
|
||||
|
||||
state_dict
|
||||
~~~~~~~~~~
|
||||
|
||||
.. automethod:: pytorch_lightning.callbacks.Callback.state_dict
|
||||
:noindex:
|
||||
|
||||
on_save_checkpoint
|
||||
~~~~~~~~~~~~~~~~~~
|
||||
|
||||
.. automethod:: pytorch_lightning.callbacks.Callback.on_save_checkpoint
|
||||
:noindex:
|
||||
|
||||
load_state_dict
|
||||
~~~~~~~~~~~~~~~
|
||||
|
||||
.. automethod:: pytorch_lightning.callbacks.Callback.load_state_dict
|
||||
:noindex:
|
||||
|
||||
on_load_checkpoint
|
||||
~~~~~~~~~~~~~~~~~~
|
||||
|
||||
|
|
Loading…
Reference in New Issue