Sample datatype for Serve Component (#15623)

* introducing serve component

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* clean up tests

* clean up tests

* doctest

* mypy

* structure-fix

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* cleanup

* cleanup

* test fix

* addition

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* test fix

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* requirements

* getting future url

* url for local

* sample data typeg

* changes

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* prediction

* updates

* updates

* manifest

* fix type error

* fixed test

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Rick Izzo <rick@grid.ai>
Co-authored-by: Jirka <jirka.borovec@seznam.cz>
This commit is contained in:
Sherin Thomas 2022-11-11 01:09:36 +05:30 committed by GitHub
parent 61d325376a
commit 136a090f99
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 96 additions and 6 deletions

View File

@ -0,0 +1,42 @@
# !pip install torchvision pydantic
import base64
import io
import torch
import torchvision
from PIL import Image
from pydantic import BaseModel
import lightning as L
from lightning.app.components.serve import Image as InputImage
from lightning.app.components.serve import PythonServer
class PyTorchServer(PythonServer):
def setup(self):
self._model = torchvision.models.resnet18(pretrained=True)
self._device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
self._model.to(self._device)
def predict(self, request):
image = base64.b64decode(request.image.encode("utf-8"))
image = Image.open(io.BytesIO(image))
transforms = torchvision.transforms.Compose(
[
torchvision.transforms.Resize(224),
torchvision.transforms.ToTensor(),
torchvision.transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
]
)
image = transforms(image)
image = image.to(self._device)
prediction = self._model(image.unsqueeze(0))
return {"prediction": prediction.argmax().item()}
class OutputData(BaseModel):
prediction: int
component = PyTorchServer(input_type=InputImage, output_type=OutputData, cloud_compute=L.CloudCompute("gpu"))
app = L.LightningApp(component)

View File

@ -35,6 +35,7 @@ def _adjust_manifest(**kwargs: Any) -> None:
"recursive-include requirements *.txt", "recursive-include requirements *.txt",
"recursive-include src/lightning/app/ui *", "recursive-include src/lightning/app/ui *",
"recursive-include src/lightning/cli/*-template *", # Add templates as build-in "recursive-include src/lightning/cli/*-template *", # Add templates as build-in
"include src/lightning/app/components/serve/catimage.png" + os.linesep,
# fixme: this is strange, this shall work with setup find package - include # fixme: this is strange, this shall work with setup find package - include
"prune src/lightning_app", "prune src/lightning_app",
"prune src/lightning_lite", "prune src/lightning_lite",

View File

@ -50,6 +50,7 @@ def _adjust_manifest(**__: Any) -> None:
"recursive-exclude src *.md" + os.linesep, "recursive-exclude src *.md" + os.linesep,
"recursive-exclude requirements *.txt" + os.linesep, "recursive-exclude requirements *.txt" + os.linesep,
"recursive-include src/lightning_app *.md" + os.linesep, "recursive-include src/lightning_app *.md" + os.linesep,
"include src/lightning_app/components/serve/catimage.png" + os.linesep,
"recursive-include requirements/app *.txt" + os.linesep, "recursive-include requirements/app *.txt" + os.linesep,
"recursive-include src/lightning_app/cli/*-template *" + os.linesep, # Add templates "recursive-include src/lightning_app/cli/*-template *" + os.linesep, # Add templates
] ]

View File

@ -9,7 +9,7 @@ from lightning_app.components.multi_node import (
from lightning_app.components.python.popen import PopenPythonScript from lightning_app.components.python.popen import PopenPythonScript
from lightning_app.components.python.tracer import Code, TracerPythonScript from lightning_app.components.python.tracer import Code, TracerPythonScript
from lightning_app.components.serve.gradio import ServeGradio from lightning_app.components.serve.gradio import ServeGradio
from lightning_app.components.serve.python_server import PythonServer from lightning_app.components.serve.python_server import Image, Number, PythonServer
from lightning_app.components.serve.serve import ModelInferenceAPI from lightning_app.components.serve.serve import ModelInferenceAPI
from lightning_app.components.serve.streamlit import ServeStreamlit from lightning_app.components.serve.streamlit import ServeStreamlit
from lightning_app.components.training import LightningTrainingComponent, PyTorchLightningScriptRunner from lightning_app.components.training import LightningTrainingComponent, PyTorchLightningScriptRunner
@ -24,6 +24,8 @@ __all__ = [
"ServeStreamlit", "ServeStreamlit",
"ModelInferenceAPI", "ModelInferenceAPI",
"PythonServer", "PythonServer",
"Image",
"Number",
"MultiNode", "MultiNode",
"LiteMultiNode", "LiteMultiNode",
"LightningTrainingComponent", "LightningTrainingComponent",

View File

@ -1,5 +1,5 @@
from lightning_app.components.serve.gradio import ServeGradio from lightning_app.components.serve.gradio import ServeGradio
from lightning_app.components.serve.python_server import PythonServer from lightning_app.components.serve.python_server import Image, Number, PythonServer
from lightning_app.components.serve.streamlit import ServeStreamlit from lightning_app.components.serve.streamlit import ServeStreamlit
__all__ = ["ServeGradio", "ServeStreamlit", "PythonServer"] __all__ = ["ServeGradio", "ServeStreamlit", "PythonServer", "Image", "Number"]

Binary file not shown.

After

Width:  |  Height:  |  Size: 20 KiB

View File

@ -1,5 +1,7 @@
import abc import abc
from typing import Any, Dict import base64
from pathlib import Path
from typing import Any, Dict, Optional
import uvicorn import uvicorn
from fastapi import FastAPI from fastapi import FastAPI
@ -12,6 +14,12 @@ from lightning_app.utilities.app_helpers import Logger
logger = Logger(__name__) logger = Logger(__name__)
def image_to_base64(image_path):
with open(image_path, "rb") as image_file:
encoded_string = base64.b64encode(image_file.read())
return encoded_string.decode("UTF-8")
class _DefaultInputData(BaseModel): class _DefaultInputData(BaseModel):
payload: str payload: str
@ -20,6 +28,25 @@ class _DefaultOutputData(BaseModel):
prediction: str prediction: str
class Image(BaseModel):
image: Optional[str]
@staticmethod
def _get_sample_data() -> Dict[Any, Any]:
imagepath = Path(__file__).absolute().parent / "catimage.png"
with open(imagepath, "rb") as image_file:
encoded_string = base64.b64encode(image_file.read())
return {"image": encoded_string.decode("UTF-8")}
class Number(BaseModel):
prediction: Optional[int]
@staticmethod
def _get_sample_data() -> Dict[Any, Any]:
return {"prediction": 463}
class PythonServer(LightningWork, abc.ABC): class PythonServer(LightningWork, abc.ABC):
def __init__( # type: ignore def __init__( # type: ignore
self, self,
@ -110,6 +137,9 @@ class PythonServer(LightningWork, abc.ABC):
@staticmethod @staticmethod
def _get_sample_dict_from_datatype(datatype: Any) -> dict: def _get_sample_dict_from_datatype(datatype: Any) -> dict:
if hasattr(datatype, "_get_sample_data"):
return datatype._get_sample_data()
datatype_props = datatype.schema()["properties"] datatype_props = datatype.schema()["properties"]
out: Dict[str, Any] = {} out: Dict[str, Any] = {}
for k, v in datatype_props.items(): for k, v in datatype_props.items():
@ -141,7 +171,7 @@ class PythonServer(LightningWork, abc.ABC):
url = self._future_url if self._future_url else self.url url = self._future_url if self._future_url else self.url
if not url: if not url:
# if the url is still empty, point it to localhost # if the url is still empty, point it to localhost
url = f"http://127.0.0.1{self.port}" url = f"http://127.0.0.1:{self.port}"
url = f"{url}/predict" url = f"{url}/predict"
datatype_parse_error = False datatype_parse_error = False
try: try:

View File

@ -1,6 +1,6 @@
import multiprocessing as mp import multiprocessing as mp
from lightning_app.components import PythonServer from lightning_app.components import Image, Number, PythonServer
from lightning_app.utilities.network import _configure_session, find_free_network_port from lightning_app.utilities.network import _configure_session, find_free_network_port
@ -29,3 +29,17 @@ def test_python_server_component():
res = session.post(f"http://127.0.0.1:{port}/predict", json={"payload": "test"}) res = session.post(f"http://127.0.0.1:{port}/predict", json={"payload": "test"})
process.terminate() process.terminate()
assert res.json()["prediction"] == "test" assert res.json()["prediction"] == "test"
def test_image_sample_data():
data = Image()._get_sample_data()
assert isinstance(data, dict)
assert "image" in data
assert len(data["image"]) > 100
def test_number_sample_data():
data = Number()._get_sample_data()
assert isinstance(data, dict)
assert "prediction" in data
assert data["prediction"] == 463