94 lines
3.5 KiB
ReStructuredText
94 lines
3.5 KiB
ReStructuredText
.. _hooks:
|
|
|
|
Model Hooks
|
|
===========
|
|
|
|
There are cases when you might want to do something different at different parts of the training/validation loop.
|
|
To enable a hook, simply override the method in your LightningModule and the trainer will call it at the correct time.
|
|
|
|
**Contributing** If there's a hook you'd like to add, simply:
|
|
|
|
1. Fork `PyTorchLightning <https://github.com/PyTorchLightning/pytorch-lightning>`_.
|
|
|
|
2. Add the hook to :class:`pytorch_lightning.core.hooks.ModelHooks`.
|
|
|
|
3. Add it in the correct place in :mod:`pytorch_lightning.trainer` where it should be called.
|
|
|
|
----------------
|
|
|
|
Hooks lifecycle
|
|
---------------
|
|
|
|
Training set-up
|
|
^^^^^^^^^^^^^^^
|
|
|
|
- :meth:`~pytorch_lightning.core.lightning.LightningModule.prepare_data`
|
|
- :meth:`~pytorch_lightning.core.lightning.LightningModule.setup`
|
|
- :meth:`~pytorch_lightning.core.lightning.LightningModule.init_ddp_connection`
|
|
- :meth:`~pytorch_lightning.trainer.optimizers.TrainerOptimizersMixin.init_optimizers`
|
|
- :meth:`~pytorch_lightning.core.lightning.LightningModule.configure_apex`
|
|
- :meth:`~pytorch_lightning.core.lightning.LightningModule.configure_ddp`
|
|
- :meth:`~pytorch_lightning.core.lightning.LightningModule.train_dataloader`
|
|
- :meth:`~pytorch_lightning.core.lightning.LightningModule.test_dataloader`
|
|
- :meth:`~pytorch_lightning.core.lightning.LightningModule.val_dataloader`
|
|
- :meth:`~pytorch_lightning.core.lightning.LightningModule.summarize`
|
|
- :meth:`~pytorch_lightning.trainer.training_io.TrainerIOMixin.restore_weights`
|
|
|
|
.. warning:: `prepare_data` is only called from global_rank=0. Don't assign state (self.something), use `setup` for that
|
|
|
|
----------
|
|
|
|
Training loop
|
|
^^^^^^^^^^^^^
|
|
|
|
- :meth:`~pytorch_lightning.core.hooks.ModelHooks.on_epoch_start`
|
|
- :meth:`~pytorch_lightning.core.hooks.ModelHooks.on_batch_start`
|
|
- :meth:`~pytorch_lightning.core.lightning.LightningModule.tbptt_split_batch`
|
|
- :meth:`~pytorch_lightning.core.lightning.LightningModule.training_step`
|
|
- :meth:`~pytorch_lightning.core.lightning.LightningModule.training_step_end` (optional)
|
|
- :meth:`~pytorch_lightning.core.hooks.ModelHooks.on_before_zero_grad`
|
|
- :meth:`~pytorch_lightning.core.hooks.ModelHooks.backward`
|
|
- :meth:`~pytorch_lightning.core.hooks.ModelHooks.on_after_backward`
|
|
- ``optimizer.step()``
|
|
- :meth:`~pytorch_lightning.core.hooks.ModelHooks.on_batch_end`
|
|
- :meth:`~pytorch_lightning.core.lightning.LightningModule.training_epoch_end`
|
|
- :meth:`~pytorch_lightning.core.hooks.ModelHooks.on_epoch_end`
|
|
|
|
----------
|
|
|
|
Validation loop
|
|
^^^^^^^^^^^^^^^
|
|
|
|
- ``model.zero_grad()``
|
|
- ``model.eval()``
|
|
- ``torch.set_grad_enabled(False)``
|
|
- :meth:`~pytorch_lightning.core.lightning.LightningModule.validation_step`
|
|
- :meth:`~pytorch_lightning.core.lightning.LightningModule.validation_step_end`
|
|
- :meth:`~pytorch_lightning.core.lightning.LightningModule.validation_epoch_end`
|
|
- ``model.train()``
|
|
- ``torch.set_grad_enabled(True)``
|
|
- :meth:`~pytorch_lightning.core.hooks.ModelHooks.on_post_performance_check`
|
|
|
|
----------
|
|
|
|
Test loop
|
|
^^^^^^^^^
|
|
|
|
- ``model.zero_grad()``
|
|
- ``model.eval()``
|
|
- ``torch.set_grad_enabled(False)``
|
|
- :meth:`~pytorch_lightning.core.lightning.LightningModule.test_step`
|
|
- :meth:`~pytorch_lightning.core.lightning.LightningModule.test_step_end`
|
|
- :meth:`~pytorch_lightning.core.lightning.LightningModule.test_epoch_end`
|
|
- ``model.train()``
|
|
- ``torch.set_grad_enabled(True)``
|
|
- :meth:`~pytorch_lightning.core.hooks.ModelHooks.on_post_performance_check`
|
|
|
|
----------------
|
|
|
|
General hooks
|
|
-------------
|
|
|
|
.. automodule:: pytorch_lightning.core.hooks
|
|
:noindex:
|