Lightning automates the training loop for you and manages all of the associated components such as: epoch and batch tracking, optimizers and schedulers,
and metric reduction. As a user, you just need to define how your model behaves with a batch of training data within the
:meth:`~pytorch_lightning.core.lightning.LightningModule.training_step` method. When using Lightning, simply override the
:meth:`~pytorch_lightning.core.lightning.LightningModule.training_step` method which takes the current ``batch`` and the ``batch_idx``
as arguments. Optionally, it can take ``optimizer_idx`` if your LightningModule defines multiple optimizers within its
Lightning also automates the validation loop for you and manages all of the associated components such as: epoch and batch tracking, and metrics reduction. As a user,
you just need to define how your model behaves with a batch of validation data within the :meth:`~pytorch_lightning.core.lightning.LightningModule.validation_step`
method. When using Lightning, simply override the :meth:`~pytorch_lightning.core.lightning.LightningModule.validation_step` method which takes the current
``batch`` and the ``batch_idx`` as arguments. Optionally, it can take ``dataloader_idx`` if you configure multiple dataloaders.
:meth:`~pytorch_lightning.core.lightning.LightningModule.validation_step` hook (make sure to use the hook parameters, ``batch`` and ``batch_idx`` in this case).
Lightning automates the testing loop for you and manages all the associated components, such as epoch and batch tracking, metrics reduction. As a user,
you just need to define how your model behaves with a batch of testing data within the :meth:`~pytorch_lightning.core.lightning.LightningModule.test_step`
method. When using Lightning, simply override the :meth:`~pytorch_lightning.core.lightning.LightningModule.test_step` method which takes the current
``batch`` and the ``batch_idx`` as arguments. Optionally, it can take ``dataloader_idx`` if you configure multiple dataloaders.
The test loop isn't used within :meth:`~pytorch_lightning.trainer.trainer.Trainer.fit`, therefore, you would need to explicitly call :meth:`~pytorch_lightning.trainer.trainer.Trainer.test`.
..code-block:: python
model = LitModel()
trainer.test(model)
..note::``model.eval()`` and ``torch.no_grad()`` are called automatically for testing.
..tip::``trainer.test()`` loads the best checkpoint automatically by default if checkpointing is enabled.
--------
*****************************
6. Configure Prediction Logic
*****************************
Lightning automates the prediction loop for you and manages all of the associated components such as epoch and batch tracking. As a user,
you just need to define how your model behaves with a batch of data within the :meth:`~pytorch_lightning.core.lightning.LightningModule.predict_step`
method. When using Lightning, simply override the :meth:`~pytorch_lightning.core.lightning.LightningModule.predict_step` method which takes the current
``batch`` and the ``batch_idx`` as arguments. Optionally, it can take ``dataloader_idx`` if you configure multiple dataloaders.
If you don't override ``predict_step`` hook, it by default calls :meth:`~pytorch_lightning.core.lightning.LightningModule.forward` method on the batch.
..testcode::
class LitModel(LightningModule):
def predict_step(self, batch, batch_idx):
x, y = batch
pred = self(x)
return pred
The predict loop will not be used until you call :meth:`~pytorch_lightning.trainer.trainer.Trainer.predict`.
..tip::``trainer.predict()`` loads the best checkpoint automatically by default if checkpointing is enabled.
--------
******************************************
7. Remove any .cuda() or .to(device) Calls
******************************************
Your :doc:`LightningModule <../common/lightning_module>` can automatically run on any hardware!
If you have any explicit calls to ``.cuda()`` or ``.to(device)``, you can remove them since Lightning makes sure that the data coming from :class:`~torch.utils.data.DataLoader`
and all the :class:`~torch.nn.Module` instances initialized inside ``LightningModule.__init__`` are moved to the respective devices automatically.
If you still need to access the current device, you can use ``self.device`` anywhere in ``LightningModule`` except ``__init__`` method. You are initializing a
:class:`~torch.Tensor` within ``LightningModule.__init__`` method and want it to be moved to the device automatically you must :meth:`~torch.nn.Module.register_buffer`