Add support for async method and remove context PythonServer (#16453)
Co-authored-by: thomas <thomas@thomass-MacBook-Pro.local>
This commit is contained in:
parent
404fc0c8b7
commit
48e1c9c99c
|
@ -14,7 +14,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
|
|||
|
||||
### Changed
|
||||
|
||||
-
|
||||
- Add support for async predict method in PythonServer and remove torch context ([#16453](https://github.com/Lightning-AI/lightning/pull/16453))
|
||||
|
||||
|
||||
### Deprecated
|
||||
|
|
|
@ -1,4 +1,5 @@
|
|||
import abc
|
||||
import asyncio
|
||||
import base64
|
||||
import os
|
||||
import platform
|
||||
|
@ -252,19 +253,19 @@ class PythonServer(LightningWork, abc.ABC):
|
|||
return out
|
||||
|
||||
def _attach_predict_fn(self, fastapi_app: FastAPI) -> None:
|
||||
from torch import inference_mode, no_grad
|
||||
|
||||
input_type: type = self.configure_input_type()
|
||||
output_type: type = self.configure_output_type()
|
||||
|
||||
device = _get_device()
|
||||
context = no_grad if device.type == "mps" else inference_mode
|
||||
def predict_fn_sync(request: input_type): # type: ignore
|
||||
return self.predict(request)
|
||||
|
||||
def predict_fn(request: input_type): # type: ignore
|
||||
with context():
|
||||
return self.predict(request)
|
||||
async def async_predict_fn(request: input_type): # type: ignore
|
||||
return await self.predict(request)
|
||||
|
||||
fastapi_app.post("/predict", response_model=output_type)(predict_fn)
|
||||
if asyncio.iscoroutinefunction(self.predict):
|
||||
fastapi_app.post("/predict", response_model=output_type)(async_predict_fn)
|
||||
else:
|
||||
fastapi_app.post("/predict", response_model=output_type)(predict_fn_sync)
|
||||
|
||||
def get_code_sample(self, url: str) -> Optional[str]:
|
||||
input_type: Any = self.configure_input_type()
|
||||
|
|
Loading…
Reference in New Issue