2022-04-19 18:15:47 +00:00
:orphan:
########################################
Deploy models into production (advanced)
########################################
**Audience** : Machine learning engineers optimizing models for enterprise-scale production environments.
----
***** ***** ***** ***** ***** ***** ***
Compile your model to TorchScript
***** ***** ***** ***** ***** ***** ***
`TorchScript <https://pytorch.org/docs/stable/jit.html> `_ allows you to serialize your models in a way that it can be loaded in non-Python environments.
2022-05-11 18:21:11 +00:00
The `` LightningModule `` has a handy method :meth: `~pytorch_lightning.core.module.LightningModule.to_torchscript` that returns a scripted module which you
2022-04-19 18:15:47 +00:00
can save or directly use.
.. testcode :: python
class SimpleModel(LightningModule):
def __init__(self):
super().__init__()
self.l1 = torch.nn.Linear(in_features=64, out_features=4)
def forward(self, x):
return torch.relu(self.l1(x.view(x.size(0), -1)))
# create the model
model = SimpleModel()
script = model.to_torchscript()
# save for use in production environment
torch.jit.save(script, "model.pt")
It is recommended that you install the latest supported version of PyTorch to use this feature without limitations.
2022-07-25 16:51:16 +00:00
Once you have the exported model, you can run it in PyTorch or C++ runtime:
2022-04-19 18:15:47 +00:00
.. code-block :: python
inp = torch.rand(1, 64)
scripted_module = torch.jit.load("model.pt")
output = scripted_module(inp)
If you want to script a different method, you can decorate the method with :func: `torch.jit.export` :
.. code-block :: python
class LitMCdropoutModel(pl.LightningModule):
def __init__(self, model, mc_iteration):
super().__init__()
self.model = model
self.dropout = nn.Dropout()
self.mc_iteration = mc_iteration
@torch.jit.export
def predict_step(self, batch, batch_idx):
# enable Monte Carlo Dropout
self.dropout.train()
# take average of `self.mc_iteration` iterations
pred = [self.dropout(self.model(x)).unsqueeze(0) for _ in range(self.mc_iteration)]
pred = torch.vstack(pred).mean(dim=0)
return pred
model = LitMCdropoutModel(...)
script = model.to_torchscript(file_path="model.pt", method="script")