87 lines
2.8 KiB
Python
87 lines
2.8 KiB
Python
# ! pip install torch torchvision
|
|
from typing import List
|
|
|
|
import torch
|
|
import torchvision
|
|
from pydantic import BaseModel
|
|
|
|
import lightning as L
|
|
|
|
|
|
class BatchRequestModel(BaseModel):
|
|
inputs: List[L.app.components.Image]
|
|
|
|
|
|
class BatchResponse(BaseModel):
|
|
outputs: List[L.app.components.Number]
|
|
|
|
|
|
class PyTorchServer(L.app.components.PythonServer):
|
|
def __init__(self, *args, **kwargs):
|
|
super().__init__(
|
|
input_type=BatchRequestModel,
|
|
output_type=BatchResponse,
|
|
*args,
|
|
**kwargs,
|
|
)
|
|
|
|
def setup(self):
|
|
self._device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
|
self._model = torchvision.models.resnet18(pretrained=True).to(self._device)
|
|
|
|
def predict(self, requests: BatchRequestModel):
|
|
transforms = torchvision.transforms.Compose(
|
|
[
|
|
torchvision.transforms.Resize(224),
|
|
torchvision.transforms.ToTensor(),
|
|
torchvision.transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
|
|
]
|
|
)
|
|
images = []
|
|
for request in requests.inputs:
|
|
image = L.app.components.serve.types.image.Image.deserialize(request.image)
|
|
image = transforms(image).unsqueeze(0)
|
|
images.append(image)
|
|
images = torch.cat(images)
|
|
images = images.to(self._device)
|
|
predictions = self._model(images)
|
|
results = predictions.argmax(1).cpu().numpy().tolist()
|
|
return BatchResponse(outputs=[{"prediction": pred} for pred in results])
|
|
|
|
|
|
class MyAutoScaler(L.app.components.AutoScaler):
|
|
def scale(self, replicas: int, metrics: dict) -> int:
|
|
"""The default scaling logic that users can override."""
|
|
# scale out if the number of pending requests exceeds max batch size.
|
|
max_requests_per_work = self.max_batch_size
|
|
pending_requests_per_work = metrics["pending_requests"] / (replicas + metrics["pending_works"])
|
|
if pending_requests_per_work >= max_requests_per_work:
|
|
return replicas + 1
|
|
|
|
# scale in if the number of pending requests is below 25% of max_requests_per_work
|
|
min_requests_per_work = max_requests_per_work * 0.25
|
|
pending_requests_per_work = metrics["pending_requests"] / replicas
|
|
if pending_requests_per_work < min_requests_per_work:
|
|
return replicas - 1
|
|
|
|
return replicas
|
|
|
|
|
|
app = L.LightningApp(
|
|
MyAutoScaler(
|
|
# work class and args
|
|
PyTorchServer,
|
|
cloud_compute=L.CloudCompute("gpu"),
|
|
# autoscaler specific args
|
|
min_replicas=1,
|
|
max_replicas=4,
|
|
scale_out_interval=10,
|
|
scale_in_interval=10,
|
|
endpoint="predict",
|
|
input_type=L.app.components.Image,
|
|
output_type=L.app.components.Number,
|
|
timeout_batching=1,
|
|
max_batch_size=8,
|
|
)
|
|
)
|