Sample datatype for Serve Component (#15623)
* introducing serve component * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * clean up tests * clean up tests * doctest * mypy * structure-fix * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * cleanup * cleanup * test fix * addition * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * test fix * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * requirements * getting future url * url for local * sample data typeg * changes * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * prediction * updates * updates * manifest * fix type error * fixed test Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Rick Izzo <rick@grid.ai> Co-authored-by: Jirka <jirka.borovec@seznam.cz>
This commit is contained in:
parent
61d325376a
commit
136a090f99
|
@ -0,0 +1,42 @@
|
|||
# !pip install torchvision pydantic
|
||||
import base64
|
||||
import io
|
||||
|
||||
import torch
|
||||
import torchvision
|
||||
from PIL import Image
|
||||
from pydantic import BaseModel
|
||||
|
||||
import lightning as L
|
||||
from lightning.app.components.serve import Image as InputImage
|
||||
from lightning.app.components.serve import PythonServer
|
||||
|
||||
|
||||
class PyTorchServer(PythonServer):
|
||||
def setup(self):
|
||||
self._model = torchvision.models.resnet18(pretrained=True)
|
||||
self._device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
||||
self._model.to(self._device)
|
||||
|
||||
def predict(self, request):
|
||||
image = base64.b64decode(request.image.encode("utf-8"))
|
||||
image = Image.open(io.BytesIO(image))
|
||||
transforms = torchvision.transforms.Compose(
|
||||
[
|
||||
torchvision.transforms.Resize(224),
|
||||
torchvision.transforms.ToTensor(),
|
||||
torchvision.transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
|
||||
]
|
||||
)
|
||||
image = transforms(image)
|
||||
image = image.to(self._device)
|
||||
prediction = self._model(image.unsqueeze(0))
|
||||
return {"prediction": prediction.argmax().item()}
|
||||
|
||||
|
||||
class OutputData(BaseModel):
|
||||
prediction: int
|
||||
|
||||
|
||||
component = PyTorchServer(input_type=InputImage, output_type=OutputData, cloud_compute=L.CloudCompute("gpu"))
|
||||
app = L.LightningApp(component)
|
|
@ -35,6 +35,7 @@ def _adjust_manifest(**kwargs: Any) -> None:
|
|||
"recursive-include requirements *.txt",
|
||||
"recursive-include src/lightning/app/ui *",
|
||||
"recursive-include src/lightning/cli/*-template *", # Add templates as build-in
|
||||
"include src/lightning/app/components/serve/catimage.png" + os.linesep,
|
||||
# fixme: this is strange, this shall work with setup find package - include
|
||||
"prune src/lightning_app",
|
||||
"prune src/lightning_lite",
|
||||
|
|
|
@ -50,6 +50,7 @@ def _adjust_manifest(**__: Any) -> None:
|
|||
"recursive-exclude src *.md" + os.linesep,
|
||||
"recursive-exclude requirements *.txt" + os.linesep,
|
||||
"recursive-include src/lightning_app *.md" + os.linesep,
|
||||
"include src/lightning_app/components/serve/catimage.png" + os.linesep,
|
||||
"recursive-include requirements/app *.txt" + os.linesep,
|
||||
"recursive-include src/lightning_app/cli/*-template *" + os.linesep, # Add templates
|
||||
]
|
||||
|
|
|
@ -9,7 +9,7 @@ from lightning_app.components.multi_node import (
|
|||
from lightning_app.components.python.popen import PopenPythonScript
|
||||
from lightning_app.components.python.tracer import Code, TracerPythonScript
|
||||
from lightning_app.components.serve.gradio import ServeGradio
|
||||
from lightning_app.components.serve.python_server import PythonServer
|
||||
from lightning_app.components.serve.python_server import Image, Number, PythonServer
|
||||
from lightning_app.components.serve.serve import ModelInferenceAPI
|
||||
from lightning_app.components.serve.streamlit import ServeStreamlit
|
||||
from lightning_app.components.training import LightningTrainingComponent, PyTorchLightningScriptRunner
|
||||
|
@ -24,6 +24,8 @@ __all__ = [
|
|||
"ServeStreamlit",
|
||||
"ModelInferenceAPI",
|
||||
"PythonServer",
|
||||
"Image",
|
||||
"Number",
|
||||
"MultiNode",
|
||||
"LiteMultiNode",
|
||||
"LightningTrainingComponent",
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
from lightning_app.components.serve.gradio import ServeGradio
|
||||
from lightning_app.components.serve.python_server import PythonServer
|
||||
from lightning_app.components.serve.python_server import Image, Number, PythonServer
|
||||
from lightning_app.components.serve.streamlit import ServeStreamlit
|
||||
|
||||
__all__ = ["ServeGradio", "ServeStreamlit", "PythonServer"]
|
||||
__all__ = ["ServeGradio", "ServeStreamlit", "PythonServer", "Image", "Number"]
|
||||
|
|
Binary file not shown.
After Width: | Height: | Size: 20 KiB |
|
@ -1,5 +1,7 @@
|
|||
import abc
|
||||
from typing import Any, Dict
|
||||
import base64
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
import uvicorn
|
||||
from fastapi import FastAPI
|
||||
|
@ -12,6 +14,12 @@ from lightning_app.utilities.app_helpers import Logger
|
|||
logger = Logger(__name__)
|
||||
|
||||
|
||||
def image_to_base64(image_path):
|
||||
with open(image_path, "rb") as image_file:
|
||||
encoded_string = base64.b64encode(image_file.read())
|
||||
return encoded_string.decode("UTF-8")
|
||||
|
||||
|
||||
class _DefaultInputData(BaseModel):
|
||||
payload: str
|
||||
|
||||
|
@ -20,6 +28,25 @@ class _DefaultOutputData(BaseModel):
|
|||
prediction: str
|
||||
|
||||
|
||||
class Image(BaseModel):
|
||||
image: Optional[str]
|
||||
|
||||
@staticmethod
|
||||
def _get_sample_data() -> Dict[Any, Any]:
|
||||
imagepath = Path(__file__).absolute().parent / "catimage.png"
|
||||
with open(imagepath, "rb") as image_file:
|
||||
encoded_string = base64.b64encode(image_file.read())
|
||||
return {"image": encoded_string.decode("UTF-8")}
|
||||
|
||||
|
||||
class Number(BaseModel):
|
||||
prediction: Optional[int]
|
||||
|
||||
@staticmethod
|
||||
def _get_sample_data() -> Dict[Any, Any]:
|
||||
return {"prediction": 463}
|
||||
|
||||
|
||||
class PythonServer(LightningWork, abc.ABC):
|
||||
def __init__( # type: ignore
|
||||
self,
|
||||
|
@ -110,6 +137,9 @@ class PythonServer(LightningWork, abc.ABC):
|
|||
|
||||
@staticmethod
|
||||
def _get_sample_dict_from_datatype(datatype: Any) -> dict:
|
||||
if hasattr(datatype, "_get_sample_data"):
|
||||
return datatype._get_sample_data()
|
||||
|
||||
datatype_props = datatype.schema()["properties"]
|
||||
out: Dict[str, Any] = {}
|
||||
for k, v in datatype_props.items():
|
||||
|
@ -141,7 +171,7 @@ class PythonServer(LightningWork, abc.ABC):
|
|||
url = self._future_url if self._future_url else self.url
|
||||
if not url:
|
||||
# if the url is still empty, point it to localhost
|
||||
url = f"http://127.0.0.1{self.port}"
|
||||
url = f"http://127.0.0.1:{self.port}"
|
||||
url = f"{url}/predict"
|
||||
datatype_parse_error = False
|
||||
try:
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
import multiprocessing as mp
|
||||
|
||||
from lightning_app.components import PythonServer
|
||||
from lightning_app.components import Image, Number, PythonServer
|
||||
from lightning_app.utilities.network import _configure_session, find_free_network_port
|
||||
|
||||
|
||||
|
@ -29,3 +29,17 @@ def test_python_server_component():
|
|||
res = session.post(f"http://127.0.0.1:{port}/predict", json={"payload": "test"})
|
||||
process.terminate()
|
||||
assert res.json()["prediction"] == "test"
|
||||
|
||||
|
||||
def test_image_sample_data():
|
||||
data = Image()._get_sample_data()
|
||||
assert isinstance(data, dict)
|
||||
assert "image" in data
|
||||
assert len(data["image"]) > 100
|
||||
|
||||
|
||||
def test_number_sample_data():
|
||||
data = Number()._get_sample_data()
|
||||
assert isinstance(data, dict)
|
||||
assert "prediction" in data
|
||||
assert data["prediction"] == 463
|
||||
|
|
Loading…
Reference in New Issue