Sample datatype for Serve Component (#15623)
* introducing serve component * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * clean up tests * clean up tests * doctest * mypy * structure-fix * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * cleanup * cleanup * test fix * addition * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * test fix * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * requirements * getting future url * url for local * sample data typeg * changes * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * prediction * updates * updates * manifest * fix type error * fixed test Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Rick Izzo <rick@grid.ai> Co-authored-by: Jirka <jirka.borovec@seznam.cz>
This commit is contained in:
parent
61d325376a
commit
136a090f99
|
@ -0,0 +1,42 @@
|
||||||
|
# !pip install torchvision pydantic
|
||||||
|
import base64
|
||||||
|
import io
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torchvision
|
||||||
|
from PIL import Image
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
import lightning as L
|
||||||
|
from lightning.app.components.serve import Image as InputImage
|
||||||
|
from lightning.app.components.serve import PythonServer
|
||||||
|
|
||||||
|
|
||||||
|
class PyTorchServer(PythonServer):
|
||||||
|
def setup(self):
|
||||||
|
self._model = torchvision.models.resnet18(pretrained=True)
|
||||||
|
self._device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
||||||
|
self._model.to(self._device)
|
||||||
|
|
||||||
|
def predict(self, request):
|
||||||
|
image = base64.b64decode(request.image.encode("utf-8"))
|
||||||
|
image = Image.open(io.BytesIO(image))
|
||||||
|
transforms = torchvision.transforms.Compose(
|
||||||
|
[
|
||||||
|
torchvision.transforms.Resize(224),
|
||||||
|
torchvision.transforms.ToTensor(),
|
||||||
|
torchvision.transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
image = transforms(image)
|
||||||
|
image = image.to(self._device)
|
||||||
|
prediction = self._model(image.unsqueeze(0))
|
||||||
|
return {"prediction": prediction.argmax().item()}
|
||||||
|
|
||||||
|
|
||||||
|
class OutputData(BaseModel):
|
||||||
|
prediction: int
|
||||||
|
|
||||||
|
|
||||||
|
component = PyTorchServer(input_type=InputImage, output_type=OutputData, cloud_compute=L.CloudCompute("gpu"))
|
||||||
|
app = L.LightningApp(component)
|
|
@ -35,6 +35,7 @@ def _adjust_manifest(**kwargs: Any) -> None:
|
||||||
"recursive-include requirements *.txt",
|
"recursive-include requirements *.txt",
|
||||||
"recursive-include src/lightning/app/ui *",
|
"recursive-include src/lightning/app/ui *",
|
||||||
"recursive-include src/lightning/cli/*-template *", # Add templates as build-in
|
"recursive-include src/lightning/cli/*-template *", # Add templates as build-in
|
||||||
|
"include src/lightning/app/components/serve/catimage.png" + os.linesep,
|
||||||
# fixme: this is strange, this shall work with setup find package - include
|
# fixme: this is strange, this shall work with setup find package - include
|
||||||
"prune src/lightning_app",
|
"prune src/lightning_app",
|
||||||
"prune src/lightning_lite",
|
"prune src/lightning_lite",
|
||||||
|
|
|
@ -50,6 +50,7 @@ def _adjust_manifest(**__: Any) -> None:
|
||||||
"recursive-exclude src *.md" + os.linesep,
|
"recursive-exclude src *.md" + os.linesep,
|
||||||
"recursive-exclude requirements *.txt" + os.linesep,
|
"recursive-exclude requirements *.txt" + os.linesep,
|
||||||
"recursive-include src/lightning_app *.md" + os.linesep,
|
"recursive-include src/lightning_app *.md" + os.linesep,
|
||||||
|
"include src/lightning_app/components/serve/catimage.png" + os.linesep,
|
||||||
"recursive-include requirements/app *.txt" + os.linesep,
|
"recursive-include requirements/app *.txt" + os.linesep,
|
||||||
"recursive-include src/lightning_app/cli/*-template *" + os.linesep, # Add templates
|
"recursive-include src/lightning_app/cli/*-template *" + os.linesep, # Add templates
|
||||||
]
|
]
|
||||||
|
|
|
@ -9,7 +9,7 @@ from lightning_app.components.multi_node import (
|
||||||
from lightning_app.components.python.popen import PopenPythonScript
|
from lightning_app.components.python.popen import PopenPythonScript
|
||||||
from lightning_app.components.python.tracer import Code, TracerPythonScript
|
from lightning_app.components.python.tracer import Code, TracerPythonScript
|
||||||
from lightning_app.components.serve.gradio import ServeGradio
|
from lightning_app.components.serve.gradio import ServeGradio
|
||||||
from lightning_app.components.serve.python_server import PythonServer
|
from lightning_app.components.serve.python_server import Image, Number, PythonServer
|
||||||
from lightning_app.components.serve.serve import ModelInferenceAPI
|
from lightning_app.components.serve.serve import ModelInferenceAPI
|
||||||
from lightning_app.components.serve.streamlit import ServeStreamlit
|
from lightning_app.components.serve.streamlit import ServeStreamlit
|
||||||
from lightning_app.components.training import LightningTrainingComponent, PyTorchLightningScriptRunner
|
from lightning_app.components.training import LightningTrainingComponent, PyTorchLightningScriptRunner
|
||||||
|
@ -24,6 +24,8 @@ __all__ = [
|
||||||
"ServeStreamlit",
|
"ServeStreamlit",
|
||||||
"ModelInferenceAPI",
|
"ModelInferenceAPI",
|
||||||
"PythonServer",
|
"PythonServer",
|
||||||
|
"Image",
|
||||||
|
"Number",
|
||||||
"MultiNode",
|
"MultiNode",
|
||||||
"LiteMultiNode",
|
"LiteMultiNode",
|
||||||
"LightningTrainingComponent",
|
"LightningTrainingComponent",
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
from lightning_app.components.serve.gradio import ServeGradio
|
from lightning_app.components.serve.gradio import ServeGradio
|
||||||
from lightning_app.components.serve.python_server import PythonServer
|
from lightning_app.components.serve.python_server import Image, Number, PythonServer
|
||||||
from lightning_app.components.serve.streamlit import ServeStreamlit
|
from lightning_app.components.serve.streamlit import ServeStreamlit
|
||||||
|
|
||||||
__all__ = ["ServeGradio", "ServeStreamlit", "PythonServer"]
|
__all__ = ["ServeGradio", "ServeStreamlit", "PythonServer", "Image", "Number"]
|
||||||
|
|
Binary file not shown.
After Width: | Height: | Size: 20 KiB |
|
@ -1,5 +1,7 @@
|
||||||
import abc
|
import abc
|
||||||
from typing import Any, Dict
|
import base64
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Any, Dict, Optional
|
||||||
|
|
||||||
import uvicorn
|
import uvicorn
|
||||||
from fastapi import FastAPI
|
from fastapi import FastAPI
|
||||||
|
@ -12,6 +14,12 @@ from lightning_app.utilities.app_helpers import Logger
|
||||||
logger = Logger(__name__)
|
logger = Logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def image_to_base64(image_path):
|
||||||
|
with open(image_path, "rb") as image_file:
|
||||||
|
encoded_string = base64.b64encode(image_file.read())
|
||||||
|
return encoded_string.decode("UTF-8")
|
||||||
|
|
||||||
|
|
||||||
class _DefaultInputData(BaseModel):
|
class _DefaultInputData(BaseModel):
|
||||||
payload: str
|
payload: str
|
||||||
|
|
||||||
|
@ -20,6 +28,25 @@ class _DefaultOutputData(BaseModel):
|
||||||
prediction: str
|
prediction: str
|
||||||
|
|
||||||
|
|
||||||
|
class Image(BaseModel):
|
||||||
|
image: Optional[str]
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _get_sample_data() -> Dict[Any, Any]:
|
||||||
|
imagepath = Path(__file__).absolute().parent / "catimage.png"
|
||||||
|
with open(imagepath, "rb") as image_file:
|
||||||
|
encoded_string = base64.b64encode(image_file.read())
|
||||||
|
return {"image": encoded_string.decode("UTF-8")}
|
||||||
|
|
||||||
|
|
||||||
|
class Number(BaseModel):
|
||||||
|
prediction: Optional[int]
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _get_sample_data() -> Dict[Any, Any]:
|
||||||
|
return {"prediction": 463}
|
||||||
|
|
||||||
|
|
||||||
class PythonServer(LightningWork, abc.ABC):
|
class PythonServer(LightningWork, abc.ABC):
|
||||||
def __init__( # type: ignore
|
def __init__( # type: ignore
|
||||||
self,
|
self,
|
||||||
|
@ -110,6 +137,9 @@ class PythonServer(LightningWork, abc.ABC):
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _get_sample_dict_from_datatype(datatype: Any) -> dict:
|
def _get_sample_dict_from_datatype(datatype: Any) -> dict:
|
||||||
|
if hasattr(datatype, "_get_sample_data"):
|
||||||
|
return datatype._get_sample_data()
|
||||||
|
|
||||||
datatype_props = datatype.schema()["properties"]
|
datatype_props = datatype.schema()["properties"]
|
||||||
out: Dict[str, Any] = {}
|
out: Dict[str, Any] = {}
|
||||||
for k, v in datatype_props.items():
|
for k, v in datatype_props.items():
|
||||||
|
@ -141,7 +171,7 @@ class PythonServer(LightningWork, abc.ABC):
|
||||||
url = self._future_url if self._future_url else self.url
|
url = self._future_url if self._future_url else self.url
|
||||||
if not url:
|
if not url:
|
||||||
# if the url is still empty, point it to localhost
|
# if the url is still empty, point it to localhost
|
||||||
url = f"http://127.0.0.1{self.port}"
|
url = f"http://127.0.0.1:{self.port}"
|
||||||
url = f"{url}/predict"
|
url = f"{url}/predict"
|
||||||
datatype_parse_error = False
|
datatype_parse_error = False
|
||||||
try:
|
try:
|
||||||
|
|
|
@ -1,6 +1,6 @@
|
||||||
import multiprocessing as mp
|
import multiprocessing as mp
|
||||||
|
|
||||||
from lightning_app.components import PythonServer
|
from lightning_app.components import Image, Number, PythonServer
|
||||||
from lightning_app.utilities.network import _configure_session, find_free_network_port
|
from lightning_app.utilities.network import _configure_session, find_free_network_port
|
||||||
|
|
||||||
|
|
||||||
|
@ -29,3 +29,17 @@ def test_python_server_component():
|
||||||
res = session.post(f"http://127.0.0.1:{port}/predict", json={"payload": "test"})
|
res = session.post(f"http://127.0.0.1:{port}/predict", json={"payload": "test"})
|
||||||
process.terminate()
|
process.terminate()
|
||||||
assert res.json()["prediction"] == "test"
|
assert res.json()["prediction"] == "test"
|
||||||
|
|
||||||
|
|
||||||
|
def test_image_sample_data():
|
||||||
|
data = Image()._get_sample_data()
|
||||||
|
assert isinstance(data, dict)
|
||||||
|
assert "image" in data
|
||||||
|
assert len(data["image"]) > 100
|
||||||
|
|
||||||
|
|
||||||
|
def test_number_sample_data():
|
||||||
|
data = Number()._get_sample_data()
|
||||||
|
assert isinstance(data, dict)
|
||||||
|
assert "prediction" in data
|
||||||
|
assert data["prediction"] == 463
|
||||||
|
|
Loading…
Reference in New Issue