100 lines
2.9 KiB
ReStructuredText
100 lines
2.9 KiB
ReStructuredText
|
############################################
|
||
|
Deploy models into production (intermediate)
|
||
|
############################################
|
||
|
**Audience**: Researchers and MLEs looking to use their models for predictions without Lightning dependencies.
|
||
|
|
||
|
----
|
||
|
|
||
|
*********************
|
||
|
Use PyTorch as normal
|
||
|
*********************
|
||
|
If you prefer to use PyTorch directly, feel free to use any Lightning checkpoint without Lightning.
|
||
|
|
||
|
.. code-block:: python
|
||
|
|
||
|
import torch
|
||
|
|
||
|
model = torch.load("path/to/lightning/checkpoint.ckpt")
|
||
|
model.eval()
|
||
|
|
||
|
You can also pull out the specific modules you want out of the checkpoint:
|
||
|
|
||
|
.. code-block:: python
|
||
|
|
||
|
model = torch.load("path/to/lightning/checkpoint.ckpt")
|
||
|
encoder = model["encoder"]
|
||
|
encoder.eval()
|
||
|
|
||
|
----
|
||
|
|
||
|
********************************************
|
||
|
Extract nn.Module from Lightning checkpoints
|
||
|
********************************************
|
||
|
You can also load the saved checkpoint and use it as a regular :class:`torch.nn.Module`. You can extract all your :class:`torch.nn.Module`
|
||
|
and load the weights using the checkpoint saved using LightningModule after training. For this, we recommend copying the exact implementation
|
||
|
from your LightningModule ``init`` and ``forward`` method.
|
||
|
|
||
|
.. code-block:: python
|
||
|
|
||
|
class Encoder(nn.Module):
|
||
|
...
|
||
|
|
||
|
|
||
|
class Decoder(nn.Module):
|
||
|
...
|
||
|
|
||
|
|
||
|
class AutoEncoderProd(nn.Module):
|
||
|
def __init__(self):
|
||
|
super().__init__()
|
||
|
self.encoder = Encoder()
|
||
|
self.decoder = Decoder()
|
||
|
|
||
|
def forward(self, x):
|
||
|
return self.encoder(x)
|
||
|
|
||
|
|
||
|
class AutoEncoderSystem(LightningModule):
|
||
|
def __init__(self):
|
||
|
super().__init__()
|
||
|
self.auto_encoder = AutoEncoderProd()
|
||
|
|
||
|
def forward(self, x):
|
||
|
return self.auto_encoder.encoder(x)
|
||
|
|
||
|
def training_step(self, batch, batch_idx):
|
||
|
x, y = batch
|
||
|
y_hat = self.auto_encoder.encoder(x)
|
||
|
y_hat = self.auto_encoder.decoder(y_hat)
|
||
|
loss = ...
|
||
|
return loss
|
||
|
|
||
|
|
||
|
# train it
|
||
|
trainer = Trainer(devices=2, accelerator="gpu", strategy="ddp")
|
||
|
model = AutoEncoderSystem()
|
||
|
trainer.fit(model, train_dataloader, val_dataloader)
|
||
|
trainer.save_checkpoint("best_model.ckpt")
|
||
|
|
||
|
|
||
|
# create the PyTorch model and load the checkpoint weights
|
||
|
model = AutoEncoderProd()
|
||
|
checkpoint = torch.load("best_model.ckpt")
|
||
|
hyper_parameters = checkpoint["hyper_parameters"]
|
||
|
|
||
|
# if you want to restore any hyperparameters, you can pass them too
|
||
|
model = AutoEncoderProd(**hyper_parameters)
|
||
|
|
||
|
state_dict = checkpoint["state_dict"]
|
||
|
|
||
|
# update keys by dropping `auto_encoder.`
|
||
|
for key in list(model_weights):
|
||
|
model_weights[key.replace("auto_encoder.", "")] = model_weights.pop(key)
|
||
|
|
||
|
model.load_state_dict(model_weights)
|
||
|
model.eval()
|
||
|
x = torch.randn(1, 64)
|
||
|
|
||
|
with torch.no_grad():
|
||
|
y_hat = model(x)
|