Model Hooks =========== There are cases when you might want to do something different at different parts of the training/validation loop. To enable a hook, simply override the method in your LightningModule and the trainer will call it at the correct time. **Contributing** If there's a hook you'd like to add, simply: 1. Fork `PyTorchLightning <https://github.com/PyTorchLightning/pytorch-lightning>`_. 2. Add the hook to :class:`pytorch_lightning.core.hooks.ModelHooks`. 3. Add it in the correct place in :mod:`pytorch_lightning.trainer` where it should be called. ---------------- Hooks lifecycle --------------- Training set-up ^^^^^^^^^^^^^^^ - :meth:`~pytorch_lightning.core.lightning.LightningModule.prepare_data` - :meth:`~pytorch_lightning.core.lightning.LightningModule.setup` - :meth:`~pytorch_lightning.core.lightning.LightningModule.init_ddp_connection` - :meth:`~pytorch_lightning.trainer.optimizers.TrainerOptimizersMixin.init_optimizers` - :meth:`~pytorch_lightning.core.lightning.LightningModule.configure_apex` - :meth:`~pytorch_lightning.core.lightning.LightningModule.configure_ddp` - :meth:`~pytorch_lightning.core.lightning.LightningModule.train_dataloader` - :meth:`~pytorch_lightning.core.lightning.LightningModule.test_dataloader` - :meth:`~pytorch_lightning.core.lightning.LightningModule.val_dataloader` - :meth:`~pytorch_lightning.core.lightning.LightningModule.summarize` - :meth:`~pytorch_lightning.trainer.training_io.TrainerIOMixin.restore_weights` .. warning:: `prepare_data` is only called from global_rank=0. Don't assign state (self.something), use `setup` for that ---------- Training loop ^^^^^^^^^^^^^ - :meth:`~pytorch_lightning.core.hooks.ModelHooks.on_epoch_start` - :meth:`~pytorch_lightning.core.hooks.ModelHooks.on_batch_start` - :meth:`~pytorch_lightning.core.lightning.LightningModule.tbptt_split_batch` - :meth:`~pytorch_lightning.core.lightning.LightningModule.training_step` - :meth:`~pytorch_lightning.core.lightning.LightningModule.training_step_end` (optional) - :meth:`~pytorch_lightning.core.hooks.ModelHooks.on_before_zero_grad` - :meth:`~pytorch_lightning.core.hooks.ModelHooks.backward` - :meth:`~pytorch_lightning.core.hooks.ModelHooks.on_after_backward` - ``optimizer.step()`` - :meth:`~pytorch_lightning.core.hooks.ModelHooks.on_batch_end` - :meth:`~pytorch_lightning.core.lightning.LightningModule.training_epoch_end` - :meth:`~pytorch_lightning.core.hooks.ModelHooks.on_epoch_end` ---------- Validation loop ^^^^^^^^^^^^^^^ - ``model.zero_grad()`` - ``model.eval()`` - ``torch.set_grad_enabled(False)`` - :meth:`~pytorch_lightning.core.lightning.LightningModule.validation_step` - :meth:`~pytorch_lightning.core.lightning.LightningModule.validation_step_end` - :meth:`~pytorch_lightning.core.lightning.LightningModule.validation_epoch_end` - ``model.train()`` - ``torch.set_grad_enabled(True)`` - :meth:`~pytorch_lightning.core.hooks.ModelHooks.on_post_performance_check` ---------- Test loop ^^^^^^^^^ - ``model.zero_grad()`` - ``model.eval()`` - ``torch.set_grad_enabled(False)`` - :meth:`~pytorch_lightning.core.lightning.LightningModule.test_step` - :meth:`~pytorch_lightning.core.lightning.LightningModule.test_step_end` - :meth:`~pytorch_lightning.core.lightning.LightningModule.test_epoch_end` - ``model.train()`` - ``torch.set_grad_enabled(True)`` - :meth:`~pytorch_lightning.core.hooks.ModelHooks.on_post_performance_check` ---------------- General hooks ------------- .. automodule:: pytorch_lightning.core.hooks :noindex: