diff --git a/examples/app_server/app.py b/examples/app_server/app.py new file mode 100644 index 0000000000..6cd2397f5b --- /dev/null +++ b/examples/app_server/app.py @@ -0,0 +1,42 @@ +# !pip install torchvision pydantic +import base64 +import io + +import torch +import torchvision +from PIL import Image +from pydantic import BaseModel + +import lightning as L +from lightning.app.components.serve import Image as InputImage +from lightning.app.components.serve import PythonServer + + +class PyTorchServer(PythonServer): + def setup(self): + self._model = torchvision.models.resnet18(pretrained=True) + self._device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") + self._model.to(self._device) + + def predict(self, request): + image = base64.b64decode(request.image.encode("utf-8")) + image = Image.open(io.BytesIO(image)) + transforms = torchvision.transforms.Compose( + [ + torchvision.transforms.Resize(224), + torchvision.transforms.ToTensor(), + torchvision.transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), + ] + ) + image = transforms(image) + image = image.to(self._device) + prediction = self._model(image.unsqueeze(0)) + return {"prediction": prediction.argmax().item()} + + +class OutputData(BaseModel): + prediction: int + + +component = PyTorchServer(input_type=InputImage, output_type=OutputData, cloud_compute=L.CloudCompute("gpu")) +app = L.LightningApp(component) diff --git a/src/lightning/__setup__.py b/src/lightning/__setup__.py index 58f0ee4260..6f30e218ab 100644 --- a/src/lightning/__setup__.py +++ b/src/lightning/__setup__.py @@ -35,6 +35,7 @@ def _adjust_manifest(**kwargs: Any) -> None: "recursive-include requirements *.txt", "recursive-include src/lightning/app/ui *", "recursive-include src/lightning/cli/*-template *", # Add templates as build-in + "include src/lightning/app/components/serve/catimage.png" + os.linesep, # fixme: this is strange, this shall work with setup find package - include "prune src/lightning_app", "prune src/lightning_lite", diff --git a/src/lightning_app/__setup__.py b/src/lightning_app/__setup__.py index af5d7582cd..7a649af448 100644 --- a/src/lightning_app/__setup__.py +++ b/src/lightning_app/__setup__.py @@ -50,6 +50,7 @@ def _adjust_manifest(**__: Any) -> None: "recursive-exclude src *.md" + os.linesep, "recursive-exclude requirements *.txt" + os.linesep, "recursive-include src/lightning_app *.md" + os.linesep, + "include src/lightning_app/components/serve/catimage.png" + os.linesep, "recursive-include requirements/app *.txt" + os.linesep, "recursive-include src/lightning_app/cli/*-template *" + os.linesep, # Add templates ] diff --git a/src/lightning_app/components/__init__.py b/src/lightning_app/components/__init__.py index 2426a9042b..918d4ba911 100644 --- a/src/lightning_app/components/__init__.py +++ b/src/lightning_app/components/__init__.py @@ -9,7 +9,7 @@ from lightning_app.components.multi_node import ( from lightning_app.components.python.popen import PopenPythonScript from lightning_app.components.python.tracer import Code, TracerPythonScript from lightning_app.components.serve.gradio import ServeGradio -from lightning_app.components.serve.python_server import PythonServer +from lightning_app.components.serve.python_server import Image, Number, PythonServer from lightning_app.components.serve.serve import ModelInferenceAPI from lightning_app.components.serve.streamlit import ServeStreamlit from lightning_app.components.training import LightningTrainingComponent, PyTorchLightningScriptRunner @@ -24,6 +24,8 @@ __all__ = [ "ServeStreamlit", "ModelInferenceAPI", "PythonServer", + "Image", + "Number", "MultiNode", "LiteMultiNode", "LightningTrainingComponent", diff --git a/src/lightning_app/components/serve/__init__.py b/src/lightning_app/components/serve/__init__.py index 8d25bf3a17..cb46a71bf9 100644 --- a/src/lightning_app/components/serve/__init__.py +++ b/src/lightning_app/components/serve/__init__.py @@ -1,5 +1,5 @@ from lightning_app.components.serve.gradio import ServeGradio -from lightning_app.components.serve.python_server import PythonServer +from lightning_app.components.serve.python_server import Image, Number, PythonServer from lightning_app.components.serve.streamlit import ServeStreamlit -__all__ = ["ServeGradio", "ServeStreamlit", "PythonServer"] +__all__ = ["ServeGradio", "ServeStreamlit", "PythonServer", "Image", "Number"] diff --git a/src/lightning_app/components/serve/catimage.png b/src/lightning_app/components/serve/catimage.png new file mode 100644 index 0000000000..a76a35bdb8 Binary files /dev/null and b/src/lightning_app/components/serve/catimage.png differ diff --git a/src/lightning_app/components/serve/python_server.py b/src/lightning_app/components/serve/python_server.py index 99b506dae3..03b0ceb260 100644 --- a/src/lightning_app/components/serve/python_server.py +++ b/src/lightning_app/components/serve/python_server.py @@ -1,5 +1,7 @@ import abc -from typing import Any, Dict +import base64 +from pathlib import Path +from typing import Any, Dict, Optional import uvicorn from fastapi import FastAPI @@ -12,6 +14,12 @@ from lightning_app.utilities.app_helpers import Logger logger = Logger(__name__) +def image_to_base64(image_path): + with open(image_path, "rb") as image_file: + encoded_string = base64.b64encode(image_file.read()) + return encoded_string.decode("UTF-8") + + class _DefaultInputData(BaseModel): payload: str @@ -20,6 +28,25 @@ class _DefaultOutputData(BaseModel): prediction: str +class Image(BaseModel): + image: Optional[str] + + @staticmethod + def _get_sample_data() -> Dict[Any, Any]: + imagepath = Path(__file__).absolute().parent / "catimage.png" + with open(imagepath, "rb") as image_file: + encoded_string = base64.b64encode(image_file.read()) + return {"image": encoded_string.decode("UTF-8")} + + +class Number(BaseModel): + prediction: Optional[int] + + @staticmethod + def _get_sample_data() -> Dict[Any, Any]: + return {"prediction": 463} + + class PythonServer(LightningWork, abc.ABC): def __init__( # type: ignore self, @@ -110,6 +137,9 @@ class PythonServer(LightningWork, abc.ABC): @staticmethod def _get_sample_dict_from_datatype(datatype: Any) -> dict: + if hasattr(datatype, "_get_sample_data"): + return datatype._get_sample_data() + datatype_props = datatype.schema()["properties"] out: Dict[str, Any] = {} for k, v in datatype_props.items(): @@ -141,7 +171,7 @@ class PythonServer(LightningWork, abc.ABC): url = self._future_url if self._future_url else self.url if not url: # if the url is still empty, point it to localhost - url = f"http://127.0.0.1{self.port}" + url = f"http://127.0.0.1:{self.port}" url = f"{url}/predict" datatype_parse_error = False try: diff --git a/tests/tests_app/components/serve/test_python_server.py b/tests/tests_app/components/serve/test_python_server.py index 7a477d98b9..313638e9ec 100644 --- a/tests/tests_app/components/serve/test_python_server.py +++ b/tests/tests_app/components/serve/test_python_server.py @@ -1,6 +1,6 @@ import multiprocessing as mp -from lightning_app.components import PythonServer +from lightning_app.components import Image, Number, PythonServer from lightning_app.utilities.network import _configure_session, find_free_network_port @@ -29,3 +29,17 @@ def test_python_server_component(): res = session.post(f"http://127.0.0.1:{port}/predict", json={"payload": "test"}) process.terminate() assert res.json()["prediction"] == "test" + + +def test_image_sample_data(): + data = Image()._get_sample_data() + assert isinstance(data, dict) + assert "image" in data + assert len(data["image"]) > 100 + + +def test_number_sample_data(): + data = Number()._get_sample_data() + assert isinstance(data, dict) + assert "prediction" in data + assert data["prediction"] == 463