diff --git a/src/lightning_app/CHANGELOG.md b/src/lightning_app/CHANGELOG.md index e983f808c2..05662d1d8c 100644 --- a/src/lightning_app/CHANGELOG.md +++ b/src/lightning_app/CHANGELOG.md @@ -14,7 +14,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Changed -- +- Add support for async predict method in PythonServer and remove torch context ([#16453](https://github.com/Lightning-AI/lightning/pull/16453)) ### Deprecated diff --git a/src/lightning_app/components/serve/python_server.py b/src/lightning_app/components/serve/python_server.py index caae6f584c..19a088b86c 100644 --- a/src/lightning_app/components/serve/python_server.py +++ b/src/lightning_app/components/serve/python_server.py @@ -1,4 +1,5 @@ import abc +import asyncio import base64 import os import platform @@ -252,19 +253,19 @@ class PythonServer(LightningWork, abc.ABC): return out def _attach_predict_fn(self, fastapi_app: FastAPI) -> None: - from torch import inference_mode, no_grad - input_type: type = self.configure_input_type() output_type: type = self.configure_output_type() - device = _get_device() - context = no_grad if device.type == "mps" else inference_mode + def predict_fn_sync(request: input_type): # type: ignore + return self.predict(request) - def predict_fn(request: input_type): # type: ignore - with context(): - return self.predict(request) + async def async_predict_fn(request: input_type): # type: ignore + return await self.predict(request) - fastapi_app.post("/predict", response_model=output_type)(predict_fn) + if asyncio.iscoroutinefunction(self.predict): + fastapi_app.post("/predict", response_model=output_type)(async_predict_fn) + else: + fastapi_app.post("/predict", response_model=output_type)(predict_fn_sync) def get_code_sample(self, url: str) -> Optional[str]: input_type: Any = self.configure_input_type()