[App] PoC: Add support for Request (#16047)
This commit is contained in:
parent
005b6f2374
commit
592b12658a
|
@ -110,8 +110,8 @@ celerybeat-schedule
|
|||
|
||||
# dotenv
|
||||
.env
|
||||
.env_staging
|
||||
.env_local
|
||||
.env.staging
|
||||
.env.local
|
||||
|
||||
# virtualenv
|
||||
.venv
|
||||
|
|
|
@ -13,6 +13,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
|
|||
- Added a progres bar while connecting to an app through the CLI ([#16035](https://github.com/Lightning-AI/lightning/pull/16035))
|
||||
|
||||
|
||||
- Added partial support for fastapi `Request` annotation in `configure_api` handlers ([#16047](https://github.com/Lightning-AI/lightning/pull/16047))
|
||||
|
||||
|
||||
### Changed
|
||||
|
||||
-
|
||||
|
|
|
@ -2,12 +2,14 @@ import asyncio
|
|||
import inspect
|
||||
import time
|
||||
from copy import deepcopy
|
||||
from dataclasses import dataclass
|
||||
from functools import wraps
|
||||
from multiprocessing import Queue
|
||||
from typing import Any, Callable, Dict, List, Optional
|
||||
from uuid import uuid4
|
||||
|
||||
from fastapi import FastAPI, HTTPException
|
||||
from fastapi import FastAPI, HTTPException, Request
|
||||
from lightning_utilities.core.apply_func import apply_to_collection
|
||||
|
||||
from lightning_app.api.request_types import _APIRequest, _CommandRequest, _RequestResponse
|
||||
from lightning_app.utilities.app_helpers import Logger
|
||||
|
@ -19,6 +21,77 @@ def _signature_proxy_function():
|
|||
pass
|
||||
|
||||
|
||||
@dataclass
|
||||
class _FastApiMockRequest:
|
||||
"""This class is meant to mock FastAPI Request class that isn't pickle-able.
|
||||
|
||||
If a user relies on FastAPI Request annotation, the Lightning framework
|
||||
patches the annotation before pickling and replace them right after.
|
||||
|
||||
Finally, the FastAPI request is converted back to the _FastApiMockRequest
|
||||
before being delivered to the users.
|
||||
|
||||
Example:
|
||||
|
||||
import lightning as L
|
||||
from fastapi import Request
|
||||
from lightning.app.api import Post
|
||||
|
||||
class Flow(L.LightningFlow):
|
||||
|
||||
def request(self, request: Request) -> OutputRequestModel:
|
||||
...
|
||||
|
||||
def configure_api(self):
|
||||
return [Post("/api/v1/request", self.request)]
|
||||
"""
|
||||
|
||||
_body: Optional[str] = None
|
||||
_json: Optional[str] = None
|
||||
_method: Optional[str] = None
|
||||
_headers: Optional[Dict] = None
|
||||
|
||||
@property
|
||||
def receive(self):
|
||||
raise NotImplementedError
|
||||
|
||||
@property
|
||||
def method(self):
|
||||
raise self._method
|
||||
|
||||
@property
|
||||
def headers(self):
|
||||
return self._headers
|
||||
|
||||
def body(self):
|
||||
return self._body
|
||||
|
||||
def json(self):
|
||||
return self._json
|
||||
|
||||
def stream(self):
|
||||
raise NotImplementedError
|
||||
|
||||
def form(self):
|
||||
raise NotImplementedError
|
||||
|
||||
def close(self):
|
||||
raise NotImplementedError
|
||||
|
||||
def is_disconnected(self):
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
async def _mock_fastapi_request(request: Request):
|
||||
# TODO: Add more requests parameters.
|
||||
return _FastApiMockRequest(
|
||||
_body=await request.body(),
|
||||
_json=await request.json(),
|
||||
_headers=request.headers,
|
||||
_method=request.method,
|
||||
)
|
||||
|
||||
|
||||
class _HttpMethod:
|
||||
def __init__(self, route: str, method: Callable, method_name: Optional[str] = None, timeout: int = 30, **kwargs):
|
||||
"""This class is used to inject user defined methods within the App Rest API.
|
||||
|
@ -34,6 +107,7 @@ class _HttpMethod:
|
|||
self.method_annotations = method.__annotations__
|
||||
# TODO: Validate the signature contains only pydantic models.
|
||||
self.method_signature = inspect.signature(method)
|
||||
|
||||
if not self.attached_to_flow:
|
||||
self.component_name = method.__name__
|
||||
self.method = method
|
||||
|
@ -43,10 +117,16 @@ class _HttpMethod:
|
|||
self.timeout = timeout
|
||||
self.kwargs = kwargs
|
||||
|
||||
# Enable the users to rely on FastAPI annotation typing with Request.
|
||||
# Note: Only a part of the Request functionatilities are supported.
|
||||
self._patch_fast_api_request()
|
||||
|
||||
def add_route(self, app: FastAPI, request_queue: Queue, responses_store: Dict[str, Any]) -> None:
|
||||
# 1: Get the route associated with the http method.
|
||||
route = getattr(app, self.__class__.__name__.lower())
|
||||
|
||||
self._unpatch_fast_api_request()
|
||||
|
||||
# 2: Create a proxy function with the signature of the wrapped method.
|
||||
fn = deepcopy(_signature_proxy_function)
|
||||
fn.__annotations__ = self.method_annotations
|
||||
|
@ -69,6 +149,11 @@ class _HttpMethod:
|
|||
@wraps(_signature_proxy_function)
|
||||
async def _handle_request(*args, **kwargs):
|
||||
async def fn(*args, **kwargs):
|
||||
args, kwargs = apply_to_collection((args, kwargs), Request, _mock_fastapi_request)
|
||||
for k, v in kwargs.items():
|
||||
if hasattr(v, "__await__"):
|
||||
kwargs[k] = await v
|
||||
|
||||
request_id = str(uuid4()).split("-")[0]
|
||||
logger.debug(f"Processing request {request_id} for route: {self.route}")
|
||||
request_queue.put(
|
||||
|
@ -101,6 +186,26 @@ class _HttpMethod:
|
|||
# 4: Register the user provided route to the Rest API.
|
||||
route(self.route, **self.kwargs)(_handle_request)
|
||||
|
||||
def _patch_fast_api_request(self):
|
||||
"""This function replaces signature annotation for Request with its mock."""
|
||||
for k, v in self.method_annotations.items():
|
||||
if v == Request:
|
||||
self.method_annotations[k] = _FastApiMockRequest
|
||||
|
||||
for v in self.method_signature.parameters.values():
|
||||
if v._annotation == Request:
|
||||
v._annotation = _FastApiMockRequest
|
||||
|
||||
def _unpatch_fast_api_request(self):
|
||||
"""This function replaces back signature annotation to fastapi Request."""
|
||||
for k, v in self.method_annotations.items():
|
||||
if v == _FastApiMockRequest:
|
||||
self.method_annotations[k] = Request
|
||||
|
||||
for v in self.method_signature.parameters.values():
|
||||
if v._annotation == _FastApiMockRequest:
|
||||
v._annotation = Request
|
||||
|
||||
|
||||
class Post(_HttpMethod):
|
||||
pass
|
||||
|
|
|
@ -12,7 +12,7 @@ import aiohttp
|
|||
import pytest
|
||||
import requests
|
||||
from deepdiff import DeepDiff, Delta
|
||||
from fastapi import HTTPException
|
||||
from fastapi import HTTPException, Request
|
||||
from httpx import AsyncClient
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
@ -479,10 +479,13 @@ class FlowAPI(LightningFlow):
|
|||
if self.counter == 501:
|
||||
self._exit()
|
||||
|
||||
def request(self, config: InputRequestModel) -> OutputRequestModel:
|
||||
def request(self, config: InputRequestModel, request: Request) -> OutputRequestModel:
|
||||
self.counter += 1
|
||||
if config.index % 5 == 0:
|
||||
raise HTTPException(status_code=400, detail="HERE")
|
||||
assert request.body()
|
||||
assert request.json()
|
||||
assert request.headers
|
||||
return OutputRequestModel(name=config.name, counter=self.counter)
|
||||
|
||||
def configure_api(self):
|
||||
|
|
Loading…
Reference in New Issue