The easiest way to use a model for predictions is to load the weights using **load_from_checkpoint** found in the LightningModule.
..code-block:: python
model = LitModel.load_from_checkpoint("best_model.ckpt")
model.eval()
x = torch.randn(1, 64)
with torch.no_grad():
y_hat = model(x)
----
**************************************
Predict step with your LightningModule
**************************************
Loading a checkpoint and predicting still leaves you with a lot of boilerplate around the predict epoch. The **predict step** in the LightningModule removes this boilerplate.
And pass in any dataloader to the Lightning Trainer:
..code-block:: python
data_loader = DataLoader(...)
model = MyModel()
trainer = Trainer()
predictions = trainer.predict(model, data_loader)
----
********************************
Enable complicated predict logic
********************************
When you need to add complicated pre-processing or post-processing logic to your data use the predict step. For example here we do `Monte Carlo Dropout <https://arxiv.org/pdf/1506.02142.pdf>`_ for predictions:
By using the predict step in Lightning you get free distributed inference using :class:`~lightning.pytorch.callbacks.prediction_writer.BasePredictionWriter`.