2020-05-05 04:16:54 +02:00
.. testsetup :: *
from pytorch_lightning.trainer.trainer import Trainer
from pytorch_lightning.callbacks.base import Callback
2020-01-21 15:18:32 -05:00
.. role :: hidden
:class: hidden-section
2020-03-20 20:49:01 +01:00
.. _callbacks:
2020-08-15 23:57:33 -04:00
Callback
========
2020-10-08 05:49:56 -04:00
.. raw :: html
2020-10-12 17:57:51 -04:00
<video width="100%" max-width="400px" controls
2020-10-08 05:49:56 -04:00
poster="https://pl-bolts-doc-images.s3.us-east-2.amazonaws.com/pl_docs/trainer_flags/thumb/callbacks.jpg"
src="https://pl-bolts-doc-images.s3.us-east-2.amazonaws.com/pl_docs/trainer_flags/callbacks.mp4"></video>
|
2020-08-13 09:58:05 -04:00
A callback is a self-contained program that can be reused across projects.
2020-02-27 17:21:51 -05:00
2020-08-13 09:58:05 -04:00
Lightning has a callback system to execute callbacks when needed. Callbacks should capture NON-ESSENTIAL
2021-01-26 21:07:07 +01:00
logic that is NOT required for your :doc: `lightning module <../common/lightning_module>` to run.
2020-02-27 17:21:51 -05:00
2020-08-13 18:52:47 -04:00
Here's the flow of how the callback hooks are executed:
.. raw :: html
2020-10-12 17:57:51 -04:00
<video width="100%" max-width="400px" controls autoplay muted playsinline src="https://pl-bolts-doc-images.s3.us-east-2.amazonaws.com/pl_docs/pt_callbacks_mov.m4v"></video>
2020-08-13 18:52:47 -04:00
2020-02-27 17:21:51 -05:00
An overall Lightning system should have:
1. Trainer for all engineering
2. LightningModule for all research code.
3. Callbacks for non-essential code.
2020-08-01 22:56:34 -04:00
|
2020-02-27 17:21:51 -05:00
2020-04-05 11:38:52 +02:00
Example:
2020-05-05 04:16:54 +02:00
.. testcode ::
2020-10-12 20:22:28 -04:00
from pytorch_lightning.callbacks import Callback
2020-10-12 18:30:57 -04:00
2020-05-05 04:16:54 +02:00
2021-07-28 18:08:31 +02:00
class MyPrintingCallback(Callback):
2020-05-05 04:16:54 +02:00
def on_init_start(self, trainer):
2021-07-28 18:08:31 +02:00
print("Starting to init trainer!")
2020-05-05 04:16:54 +02:00
def on_init_end(self, trainer):
2021-07-28 18:08:31 +02:00
print("trainer is init now")
2020-05-05 04:16:54 +02:00
def on_train_end(self, trainer, pl_module):
2021-07-28 18:08:31 +02:00
print("do something when training ends")
2020-05-05 04:16:54 +02:00
trainer = Trainer(callbacks=[MyPrintingCallback()])
.. testoutput ::
2020-04-05 11:38:52 +02:00
Starting to init trainer!
trainer is init now
We successfully extended functionality without polluting our super clean
2021-01-26 21:07:07 +01:00
:doc: `lightning module <../common/lightning_module>` research code.
2020-02-27 17:21:51 -05:00
2020-08-13 09:58:05 -04:00
-----------
Examples
--------
You can do pretty much anything with callbacks.
2021-03-31 02:22:59 +09:00
- `Add a MLP to fine-tune self-supervised networks <https://lightning-bolts.readthedocs.io/en/latest/self_supervised_callbacks.html#sslonlineevaluator> `_ .
- `Find how to modify an image input to trick the classification result <https://lightning-bolts.readthedocs.io/en/latest/vision_callbacks.html#confused-logit> `_ .
- `Interpolate the latent space of any variational model <https://lightning-bolts.readthedocs.io/en/latest/variational_callbacks.html#latent-dim-interpolator> `_ .
- `Log images to Tensorboard for any model <https://lightning-bolts.readthedocs.io/en/latest/vision_callbacks.html#tensorboard-image-generator> `_ .
2020-08-13 09:58:05 -04:00
--------------
Built-in Callbacks
------------------
Lightning has a few built-in callbacks.
.. note ::
For a richer collection of callbacks, check out our
2021-03-31 02:22:59 +09:00
`bolts library <https://lightning-bolts.readthedocs.io/en/latest/callbacks.html> `_ .
2020-08-13 09:58:05 -04:00
2020-10-07 03:58:45 +06:30
.. currentmodule :: pytorch_lightning.callbacks
.. autosummary ::
:toctree: generated
:nosignatures:
:template: classtemplate.rst
2021-02-04 18:36:54 +00:00
BackboneFinetuning
BaseFinetuning
2021-09-03 05:44:36 +05:30
BasePredictionWriter
2020-10-07 03:58:45 +06:30
Callback
2021-10-13 11:29:36 -07:00
DeviceStatsMonitor
2020-10-07 03:58:45 +06:30
EarlyStopping
GPUStatsMonitor
GradientAccumulationScheduler
2021-01-13 18:42:49 +09:00
LambdaCallback
2020-10-07 03:58:45 +06:30
LearningRateMonitor
ModelCheckpoint
2021-01-27 06:00:42 +00:00
ModelPruning
2021-09-17 16:24:16 +05:30
ModelSummary
2020-10-07 03:58:45 +06:30
ProgressBar
ProgressBarBase
2021-09-17 16:24:16 +05:30
RichModelSummary
2021-09-03 05:44:36 +05:30
RichProgressBar
2021-02-18 08:51:51 -05:00
QuantizationAwareTraining
2021-02-11 00:05:59 +00:00
StochasticWeightAveraging
2021-09-03 05:44:36 +05:30
XLAStatsMonitor
2020-08-01 22:56:34 -04:00
2020-08-13 09:58:05 -04:00
----------
2020-08-01 22:56:34 -04:00
2021-08-24 19:35:19 +02:00
.. _Persisting Callback State:
2020-08-28 10:50:52 -04:00
Persisting State
----------------
Some callbacks require internal state in order to function properly. You can optionally
choose to persist your callback's state as part of model checkpoint files using the callback hooks
:meth: `~pytorch_lightning.callbacks.Callback.on_save_checkpoint` and :meth: `~pytorch_lightning.callbacks.Callback.on_load_checkpoint` .
2021-08-24 19:35:19 +02:00
Note that the returned state must be able to be pickled.
When your callback is meant to be used only as a singleton callback then implementing the above two hooks is enough
to persist state effectively. However, if passing multiple instances of the callback to the Trainer is supported, then
the callback must define a :attr: `~pytorch_lightning.callbacks.Callback.state_key` property in order for Lightning
to be able to distinguish the different states when loading the callback state. This concept is best illustrated by
the following example.
.. testcode ::
class Counter(Callback):
def __init__(self, what="epochs", verbose=True):
self.what = what
self.verbose = verbose
self.state = {"epochs": 0, "batches": 0}
@property
def state_key(self):
# note: we do not include `verbose` here on purpose
return self._generate_state_key(what=self.what)
def on_train_epoch_end(self, *args, * *kwargs):
if self.what == "epochs":
self.state["epochs"] += 1
def on_train_batch_end(self, *args, * *kwargs):
if self.what == "batches":
self.state["batches"] += 1
def on_load_checkpoint(self, trainer, pl_module, callback_state):
self.state.update(callback_state)
def on_save_checkpoint(self, trainer, pl_module, checkpoint):
return self.state.copy()
# two callbacks of the same type are being used
trainer = Trainer(callbacks=[Counter(what="epochs"), Counter(what="batches")])
A Lightning checkpoint from this Trainer with the two stateful callbacks will include the following information:
.. code-block ::
{
"state_dict": ...,
"callbacks": {
"Counter{'what': 'batches'}": {"batches": 32, "epochs": 0},
"Counter{'what': 'epochs'}": {"batches": 0, "epochs": 2},
...
}
}
2020-08-28 10:50:52 -04:00
2021-08-24 19:35:19 +02:00
The implementation of a :attr: `~pytorch_lightning.callbacks.Callback.state_key` is essential here. If it were missing,
Lightning would not be able to disambiguate the state for these two callbacks, and :attr: `~pytorch_lightning.callbacks.Callback.state_key`
by default only defines the class name as the key, e.g., here `` Counter `` .
2020-08-28 10:50:52 -04:00
2020-08-01 22:56:34 -04:00
Best Practices
--------------
The following are best practices when using/designing callbacks.
1. Callbacks should be isolated in their functionality.
2. Your callback should not rely on the behavior of other callbacks in order to work properly.
3. Do not manually call methods from the callback.
4. Directly calling methods (eg. `on_validation_end` ) is strongly discouraged.
5. Whenever possible, your callbacks should not depend on the order in which they are executed.
2020-10-12 17:57:51 -04:00
-----------
.. _hooks:
Available Callback hooks
------------------------
2020-10-20 16:01:08 +01:00
setup
^^^^^
.. automethod :: pytorch_lightning.callbacks.Callback.setup
:noindex:
teardown
^^^^^^^^
.. automethod :: pytorch_lightning.callbacks.Callback.teardown
:noindex:
on_init_start
2020-10-12 17:57:51 -04:00
^^^^^^^^^^^^^^
2020-10-20 16:01:08 +01:00
.. automethod :: pytorch_lightning.callbacks.Callback.on_init_start
2020-10-12 17:57:51 -04:00
:noindex:
2020-10-20 16:01:08 +01:00
on_init_end
^^^^^^^^^^^
2020-10-12 17:57:51 -04:00
2020-10-20 16:01:08 +01:00
.. automethod :: pytorch_lightning.callbacks.Callback.on_init_end
2020-10-12 17:57:51 -04:00
:noindex:
on_fit_start
^^^^^^^^^^^^
2021-04-14 01:19:27 +02:00
.. automethod :: pytorch_lightning.callbacks.Callback.on_fit_start
2020-10-12 17:57:51 -04:00
:noindex:
on_fit_end
^^^^^^^^^^
2020-10-20 16:01:08 +01:00
.. automethod :: pytorch_lightning.callbacks.Callback.on_fit_end
2020-10-12 17:57:51 -04:00
:noindex:
2020-10-20 16:01:08 +01:00
on_sanity_check_start
^^^^^^^^^^^^^^^^^^^^^
.. automethod :: pytorch_lightning.callbacks.Callback.on_sanity_check_start
:noindex:
on_sanity_check_end
^^^^^^^^^^^^^^^^^^^
.. automethod :: pytorch_lightning.callbacks.Callback.on_sanity_check_end
:noindex:
on_train_batch_start
^^^^^^^^^^^^^^^^^^^^
.. automethod :: pytorch_lightning.callbacks.Callback.on_train_batch_start
:noindex:
on_train_batch_end
2020-10-12 17:57:51 -04:00
^^^^^^^^^^^^^^^^^^
2020-10-20 16:01:08 +01:00
.. automethod :: pytorch_lightning.callbacks.Callback.on_train_batch_end
2020-10-12 17:57:51 -04:00
:noindex:
2020-10-20 16:01:08 +01:00
on_train_epoch_start
^^^^^^^^^^^^^^^^^^^^
.. automethod :: pytorch_lightning.callbacks.Callback.on_train_epoch_start
:noindex:
on_train_epoch_end
2020-10-12 17:57:51 -04:00
^^^^^^^^^^^^^^^^^^
2020-10-20 16:01:08 +01:00
.. automethod :: pytorch_lightning.callbacks.Callback.on_train_epoch_end
2020-10-12 17:57:51 -04:00
:noindex:
2020-10-20 16:01:08 +01:00
on_validation_epoch_start
2020-10-12 17:57:51 -04:00
^^^^^^^^^^^^^^^^^^^^^^^^^
2020-10-20 16:01:08 +01:00
.. automethod :: pytorch_lightning.callbacks.Callback.on_validation_epoch_start
2020-10-12 17:57:51 -04:00
:noindex:
2020-10-20 16:01:08 +01:00
on_validation_epoch_end
2020-10-12 17:57:51 -04:00
^^^^^^^^^^^^^^^^^^^^^^^
2020-10-20 16:01:08 +01:00
.. automethod :: pytorch_lightning.callbacks.Callback.on_validation_epoch_end
2020-10-12 17:57:51 -04:00
:noindex:
2020-10-20 16:01:08 +01:00
on_test_epoch_start
2020-10-12 17:57:51 -04:00
^^^^^^^^^^^^^^^^^^^
2020-10-20 16:01:08 +01:00
.. automethod :: pytorch_lightning.callbacks.Callback.on_test_epoch_start
2020-10-12 17:57:51 -04:00
:noindex:
2020-10-20 16:01:08 +01:00
on_test_epoch_end
2020-10-12 17:57:51 -04:00
^^^^^^^^^^^^^^^^^
2020-10-20 16:01:08 +01:00
.. automethod :: pytorch_lightning.callbacks.Callback.on_test_epoch_end
2020-10-12 17:57:51 -04:00
:noindex:
2020-10-20 16:01:08 +01:00
on_epoch_start
^^^^^^^^^^^^^^
2020-10-12 17:57:51 -04:00
2020-10-20 16:01:08 +01:00
.. automethod :: pytorch_lightning.callbacks.Callback.on_epoch_start
2020-10-12 17:57:51 -04:00
:noindex:
2020-10-20 16:01:08 +01:00
on_epoch_end
^^^^^^^^^^^^
2020-10-12 17:57:51 -04:00
2020-10-20 16:01:08 +01:00
.. automethod :: pytorch_lightning.callbacks.Callback.on_epoch_end
2020-10-12 17:57:51 -04:00
:noindex:
2020-10-20 16:01:08 +01:00
on_batch_start
^^^^^^^^^^^^^^
2020-10-12 17:57:51 -04:00
2020-10-20 16:01:08 +01:00
.. automethod :: pytorch_lightning.callbacks.Callback.on_batch_start
2020-10-12 17:57:51 -04:00
:noindex:
2020-10-20 16:01:08 +01:00
on_validation_batch_start
^^^^^^^^^^^^^^^^^^^^^^^^^
2020-10-12 17:57:51 -04:00
2020-10-20 16:01:08 +01:00
.. automethod :: pytorch_lightning.callbacks.Callback.on_validation_batch_start
2020-10-12 17:57:51 -04:00
:noindex:
2020-10-20 16:01:08 +01:00
on_validation_batch_end
^^^^^^^^^^^^^^^^^^^^^^^
2020-10-12 17:57:51 -04:00
2020-10-20 16:01:08 +01:00
.. automethod :: pytorch_lightning.callbacks.Callback.on_validation_batch_end
2020-10-12 17:57:51 -04:00
:noindex:
2020-10-20 16:01:08 +01:00
on_test_batch_start
^^^^^^^^^^^^^^^^^^^
2020-10-12 17:57:51 -04:00
2020-10-20 16:01:08 +01:00
.. automethod :: pytorch_lightning.callbacks.Callback.on_test_batch_start
2020-10-12 17:57:51 -04:00
:noindex:
2020-10-20 16:01:08 +01:00
on_test_batch_end
^^^^^^^^^^^^^^^^^
2020-10-12 17:57:51 -04:00
2020-10-20 16:01:08 +01:00
.. automethod :: pytorch_lightning.callbacks.Callback.on_test_batch_end
2020-10-12 17:57:51 -04:00
:noindex:
2020-10-20 16:01:08 +01:00
on_batch_end
2020-10-12 17:57:51 -04:00
^^^^^^^^^^^^
2020-10-20 16:01:08 +01:00
.. automethod :: pytorch_lightning.callbacks.Callback.on_batch_end
2020-10-12 17:57:51 -04:00
:noindex:
2020-10-20 16:01:08 +01:00
on_train_start
^^^^^^^^^^^^^^
2020-10-12 17:57:51 -04:00
2020-10-20 16:01:08 +01:00
.. automethod :: pytorch_lightning.callbacks.Callback.on_train_start
2020-10-12 17:57:51 -04:00
:noindex:
2020-10-20 16:01:08 +01:00
on_train_end
^^^^^^^^^^^^
2020-10-12 17:57:51 -04:00
2020-10-20 16:01:08 +01:00
.. automethod :: pytorch_lightning.callbacks.Callback.on_train_end
2020-10-12 17:57:51 -04:00
:noindex:
2020-10-20 16:01:08 +01:00
on_pretrain_routine_start
2020-10-12 17:57:51 -04:00
^^^^^^^^^^^^^^^^^^^^^^^^^
2020-10-20 16:01:08 +01:00
.. automethod :: pytorch_lightning.callbacks.Callback.on_pretrain_routine_start
2020-10-12 17:57:51 -04:00
:noindex:
2020-10-20 16:01:08 +01:00
on_pretrain_routine_end
2020-10-12 17:57:51 -04:00
^^^^^^^^^^^^^^^^^^^^^^^
2020-10-20 16:01:08 +01:00
.. automethod :: pytorch_lightning.callbacks.Callback.on_pretrain_routine_end
2020-10-12 17:57:51 -04:00
:noindex:
2020-10-20 16:01:08 +01:00
on_validation_start
^^^^^^^^^^^^^^^^^^^
2020-10-12 17:57:51 -04:00
2020-10-20 16:01:08 +01:00
.. automethod :: pytorch_lightning.callbacks.Callback.on_validation_start
2020-10-12 17:57:51 -04:00
:noindex:
2020-10-20 16:01:08 +01:00
on_validation_end
^^^^^^^^^^^^^^^^^
2020-10-12 17:57:51 -04:00
2020-10-20 16:01:08 +01:00
.. automethod :: pytorch_lightning.callbacks.Callback.on_validation_end
2020-10-12 17:57:51 -04:00
:noindex:
2020-10-20 16:01:08 +01:00
on_test_start
^^^^^^^^^^^^^
2020-10-12 17:57:51 -04:00
2020-10-20 16:01:08 +01:00
.. automethod :: pytorch_lightning.callbacks.Callback.on_test_start
2020-10-12 17:57:51 -04:00
:noindex:
2020-10-20 16:01:08 +01:00
on_test_end
^^^^^^^^^^^
2020-10-12 17:57:51 -04:00
2020-10-20 16:01:08 +01:00
.. automethod :: pytorch_lightning.callbacks.Callback.on_test_end
2020-10-12 17:57:51 -04:00
:noindex:
2020-10-20 16:01:08 +01:00
on_keyboard_interrupt
^^^^^^^^^^^^^^^^^^^^^
2020-10-12 17:57:51 -04:00
2020-10-20 16:01:08 +01:00
.. automethod :: pytorch_lightning.callbacks.Callback.on_keyboard_interrupt
2020-10-12 17:57:51 -04:00
:noindex:
2021-09-11 16:25:42 -07:00
on_exception
^^^^^^^^^^^^
.. automethod :: pytorch_lightning.callbacks.Callback.on_exception
:noindex:
2020-10-20 16:01:08 +01:00
on_save_checkpoint
^^^^^^^^^^^^^^^^^^
.. automethod :: pytorch_lightning.callbacks.Callback.on_save_checkpoint
:noindex:
on_load_checkpoint
^^^^^^^^^^^^^^^^^^
2020-10-12 17:57:51 -04:00
2020-10-20 16:01:08 +01:00
.. automethod :: pytorch_lightning.callbacks.Callback.on_load_checkpoint
2020-10-12 17:57:51 -04:00
:noindex:
2021-03-25 18:50:49 +05:30
2021-07-09 08:15:57 +02:00
on_before_backward
^^^^^^^^^^^^^^^^^^
.. automethod :: pytorch_lightning.callbacks.Callback.on_before_backward
:noindex:
2021-03-25 18:50:49 +05:30
on_after_backward
^^^^^^^^^^^^^^^^^
.. automethod :: pytorch_lightning.callbacks.Callback.on_after_backward
:noindex:
2021-07-09 13:30:52 +02:00
on_before_optimizer_step
^^^^^^^^^^^^^^^^^^^^^^^^
.. automethod :: pytorch_lightning.callbacks.Callback.on_before_optimizer_step
:noindex:
2021-03-25 18:50:49 +05:30
on_before_zero_grad
^^^^^^^^^^^^^^^^^^^
.. automethod :: pytorch_lightning.callbacks.Callback.on_before_zero_grad
:noindex: