Add support for async method and remove context PythonServer (#16453)

Co-authored-by: thomas <thomas@thomass-MacBook-Pro.local>
This commit is contained in:
thomas chaton 2023-01-23 15:44:36 +00:00 committed by GitHub
parent 404fc0c8b7
commit 48e1c9c99c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 10 additions and 9 deletions

View File

@ -14,7 +14,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
### Changed
-
- Add support for async predict method in PythonServer and remove torch context ([#16453](https://github.com/Lightning-AI/lightning/pull/16453))
### Deprecated

View File

@ -1,4 +1,5 @@
import abc
import asyncio
import base64
import os
import platform
@ -252,19 +253,19 @@ class PythonServer(LightningWork, abc.ABC):
return out
def _attach_predict_fn(self, fastapi_app: FastAPI) -> None:
from torch import inference_mode, no_grad
input_type: type = self.configure_input_type()
output_type: type = self.configure_output_type()
device = _get_device()
context = no_grad if device.type == "mps" else inference_mode
def predict_fn_sync(request: input_type): # type: ignore
return self.predict(request)
def predict_fn(request: input_type): # type: ignore
with context():
return self.predict(request)
async def async_predict_fn(request: input_type): # type: ignore
return await self.predict(request)
fastapi_app.post("/predict", response_model=output_type)(predict_fn)
if asyncio.iscoroutinefunction(self.predict):
fastapi_app.post("/predict", response_model=output_type)(async_predict_fn)
else:
fastapi_app.post("/predict", response_model=output_type)(predict_fn_sync)
def get_code_sample(self, url: str) -> Optional[str]:
input_type: Any = self.configure_input_type()