285 lines
9.2 KiB
ReStructuredText
285 lines
9.2 KiB
ReStructuredText
.. _production_inference:
|
|
|
|
#######################
|
|
Inference in Production
|
|
#######################
|
|
|
|
Once a model is trained, deploying to production and running inference is the next task. To help you with it, here are the possible approaches
|
|
you can use to deploy and make inferences with your models.
|
|
|
|
------------
|
|
|
|
******************
|
|
With Lightning API
|
|
******************
|
|
|
|
The following are some possible ways you can use Lightning to run inference in production. Note that PyTorch Lightning has some extra dependencies and using raw PyTorch might be advantageous.
|
|
in your production environment.
|
|
|
|
------------
|
|
|
|
Prediction API
|
|
==============
|
|
|
|
Lightning provides you with a prediction API that can be accessed using :meth:`~pytorch_lightning.trainer.trainer.Trainer.predict`.
|
|
To configure this with your LightningModule, you would need to override the :meth:`~pytorch_lightning.core.lightning.LightningModule.predict_step` method.
|
|
By default :meth:`~pytorch_lightning.core.lightning.LightningModule.predict_step` calls the :meth:`~pytorch_lightning.core.lightning.LightningModule.forward`
|
|
method. In order to customize this behaviour, simply override the :meth:`~pytorch_lightning.core.lightning.LightningModule.predict_step` method. This can be useful to add some pre-processing or post-processing logic to your data.
|
|
|
|
For the example let's override ``predict_step`` and try out `Monte Carlo Dropout <https://arxiv.org/pdf/1506.02142.pdf>`_:
|
|
|
|
.. code-block:: python
|
|
|
|
class LitMCdropoutModel(pl.LightningModule):
|
|
def __init__(self, model, mc_iteration):
|
|
super().__init__()
|
|
self.model = model
|
|
self.dropout = nn.Dropout()
|
|
self.mc_iteration = mc_iteration
|
|
|
|
def predict_step(self, batch, batch_idx):
|
|
# enable Monte Carlo Dropout
|
|
self.dropout.train()
|
|
|
|
# take average of `self.mc_iteration` iterations
|
|
pred = [self.dropout(self.model(x)).unsqueeze(0) for _ in range(self.mc_iteration)]
|
|
pred = torch.vstack(pred).mean(dim=0)
|
|
return pred
|
|
|
|
------------
|
|
|
|
PyTorch Runtime
|
|
===============
|
|
|
|
You can also load the saved checkpoint and use it as a regular :class:`torch.nn.Module`.
|
|
|
|
.. code-block:: python
|
|
|
|
class SimpleModel(LightningModule):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.l1 = torch.nn.Linear(in_features=64, out_features=4)
|
|
|
|
def forward(self, x):
|
|
return torch.relu(self.l1(x.view(x.size(0), -1)))
|
|
|
|
|
|
# create the model
|
|
model = SimpleModel()
|
|
|
|
# train it
|
|
trainer = Trainer(accelerator="gpu", devices=2)
|
|
trainer.fit(model, train_dataloader, val_dataloader)
|
|
trainer.save_checkpoint("best_model.ckpt", weights_only=True)
|
|
|
|
# use model after training or load weights and drop into the production system
|
|
model = SimpleModel.load_from_checkpoint("best_model.ckpt")
|
|
model.eval()
|
|
x = torch.randn(1, 64)
|
|
|
|
with torch.no_grad():
|
|
y_hat = model(x)
|
|
|
|
------------
|
|
|
|
*********************
|
|
Without Lightning API
|
|
*********************
|
|
|
|
As the :class:`~pytorch_lightning.core.lightning.LightningModule` is simply a :class:`torch.nn.Module`, common techniques to export PyTorch models
|
|
to production apply here too. However, the :class:`~pytorch_lightning.core.lightning.LightningModule` provides helper methods to help you out with it.
|
|
|
|
------------
|
|
|
|
Convert to ONNX
|
|
===============
|
|
|
|
Lightning provides a handy function to quickly export your model to `ONNX <https://pytorch.org/docs/stable/onnx.html>`_ format
|
|
which allows the model to be independent of PyTorch and run on an ONNX Runtime.
|
|
|
|
To export your model to ONNX format call the :meth:`~pytorch_lightning.core.lightning.LightningModule.to_onnx` function on your :class:`~pytorch_lightning.core.lightning.LightningModule` with the ``filepath`` and ``input_sample``.
|
|
|
|
.. code-block:: python
|
|
|
|
class SimpleModel(LightningModule):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.l1 = torch.nn.Linear(in_features=64, out_features=4)
|
|
|
|
def forward(self, x):
|
|
return torch.relu(self.l1(x.view(x.size(0), -1)))
|
|
|
|
|
|
# create the model
|
|
model = SimpleModel()
|
|
filepath = "model.onnx"
|
|
input_sample = torch.randn((1, 64))
|
|
model.to_onnx(filepath, input_sample, export_params=True)
|
|
|
|
You can also skip passing the input sample if the ``example_input_array`` property is specified in your :class:`~pytorch_lightning.core.lightning.LightningModule`.
|
|
|
|
.. code-block:: python
|
|
|
|
class SimpleModel(LightningModule):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.l1 = torch.nn.Linear(in_features=64, out_features=4)
|
|
self.example_input_array = torch.randn(7, 64)
|
|
|
|
def forward(self, x):
|
|
return torch.relu(self.l1(x.view(x.size(0), -1)))
|
|
|
|
|
|
# create the model
|
|
model = SimpleModel()
|
|
filepath = "model.onnx"
|
|
model.to_onnx(filepath, export_params=True)
|
|
|
|
Once you have the exported model, you can run it on your ONNX runtime in the following way:
|
|
|
|
.. code-block:: python
|
|
|
|
import onnxruntime
|
|
|
|
ort_session = onnxruntime.InferenceSession(filepath)
|
|
input_name = ort_session.get_inputs()[0].name
|
|
ort_inputs = {input_name: np.random.randn(1, 64)}
|
|
ort_outs = ort_session.run(None, ort_inputs)
|
|
|
|
------------
|
|
|
|
Convert to TorchScript
|
|
======================
|
|
|
|
`TorchScript <https://pytorch.org/docs/stable/jit.html>`_ allows you to serialize your models in a way that it can be loaded in non-Python environments.
|
|
The ``LightningModule`` has a handy method :meth:`~pytorch_lightning.core.lightning.LightningModule.to_torchscript` that returns a scripted module which you
|
|
can save or directly use.
|
|
|
|
.. testcode:: python
|
|
|
|
class SimpleModel(LightningModule):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.l1 = torch.nn.Linear(in_features=64, out_features=4)
|
|
|
|
def forward(self, x):
|
|
return torch.relu(self.l1(x.view(x.size(0), -1)))
|
|
|
|
|
|
# create the model
|
|
model = SimpleModel()
|
|
script = model.to_torchscript()
|
|
|
|
# save for use in production environment
|
|
torch.jit.save(script, "model.pt")
|
|
|
|
It is recommended that you install the latest supported version of PyTorch to use this feature without limitations.
|
|
|
|
Once you have the exported model, you can run it in Pytorch or C++ runtime:
|
|
|
|
.. code-block:: python
|
|
|
|
inp = torch.rand(1, 64)
|
|
scripted_module = torch.jit.load("model.pt")
|
|
output = scripted_module(inp)
|
|
|
|
|
|
If you want to script a different method, you can decorate the method with :func:`torch.jit.export`:
|
|
|
|
.. code-block:: python
|
|
|
|
class LitMCdropoutModel(pl.LightningModule):
|
|
def __init__(self, model, mc_iteration):
|
|
super().__init__()
|
|
self.model = model
|
|
self.dropout = nn.Dropout()
|
|
self.mc_iteration = mc_iteration
|
|
|
|
@torch.jit.export
|
|
def predict_step(self, batch, batch_idx):
|
|
# enable Monte Carlo Dropout
|
|
self.dropout.train()
|
|
|
|
# take average of `self.mc_iteration` iterations
|
|
pred = [self.dropout(self.model(x)).unsqueeze(0) for _ in range(self.mc_iteration)]
|
|
pred = torch.vstack(pred).mean(dim=0)
|
|
return pred
|
|
|
|
|
|
model = LitMCdropoutModel(...)
|
|
script = model.to_torchscript(file_path="model.pt", method="script")
|
|
|
|
|
|
------------
|
|
|
|
PyTorch Runtime
|
|
===============
|
|
|
|
You can also load the saved checkpoint and use it as a regular :class:`torch.nn.Module`. You can extract all your :class:`torch.nn.Module`
|
|
and load the weights using the checkpoint saved using LightningModule after training. For this, we recommend copying the exact implementation
|
|
from your LightningModule ``init`` and ``forward`` method.
|
|
|
|
.. code-block:: python
|
|
|
|
class Encoder(nn.Module):
|
|
...
|
|
|
|
|
|
class Decoder(nn.Module):
|
|
...
|
|
|
|
|
|
class AutoEncoderProd(nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.encoder = Encoder()
|
|
self.decoder = Decoder()
|
|
|
|
def forward(self, x):
|
|
return self.encoder(x)
|
|
|
|
|
|
class AutoEncoderSystem(LightningModule):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.auto_encoder = AutoEncoderProd()
|
|
|
|
def forward(self, x):
|
|
return self.auto_encoder.encoder(x)
|
|
|
|
def training_step(self, batch, batch_idx):
|
|
x, y = batch
|
|
y_hat = self.auto_encoder.encoder(x)
|
|
y_hat = self.auto_encoder.decoder(y_hat)
|
|
loss = ...
|
|
return loss
|
|
|
|
|
|
# train it
|
|
trainer = Trainer(devices=2, accelerator="gpu", strategy="ddp")
|
|
model = AutoEncoderSystem()
|
|
trainer.fit(model, train_dataloader, val_dataloader)
|
|
trainer.save_checkpoint("best_model.ckpt")
|
|
|
|
|
|
# create the PyTorch model and load the checkpoint weights
|
|
model = AutoEncoderProd()
|
|
checkpoint = torch.load("best_model.ckpt")
|
|
hyper_parameters = checkpoint["hyper_parameters"]
|
|
|
|
# if you want to restore any hyperparameters, you can pass them too
|
|
model = AutoEncoderProd(**hyper_parameters)
|
|
|
|
state_dict = checkpoint["state_dict"]
|
|
|
|
# update keys by dropping `auto_encoder.`
|
|
for key in list(model_weights):
|
|
model_weights[key.replace("auto_encoder.", "")] = model_weights.pop(key)
|
|
|
|
model.load_state_dict(model_weights)
|
|
model.eval()
|
|
x = torch.randn(1, 64)
|
|
|
|
with torch.no_grad():
|
|
y_hat = model(x)
|