lightning/docs/source-pytorch/deploy/production_intermediate.rst

99 lines
2.8 KiB
ReStructuredText

############################################
Deploy models into production (intermediate)
############################################
**Audience**: Researchers and MLEs looking to use their models for predictions without Lightning dependencies.
----
*********************
Use PyTorch as normal
*********************
If you prefer to use PyTorch directly, feel free to use any Lightning checkpoint without Lightning.
.. code-block:: python
import torch
class MyModel(nn.Module):
...
model = MyModel()
checkpoint = torch.load("path/to/lightning/checkpoint.ckpt")
model.load_state_dict(checkpoint["state_dict"])
model.eval()
----
********************************************
Extract nn.Module from Lightning checkpoints
********************************************
You can also load the saved checkpoint and use it as a regular :class:`torch.nn.Module`. You can extract all your :class:`torch.nn.Module`
and load the weights using the checkpoint saved using LightningModule after training. For this, we recommend copying the exact implementation
from your LightningModule ``init`` and ``forward`` method.
.. code-block:: python
class Encoder(nn.Module):
...
class Decoder(nn.Module):
...
class AutoEncoderProd(nn.Module):
def __init__(self):
super().__init__()
self.encoder = Encoder()
self.decoder = Decoder()
def forward(self, x):
return self.encoder(x)
class AutoEncoderSystem(LightningModule):
def __init__(self):
super().__init__()
self.auto_encoder = AutoEncoderProd()
def forward(self, x):
return self.auto_encoder.encoder(x)
def training_step(self, batch, batch_idx):
x, y = batch
y_hat = self.auto_encoder.encoder(x)
y_hat = self.auto_encoder.decoder(y_hat)
loss = ...
return loss
# train it
trainer = Trainer(devices=2, accelerator="gpu", strategy="ddp")
model = AutoEncoderSystem()
trainer.fit(model, train_dataloader, val_dataloader)
trainer.save_checkpoint("best_model.ckpt")
# create the PyTorch model and load the checkpoint weights
model = AutoEncoderProd()
checkpoint = torch.load("best_model.ckpt")
hyper_parameters = checkpoint["hyper_parameters"]
# if you want to restore any hyperparameters, you can pass them too
model = AutoEncoderProd(**hyper_parameters)
model_weights = checkpoint["state_dict"]
# update keys by dropping `auto_encoder.`
for key in list(model_weights):
model_weights[key.replace("auto_encoder.", "")] = model_weights.pop(key)
model.load_state_dict(model_weights)
model.eval()
x = torch.randn(1, 64)
with torch.no_grad():
y_hat = model(x)