diff --git a/docs/source/common/checkpointing.rst b/docs/source/common/checkpointing.rst index 8d96bfb3c1..7349605279 100644 --- a/docs/source/common/checkpointing.rst +++ b/docs/source/common/checkpointing.rst @@ -28,11 +28,27 @@ A Lightning checkpoint has everything needed to restore a training session inclu - LightningModule's state_dict - State of all optimizers - State of all learning rate schedulers -- State of all callbacks +- State of all callbacks (for stateful callbacks) +- State of datamodule (for stateful datamodules) - The hyperparameters used for that model if passed in as hparams (Argparse.Namespace) - State of Loops (if using Fault-Tolerant training) +Individual Component States +=========================== + +Each component can save and load its state by implementing the PyTorch ``state_dict``, ``load_state_dict`` stateful protocol. +For details on implementing your own stateful callbacks and datamodules, refer to the individual docs pages at :doc:`callbacks <../extensions/callbacks>` and :doc:`datamodules <../extensions/datamodules>`. + + +Operating on Global Checkpoint Component States +=============================================== + +If you need to operate on the global component state (i.e. the entire checkpoint dictionary), you can read/add/delete/modify custom states in your checkpoints before they are being saved or loaded. +For this you can override :meth:`~pytorch_lightning.core.hooks.CheckpointHooks.on_save_checkpoint` and :meth:`~pytorch_lightning.core.hooks.CheckpointHooks.on_load_checkpoint` in your ``LightningModule`` +or :meth:`~pytorch_lightning.callbacks.base.Callback.on_save_checkpoint` and :meth:`~pytorch_lightning.callbacks.base.Callback.on_load_checkpoint` methods in your ``Callback``. + + ***************** Checkpoint Saving ***************** @@ -102,14 +118,6 @@ If using custom saving functions cannot be avoided, we recommend using the :func model parallel distributed strategies such as deepspeed or sharded training. -Modifying Checkpoint When Saving and Loading -============================================ - -You can add/delete/modify custom states in your checkpoints before they are being saved or loaded. For this you can override :meth:`~pytorch_lightning.core.hooks.CheckpointHooks.on_save_checkpoint` -and :meth:`~pytorch_lightning.core.hooks.CheckpointHooks.on_load_checkpoint` in your ``LightningModule`` or :meth:`~pytorch_lightning.callbacks.base.Callback.on_save_checkpoint` and -:meth:`~pytorch_lightning.callbacks.base.Callback.on_load_checkpoint` methods in your ``Callback``. - - Checkpointing Hyperparameters ============================= diff --git a/docs/source/extensions/callbacks.rst b/docs/source/extensions/callbacks.rst index d5f02ee5c1..19b43f4f4d 100644 --- a/docs/source/extensions/callbacks.rst +++ b/docs/source/extensions/callbacks.rst @@ -116,8 +116,8 @@ Persisting State ---------------- Some callbacks require internal state in order to function properly. You can optionally -choose to persist your callback's state as part of model checkpoint files using the callback hooks -:meth:`~pytorch_lightning.callbacks.Callback.on_save_checkpoint` and :meth:`~pytorch_lightning.callbacks.Callback.on_load_checkpoint`. +choose to persist your callback's state as part of model checkpoint files using +:meth:`~pytorch_lightning.callbacks.Callback.state_dict` and :meth:`~pytorch_lightning.callbacks.Callback.load_state_dict`. Note that the returned state must be able to be pickled. When your callback is meant to be used only as a singleton callback then implementing the above two hooks is enough @@ -147,10 +147,10 @@ the following example. if self.what == "batches": self.state["batches"] += 1 - def on_load_checkpoint(self, trainer, pl_module, callback_state): - self.state.update(callback_state) + def load_state_dict(self, state_dict): + self.state.update(state_dict) - def on_save_checkpoint(self, trainer, pl_module, checkpoint): + def state_dict(self): return self.state.copy() @@ -422,12 +422,24 @@ on_exception .. automethod:: pytorch_lightning.callbacks.Callback.on_exception :noindex: +state_dict +~~~~~~~~~~ + +.. automethod:: pytorch_lightning.callbacks.Callback.state_dict + :noindex: + on_save_checkpoint ~~~~~~~~~~~~~~~~~~ .. automethod:: pytorch_lightning.callbacks.Callback.on_save_checkpoint :noindex: +load_state_dict +~~~~~~~~~~~~~~~ + +.. automethod:: pytorch_lightning.callbacks.Callback.load_state_dict + :noindex: + on_load_checkpoint ~~~~~~~~~~~~~~~~~~