Sample datatype for Serve Component (#15623)

* introducing serve component

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* clean up tests

* clean up tests

* doctest

* mypy

* structure-fix

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* cleanup

* cleanup

* test fix

* addition

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* test fix

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* requirements

* getting future url

* url for local

* sample data typeg

* changes

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* prediction

* updates

* updates

* manifest

* fix type error

* fixed test

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Rick Izzo <rick@grid.ai>
Co-authored-by: Jirka <jirka.borovec@seznam.cz>
This commit is contained in:
Sherin Thomas 2022-11-11 01:09:36 +05:30 committed by GitHub
parent 61d325376a
commit 136a090f99
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 96 additions and 6 deletions

View File

@ -0,0 +1,42 @@
# !pip install torchvision pydantic
import base64
import io
import torch
import torchvision
from PIL import Image
from pydantic import BaseModel
import lightning as L
from lightning.app.components.serve import Image as InputImage
from lightning.app.components.serve import PythonServer
class PyTorchServer(PythonServer):
def setup(self):
self._model = torchvision.models.resnet18(pretrained=True)
self._device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
self._model.to(self._device)
def predict(self, request):
image = base64.b64decode(request.image.encode("utf-8"))
image = Image.open(io.BytesIO(image))
transforms = torchvision.transforms.Compose(
[
torchvision.transforms.Resize(224),
torchvision.transforms.ToTensor(),
torchvision.transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
]
)
image = transforms(image)
image = image.to(self._device)
prediction = self._model(image.unsqueeze(0))
return {"prediction": prediction.argmax().item()}
class OutputData(BaseModel):
prediction: int
component = PyTorchServer(input_type=InputImage, output_type=OutputData, cloud_compute=L.CloudCompute("gpu"))
app = L.LightningApp(component)

View File

@ -35,6 +35,7 @@ def _adjust_manifest(**kwargs: Any) -> None:
"recursive-include requirements *.txt",
"recursive-include src/lightning/app/ui *",
"recursive-include src/lightning/cli/*-template *", # Add templates as build-in
"include src/lightning/app/components/serve/catimage.png" + os.linesep,
# fixme: this is strange, this shall work with setup find package - include
"prune src/lightning_app",
"prune src/lightning_lite",

View File

@ -50,6 +50,7 @@ def _adjust_manifest(**__: Any) -> None:
"recursive-exclude src *.md" + os.linesep,
"recursive-exclude requirements *.txt" + os.linesep,
"recursive-include src/lightning_app *.md" + os.linesep,
"include src/lightning_app/components/serve/catimage.png" + os.linesep,
"recursive-include requirements/app *.txt" + os.linesep,
"recursive-include src/lightning_app/cli/*-template *" + os.linesep, # Add templates
]

View File

@ -9,7 +9,7 @@ from lightning_app.components.multi_node import (
from lightning_app.components.python.popen import PopenPythonScript
from lightning_app.components.python.tracer import Code, TracerPythonScript
from lightning_app.components.serve.gradio import ServeGradio
from lightning_app.components.serve.python_server import PythonServer
from lightning_app.components.serve.python_server import Image, Number, PythonServer
from lightning_app.components.serve.serve import ModelInferenceAPI
from lightning_app.components.serve.streamlit import ServeStreamlit
from lightning_app.components.training import LightningTrainingComponent, PyTorchLightningScriptRunner
@ -24,6 +24,8 @@ __all__ = [
"ServeStreamlit",
"ModelInferenceAPI",
"PythonServer",
"Image",
"Number",
"MultiNode",
"LiteMultiNode",
"LightningTrainingComponent",

View File

@ -1,5 +1,5 @@
from lightning_app.components.serve.gradio import ServeGradio
from lightning_app.components.serve.python_server import PythonServer
from lightning_app.components.serve.python_server import Image, Number, PythonServer
from lightning_app.components.serve.streamlit import ServeStreamlit
__all__ = ["ServeGradio", "ServeStreamlit", "PythonServer"]
__all__ = ["ServeGradio", "ServeStreamlit", "PythonServer", "Image", "Number"]

Binary file not shown.

After

Width:  |  Height:  |  Size: 20 KiB

View File

@ -1,5 +1,7 @@
import abc
from typing import Any, Dict
import base64
from pathlib import Path
from typing import Any, Dict, Optional
import uvicorn
from fastapi import FastAPI
@ -12,6 +14,12 @@ from lightning_app.utilities.app_helpers import Logger
logger = Logger(__name__)
def image_to_base64(image_path):
with open(image_path, "rb") as image_file:
encoded_string = base64.b64encode(image_file.read())
return encoded_string.decode("UTF-8")
class _DefaultInputData(BaseModel):
payload: str
@ -20,6 +28,25 @@ class _DefaultOutputData(BaseModel):
prediction: str
class Image(BaseModel):
image: Optional[str]
@staticmethod
def _get_sample_data() -> Dict[Any, Any]:
imagepath = Path(__file__).absolute().parent / "catimage.png"
with open(imagepath, "rb") as image_file:
encoded_string = base64.b64encode(image_file.read())
return {"image": encoded_string.decode("UTF-8")}
class Number(BaseModel):
prediction: Optional[int]
@staticmethod
def _get_sample_data() -> Dict[Any, Any]:
return {"prediction": 463}
class PythonServer(LightningWork, abc.ABC):
def __init__( # type: ignore
self,
@ -110,6 +137,9 @@ class PythonServer(LightningWork, abc.ABC):
@staticmethod
def _get_sample_dict_from_datatype(datatype: Any) -> dict:
if hasattr(datatype, "_get_sample_data"):
return datatype._get_sample_data()
datatype_props = datatype.schema()["properties"]
out: Dict[str, Any] = {}
for k, v in datatype_props.items():
@ -141,7 +171,7 @@ class PythonServer(LightningWork, abc.ABC):
url = self._future_url if self._future_url else self.url
if not url:
# if the url is still empty, point it to localhost
url = f"http://127.0.0.1{self.port}"
url = f"http://127.0.0.1:{self.port}"
url = f"{url}/predict"
datatype_parse_error = False
try:

View File

@ -1,6 +1,6 @@
import multiprocessing as mp
from lightning_app.components import PythonServer
from lightning_app.components import Image, Number, PythonServer
from lightning_app.utilities.network import _configure_session, find_free_network_port
@ -29,3 +29,17 @@ def test_python_server_component():
res = session.post(f"http://127.0.0.1:{port}/predict", json={"payload": "test"})
process.terminate()
assert res.json()["prediction"] == "test"
def test_image_sample_data():
data = Image()._get_sample_data()
assert isinstance(data, dict)
assert "image" in data
assert len(data["image"]) > 100
def test_number_sample_data():
data = Number()._get_sample_data()
assert isinstance(data, dict)
assert "prediction" in data
assert data["prediction"] == 463