lightning/docs/source-pytorch/extensions/callbacks_state.rst

63 lines
2.5 KiB
ReStructuredText

*******************
Save Callback state
*******************
Some callbacks require internal state in order to function properly. You can optionally
choose to persist your callback's state as part of model checkpoint files using
:meth:`~lightning.pytorch.callbacks.Callback.state_dict` and :meth:`~lightning.pytorch.callbacks.Callback.load_state_dict`.
Note that the returned state must be able to be pickled.
When your callback is meant to be used only as a singleton callback then implementing the above two hooks is enough
to persist state effectively. However, if passing multiple instances of the callback to the Trainer is supported, then
the callback must define a :attr:`~lightning.pytorch.callbacks.Callback.state_key` property in order for Lightning
to be able to distinguish the different states when loading the callback state. This concept is best illustrated by
the following example.
.. testcode::
class Counter(Callback):
def __init__(self, what="epochs", verbose=True):
self.what = what
self.verbose = verbose
self.state = {"epochs": 0, "batches": 0}
@property
def state_key(self) -> str:
# note: we do not include `verbose` here on purpose
return f"Counter[what={self.what}]"
def on_train_epoch_end(self, *args, **kwargs):
if self.what == "epochs":
self.state["epochs"] += 1
def on_train_batch_end(self, *args, **kwargs):
if self.what == "batches":
self.state["batches"] += 1
def load_state_dict(self, state_dict):
self.state.update(state_dict)
def state_dict(self):
return self.state.copy()
# two callbacks of the same type are being used
trainer = Trainer(callbacks=[Counter(what="epochs"), Counter(what="batches")])
A Lightning checkpoint from this Trainer with the two stateful callbacks will include the following information:
.. code-block::
{
"state_dict": ...,
"callbacks": {
"Counter{'what': 'batches'}": {"batches": 32, "epochs": 0},
"Counter{'what': 'epochs'}": {"batches": 0, "epochs": 2},
...
}
}
The implementation of a :attr:`~lightning.pytorch.callbacks.Callback.state_key` is essential here. If it were missing,
Lightning would not be able to disambiguate the state for these two callbacks, and :attr:`~lightning.pytorch.callbacks.Callback.state_key`
by default only defines the class name as the key, e.g., here ``Counter``.