lightning/examples/pl_servable_module/production.py

125 lines
4.1 KiB
Python

import base64
from dataclasses import dataclass
from io import BytesIO
from os import path
from typing import Dict, Optional
import numpy as np
import torch
import torchvision
import torchvision.transforms as T
from PIL import Image as PILImage
from lightning.pytorch import cli_lightning_logo, LightningDataModule, LightningModule
from lightning.pytorch.cli import LightningCLI
from lightning.pytorch.serve import ServableModule, ServableModuleValidator
from lightning.pytorch.utilities.model_helpers import get_torchvision_model
DATASETS_PATH = path.join(path.dirname(__file__), "..", "..", "Datasets")
class LitModule(LightningModule):
def __init__(self, name: str = "resnet18"):
super().__init__()
self.model = get_torchvision_model(name, weights="DEFAULT")
self.model.fc = torch.nn.Linear(self.model.fc.in_features, 10)
self.criterion = torch.nn.CrossEntropyLoss()
def training_step(self, batch, batch_idx):
inputs, labels = batch
outputs = self.model(inputs)
loss = self.criterion(outputs, labels)
self.log("train_loss", loss)
return loss
def validation_step(self, batch, batch_idx):
inputs, labels = batch
outputs = self.model(inputs)
loss = self.criterion(outputs, labels)
self.log("val_loss", loss)
def configure_optimizers(self):
return torch.optim.SGD(self.parameters(), lr=0.001, momentum=0.9)
class CIFAR10DataModule(LightningDataModule):
transform = T.Compose([T.Resize(256), T.CenterCrop(224), T.ToTensor()])
def train_dataloader(self, *args, **kwargs):
trainset = torchvision.datasets.CIFAR10(root=DATASETS_PATH, train=True, download=True, transform=self.transform)
return torch.utils.data.DataLoader(trainset, batch_size=2, shuffle=True, num_workers=0)
def val_dataloader(self, *args, **kwargs):
valset = torchvision.datasets.CIFAR10(root=DATASETS_PATH, train=False, download=True, transform=self.transform)
return torch.utils.data.DataLoader(valset, batch_size=2, shuffle=True, num_workers=0)
@dataclass(unsafe_hash=True)
class Image:
height: Optional[int] = None
width: Optional[int] = None
extension: str = "JPEG"
mode: str = "RGB"
channel_first: bool = False
def deserialize(self, data: str) -> torch.Tensor:
encoded_with_padding = (data + "===").encode("UTF-8")
img = base64.b64decode(encoded_with_padding)
buffer = BytesIO(img)
img = PILImage.open(buffer, mode="r")
if self.height and self.width:
img = img.resize((self.width, self.height))
arr = np.array(img)
return T.ToTensor()(arr).unsqueeze(0)
class Top1:
def serialize(self, tensor: torch.Tensor) -> int:
return torch.nn.functional.softmax(tensor).argmax().item()
class ProductionReadyModel(LitModule, ServableModule):
def configure_payload(self):
# 1: Access the train dataloader and load a single sample.
image, _ = self.trainer.train_dataloader.iterables.dataset[0]
# 2: Convert the image into a PIL Image to bytes and encode it with base64
pil_image = T.ToPILImage()(image)
buffered = BytesIO()
pil_image.save(buffered, format="JPEG")
img_str = base64.b64encode(buffered.getvalue()).decode("UTF-8")
payload = {"body": {"x": img_str}}
return payload
def configure_serialization(self):
return {"x": Image(224, 224).deserialize}, {"output": Top1().serialize}
def serve_step(self, x: torch.Tensor) -> Dict[str, torch.Tensor]:
return {"output": self.model(x)}
def configure_response(self):
return {"output": 7}
def cli_main():
cli = LightningCLI(
ProductionReadyModel,
CIFAR10DataModule,
seed_everything_default=42,
save_config_kwargs={"overwrite": True},
run=False,
trainer_defaults={
"callbacks": [ServableModuleValidator()],
"max_epochs": 1,
"limit_train_batches": 5,
"limit_val_batches": 5,
},
)
cli.trainer.fit(cli.model, cli.datamodule)
if __name__ == "__main__":
cli_lightning_logo()
cli_main()