lightning/docs/source/common/production_inference.rst

285 lines
9.2 KiB
ReStructuredText

.. _production_inference:
#######################
Inference in Production
#######################
Once a model is trained, deploying to production and running inference is the next task. To help you with it, here are the possible approaches
you can use to deploy and make inferences with your models.
------------
******************
With Lightning API
******************
The following are some possible ways you can use Lightning to run inference in production. Note that PyTorch Lightning has some extra dependencies and using raw PyTorch might be advantageous.
in your production environment.
------------
Prediction API
==============
Lightning provides you with a prediction API that can be accessed using :meth:`~pytorch_lightning.trainer.trainer.Trainer.predict`.
To configure this with your LightningModule, you would need to override the :meth:`~pytorch_lightning.core.lightning.LightningModule.predict_step` method.
By default :meth:`~pytorch_lightning.core.lightning.LightningModule.predict_step` calls the :meth:`~pytorch_lightning.core.lightning.LightningModule.forward`
method. In order to customize this behaviour, simply override the :meth:`~pytorch_lightning.core.lightning.LightningModule.predict_step` method. This can be useful to add some pre-processing or post-processing logic to your data.
For the example let's override ``predict_step`` and try out `Monte Carlo Dropout <https://arxiv.org/pdf/1506.02142.pdf>`_:
.. code-block:: python
class LitMCdropoutModel(pl.LightningModule):
def __init__(self, model, mc_iteration):
super().__init__()
self.model = model
self.dropout = nn.Dropout()
self.mc_iteration = mc_iteration
def predict_step(self, batch, batch_idx):
# enable Monte Carlo Dropout
self.dropout.train()
# take average of `self.mc_iteration` iterations
pred = [self.dropout(self.model(x)).unsqueeze(0) for _ in range(self.mc_iteration)]
pred = torch.vstack(pred).mean(dim=0)
return pred
------------
PyTorch Runtime
===============
You can also load the saved checkpoint and use it as a regular :class:`torch.nn.Module`.
.. code-block:: python
class SimpleModel(LightningModule):
def __init__(self):
super().__init__()
self.l1 = torch.nn.Linear(in_features=64, out_features=4)
def forward(self, x):
return torch.relu(self.l1(x.view(x.size(0), -1)))
# create the model
model = SimpleModel()
# train it
trainer = Trainer(accelerator="gpu", devices=2)
trainer.fit(model, train_dataloader, val_dataloader)
trainer.save_checkpoint("best_model.ckpt", weights_only=True)
# use model after training or load weights and drop into the production system
model = SimpleModel.load_from_checkpoint("best_model.ckpt")
model.eval()
x = torch.randn(1, 64)
with torch.no_grad():
y_hat = model(x)
------------
*********************
Without Lightning API
*********************
As the :class:`~pytorch_lightning.core.lightning.LightningModule` is simply a :class:`torch.nn.Module`, common techniques to export PyTorch models
to production apply here too. However, the :class:`~pytorch_lightning.core.lightning.LightningModule` provides helper methods to help you out with it.
------------
Convert to ONNX
===============
Lightning provides a handy function to quickly export your model to `ONNX <https://pytorch.org/docs/stable/onnx.html>`_ format
which allows the model to be independent of PyTorch and run on an ONNX Runtime.
To export your model to ONNX format call the :meth:`~pytorch_lightning.core.lightning.LightningModule.to_onnx` function on your :class:`~pytorch_lightning.core.lightning.LightningModule` with the ``filepath`` and ``input_sample``.
.. code-block:: python
class SimpleModel(LightningModule):
def __init__(self):
super().__init__()
self.l1 = torch.nn.Linear(in_features=64, out_features=4)
def forward(self, x):
return torch.relu(self.l1(x.view(x.size(0), -1)))
# create the model
model = SimpleModel()
filepath = "model.onnx"
input_sample = torch.randn((1, 64))
model.to_onnx(filepath, input_sample, export_params=True)
You can also skip passing the input sample if the ``example_input_array`` property is specified in your :class:`~pytorch_lightning.core.lightning.LightningModule`.
.. code-block:: python
class SimpleModel(LightningModule):
def __init__(self):
super().__init__()
self.l1 = torch.nn.Linear(in_features=64, out_features=4)
self.example_input_array = torch.randn(7, 64)
def forward(self, x):
return torch.relu(self.l1(x.view(x.size(0), -1)))
# create the model
model = SimpleModel()
filepath = "model.onnx"
model.to_onnx(filepath, export_params=True)
Once you have the exported model, you can run it on your ONNX runtime in the following way:
.. code-block:: python
import onnxruntime
ort_session = onnxruntime.InferenceSession(filepath)
input_name = ort_session.get_inputs()[0].name
ort_inputs = {input_name: np.random.randn(1, 64)}
ort_outs = ort_session.run(None, ort_inputs)
------------
Convert to TorchScript
======================
`TorchScript <https://pytorch.org/docs/stable/jit.html>`_ allows you to serialize your models in a way that it can be loaded in non-Python environments.
The ``LightningModule`` has a handy method :meth:`~pytorch_lightning.core.lightning.LightningModule.to_torchscript` that returns a scripted module which you
can save or directly use.
.. testcode:: python
class SimpleModel(LightningModule):
def __init__(self):
super().__init__()
self.l1 = torch.nn.Linear(in_features=64, out_features=4)
def forward(self, x):
return torch.relu(self.l1(x.view(x.size(0), -1)))
# create the model
model = SimpleModel()
script = model.to_torchscript()
# save for use in production environment
torch.jit.save(script, "model.pt")
It is recommended that you install the latest supported version of PyTorch to use this feature without limitations.
Once you have the exported model, you can run it in Pytorch or C++ runtime:
.. code-block:: python
inp = torch.rand(1, 64)
scripted_module = torch.jit.load("model.pt")
output = scripted_module(inp)
If you want to script a different method, you can decorate the method with :func:`torch.jit.export`:
.. code-block:: python
class LitMCdropoutModel(pl.LightningModule):
def __init__(self, model, mc_iteration):
super().__init__()
self.model = model
self.dropout = nn.Dropout()
self.mc_iteration = mc_iteration
@torch.jit.export
def predict_step(self, batch, batch_idx):
# enable Monte Carlo Dropout
self.dropout.train()
# take average of `self.mc_iteration` iterations
pred = [self.dropout(self.model(x)).unsqueeze(0) for _ in range(self.mc_iteration)]
pred = torch.vstack(pred).mean(dim=0)
return pred
model = LitMCdropoutModel(...)
script = model.to_torchscript(file_path="model.pt", method="script")
------------
PyTorch Runtime
===============
You can also load the saved checkpoint and use it as a regular :class:`torch.nn.Module`. You can extract all your :class:`torch.nn.Module`
and load the weights using the checkpoint saved using LightningModule after training. For this, we recommend copying the exact implementation
from your LightningModule ``init`` and ``forward`` method.
.. code-block:: python
class Encoder(nn.Module):
...
class Decoder(nn.Module):
...
class AutoEncoderProd(nn.Module):
def __init__(self):
super().__init__()
self.encoder = Encoder()
self.decoder = Decoder()
def forward(self, x):
return self.encoder(x)
class AutoEncoderSystem(LightningModule):
def __init__(self):
super().__init__()
self.auto_encoder = AutoEncoderProd()
def forward(self, x):
return self.auto_encoder.encoder(x)
def training_step(self, batch, batch_idx):
x, y = batch
y_hat = self.auto_encoder.encoder(x)
y_hat = self.auto_encoder.decoder(y_hat)
loss = ...
return loss
# train it
trainer = Trainer(devices=2, accelerator="gpu", strategy="ddp")
model = AutoEncoderSystem()
trainer.fit(model, train_dataloader, val_dataloader)
trainer.save_checkpoint("best_model.ckpt")
# create the PyTorch model and load the checkpoint weights
model = AutoEncoderProd()
checkpoint = torch.load("best_model.ckpt")
hyper_parameters = checkpoint["hyper_parameters"]
# if you want to restore any hyperparameters, you can pass them too
model = AutoEncoderProd(**hyper_parameters)
state_dict = checkpoint["state_dict"]
# update keys by dropping `auto_encoder.`
for key in list(model_weights):
model_weights[key.replace("auto_encoder.", "")] = model_weights.pop(key)
model.load_state_dict(model_weights)
model.eval()
x = torch.randn(1, 64)
with torch.no_grad():
y_hat = model(x)