lightning/docs/source-pytorch/deploy/production_advanced_2.rst

70 lines
2.3 KiB
ReStructuredText

:orphan:
########################################
Deploy models into production (advanced)
########################################
**Audience**: Machine learning engineers optimizing models for enterprise-scale production environments.
----
*********************************
Compile your model to TorchScript
*********************************
`TorchScript <https://pytorch.org/docs/stable/jit.html>`_ allows you to serialize your models in a way that it can be loaded in non-Python environments.
The ``LightningModule`` has a handy method :meth:`~pytorch_lightning.core.module.LightningModule.to_torchscript` that returns a scripted module which you
can save or directly use.
.. testcode:: python
class SimpleModel(LightningModule):
def __init__(self):
super().__init__()
self.l1 = torch.nn.Linear(in_features=64, out_features=4)
def forward(self, x):
return torch.relu(self.l1(x.view(x.size(0), -1)))
# create the model
model = SimpleModel()
script = model.to_torchscript()
# save for use in production environment
torch.jit.save(script, "model.pt")
It is recommended that you install the latest supported version of PyTorch to use this feature without limitations.
Once you have the exported model, you can run it in PyTorch or C++ runtime:
.. code-block:: python
inp = torch.rand(1, 64)
scripted_module = torch.jit.load("model.pt")
output = scripted_module(inp)
If you want to script a different method, you can decorate the method with :func:`torch.jit.export`:
.. code-block:: python
class LitMCdropoutModel(pl.LightningModule):
def __init__(self, model, mc_iteration):
super().__init__()
self.model = model
self.dropout = nn.Dropout()
self.mc_iteration = mc_iteration
@torch.jit.export
def predict_step(self, batch, batch_idx):
# enable Monte Carlo Dropout
self.dropout.train()
# take average of `self.mc_iteration` iterations
pred = [self.dropout(self.model(x)).unsqueeze(0) for _ in range(self.mc_iteration)]
pred = torch.vstack(pred).mean(dim=0)
return pred
model = LitMCdropoutModel(...)
script = model.to_torchscript(file_path="model.pt", method="script")