.. _production_inference: ####################### Inference in Production ####################### Once a model is trained, deploying to production and running inference is the next task. To help you with it, here are the possible approaches you can use to deploy and make inferences with your models. ------------ ****************** With Lightning API ****************** The following are some possible ways you can use Lightning to run inference in production. Note that PyTorch Lightning has some extra dependencies and using raw PyTorch might be advantageous. in your production environment. ------------ Prediction API ============== Lightning provides you with a prediction API that can be accessed using :meth:`~pytorch_lightning.trainer.trainer.Trainer.predict`. To configure this with your LightningModule, you would need to override the :meth:`~pytorch_lightning.core.lightning.LightningModule.predict_step` method. By default :meth:`~pytorch_lightning.core.lightning.LightningModule.predict_step` calls the :meth:`~pytorch_lightning.core.lightning.LightningModule.forward` method. In order to customize this behaviour, simply override the :meth:`~pytorch_lightning.core.lightning.LightningModule.predict_step` method. This can be useful to add some pre-processing or post-processing logic to your data. For the example let's override ``predict_step`` and try out `Monte Carlo Dropout `_: .. code-block:: python class LitMCdropoutModel(pl.LightningModule): def __init__(self, model, mc_iteration): super().__init__() self.model = model self.dropout = nn.Dropout() self.mc_iteration = mc_iteration def predict_step(self, batch, batch_idx): # enable Monte Carlo Dropout self.dropout.train() # take average of `self.mc_iteration` iterations pred = [self.dropout(self.model(x)).unsqueeze(0) for _ in range(self.mc_iteration)] pred = torch.vstack(pred).mean(dim=0) return pred ------------ PyTorch Runtime =============== You can also load the saved checkpoint and use it as a regular :class:`torch.nn.Module`. .. code-block:: python class SimpleModel(LightningModule): def __init__(self): super().__init__() self.l1 = torch.nn.Linear(in_features=64, out_features=4) def forward(self, x): return torch.relu(self.l1(x.view(x.size(0), -1))) # create the model model = SimpleModel() # train it trainer = Trainer(accelerator="gpu", devices=2) trainer.fit(model, train_dataloader, val_dataloader) trainer.save_checkpoint("best_model.ckpt", weights_only=True) # use model after training or load weights and drop into the production system model = SimpleModel.load_from_checkpoint("best_model.ckpt") model.eval() x = torch.randn(1, 64) with torch.no_grad(): y_hat = model(x) ------------ ********************* Without Lightning API ********************* As the :class:`~pytorch_lightning.core.lightning.LightningModule` is simply a :class:`torch.nn.Module`, common techniques to export PyTorch models to production apply here too. However, the :class:`~pytorch_lightning.core.lightning.LightningModule` provides helper methods to help you out with it. ------------ Convert to ONNX =============== Lightning provides a handy function to quickly export your model to `ONNX `_ format which allows the model to be independent of PyTorch and run on an ONNX Runtime. To export your model to ONNX format call the :meth:`~pytorch_lightning.core.lightning.LightningModule.to_onnx` function on your :class:`~pytorch_lightning.core.lightning.LightningModule` with the ``filepath`` and ``input_sample``. .. code-block:: python class SimpleModel(LightningModule): def __init__(self): super().__init__() self.l1 = torch.nn.Linear(in_features=64, out_features=4) def forward(self, x): return torch.relu(self.l1(x.view(x.size(0), -1))) # create the model model = SimpleModel() filepath = "model.onnx" input_sample = torch.randn((1, 64)) model.to_onnx(filepath, input_sample, export_params=True) You can also skip passing the input sample if the ``example_input_array`` property is specified in your :class:`~pytorch_lightning.core.lightning.LightningModule`. .. code-block:: python class SimpleModel(LightningModule): def __init__(self): super().__init__() self.l1 = torch.nn.Linear(in_features=64, out_features=4) self.example_input_array = torch.randn(7, 64) def forward(self, x): return torch.relu(self.l1(x.view(x.size(0), -1))) # create the model model = SimpleModel() filepath = "model.onnx" model.to_onnx(filepath, export_params=True) Once you have the exported model, you can run it on your ONNX runtime in the following way: .. code-block:: python import onnxruntime ort_session = onnxruntime.InferenceSession(filepath) input_name = ort_session.get_inputs()[0].name ort_inputs = {input_name: np.random.randn(1, 64)} ort_outs = ort_session.run(None, ort_inputs) ------------ Convert to TorchScript ====================== `TorchScript `_ allows you to serialize your models in a way that it can be loaded in non-Python environments. The ``LightningModule`` has a handy method :meth:`~pytorch_lightning.core.lightning.LightningModule.to_torchscript` that returns a scripted module which you can save or directly use. .. testcode:: python class SimpleModel(LightningModule): def __init__(self): super().__init__() self.l1 = torch.nn.Linear(in_features=64, out_features=4) def forward(self, x): return torch.relu(self.l1(x.view(x.size(0), -1))) # create the model model = SimpleModel() script = model.to_torchscript() # save for use in production environment torch.jit.save(script, "model.pt") It is recommended that you install the latest supported version of PyTorch to use this feature without limitations. Once you have the exported model, you can run it in Pytorch or C++ runtime: .. code-block:: python inp = torch.rand(1, 64) scripted_module = torch.jit.load("model.pt") output = scripted_module(inp) If you want to script a different method, you can decorate the method with :func:`torch.jit.export`: .. code-block:: python class LitMCdropoutModel(pl.LightningModule): def __init__(self, model, mc_iteration): super().__init__() self.model = model self.dropout = nn.Dropout() self.mc_iteration = mc_iteration @torch.jit.export def predict_step(self, batch, batch_idx): # enable Monte Carlo Dropout self.dropout.train() # take average of `self.mc_iteration` iterations pred = [self.dropout(self.model(x)).unsqueeze(0) for _ in range(self.mc_iteration)] pred = torch.vstack(pred).mean(dim=0) return pred model = LitMCdropoutModel(...) script = model.to_torchscript(file_path="model.pt", method="script") ------------ PyTorch Runtime =============== You can also load the saved checkpoint and use it as a regular :class:`torch.nn.Module`. You can extract all your :class:`torch.nn.Module` and load the weights using the checkpoint saved using LightningModule after training. For this, we recommend copying the exact implementation from your LightningModule ``init`` and ``forward`` method. .. code-block:: python class Encoder(nn.Module): ... class Decoder(nn.Module): ... class AutoEncoderProd(nn.Module): def __init__(self): super().__init__() self.encoder = Encoder() self.decoder = Decoder() def forward(self, x): return self.encoder(x) class AutoEncoderSystem(LightningModule): def __init__(self): super().__init__() self.auto_encoder = AutoEncoderProd() def forward(self, x): return self.auto_encoder.encoder(x) def training_step(self, batch, batch_idx): x, y = batch y_hat = self.auto_encoder.encoder(x) y_hat = self.auto_encoder.decoder(y_hat) loss = ... return loss # train it trainer = Trainer(devices=2, accelerator="gpu", strategy="ddp") model = AutoEncoderSystem() trainer.fit(model, train_dataloader, val_dataloader) trainer.save_checkpoint("best_model.ckpt") # create the PyTorch model and load the checkpoint weights model = AutoEncoderProd() checkpoint = torch.load("best_model.ckpt") hyper_parameters = checkpoint["hyper_parameters"] # if you want to restore any hyperparameters, you can pass them too model = AutoEncoderProd(**hyper_parameters) state_dict = checkpoint["state_dict"] # update keys by dropping `auto_encoder.` for key in list(model_weights): model_weights[key.replace("auto_encoder.", "")] = model_weights.pop(key) model.load_state_dict(model_weights) model.eval() x = torch.randn(1, 64) with torch.no_grad(): y_hat = model(x)