##################################### Deploy models into production (basic) ##################################### **Audience**: All users. ---- ***************************** Load a checkpoint and predict ***************************** The easiest way to use a model for predictions is to load the weights using **load_from_checkpoint** found in the LightningModule. .. code-block:: python model = LitModel.load_from_checkpoint("best_model.ckpt") model.eval() x = torch.randn(1, 64) with torch.no_grad(): y_hat = model(x) ---- ************************************** Predict step with your LightningModule ************************************** Loading a checkpoint and predicting still leaves you with a lot of boilerplate around the predict epoch. The **predict step** in the LightningModule removes this boilerplate. .. code-block:: python class MyModel(LightningModule): def predict_step(self, batch, batch_idx, dataloader_idx=0): return self(batch) And pass in any dataloader to the Lightning Trainer: .. code-block:: python data_loader = DataLoader(...) model = MyModel() trainer = Trainer() predictions = trainer.predict(model, data_loader) ---- ******************************** Enable complicated predict logic ******************************** When you need to add complicated pre-processing or post-processing logic to your data use the predict step. For example here we do `Monte Carlo Dropout `_ for predictions: .. code-block:: python class LitMCdropoutModel(pl.LightningModule): def __init__(self, model, mc_iteration): super().__init__() self.model = model self.dropout = nn.Dropout() self.mc_iteration = mc_iteration def predict_step(self, batch, batch_idx): # enable Monte Carlo Dropout self.dropout.train() # take average of `self.mc_iteration` iterations pred = [self.dropout(self.model(x)).unsqueeze(0) for _ in range(self.mc_iteration)] pred = torch.vstack(pred).mean(dim=0) return pred ---- **************************** Enable distributed inference **************************** By using the predict step in Lightning you get free distributed inference .. code-block:: python trainer = Trainer(devices=8, accelerator="gpu") predictions = trainer.predict(model, data_loader)