[App] PoC: Add support for Request (#16047)

This commit is contained in:
thomas chaton 2022-12-16 15:19:10 +01:00 committed by GitHub
parent 005b6f2374
commit 592b12658a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 116 additions and 5 deletions

4
.gitignore vendored
View File

@ -110,8 +110,8 @@ celerybeat-schedule
# dotenv
.env
.env_staging
.env_local
.env.staging
.env.local
# virtualenv
.venv

View File

@ -13,6 +13,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Added a progres bar while connecting to an app through the CLI ([#16035](https://github.com/Lightning-AI/lightning/pull/16035))
- Added partial support for fastapi `Request` annotation in `configure_api` handlers ([#16047](https://github.com/Lightning-AI/lightning/pull/16047))
### Changed
-

View File

@ -2,12 +2,14 @@ import asyncio
import inspect
import time
from copy import deepcopy
from dataclasses import dataclass
from functools import wraps
from multiprocessing import Queue
from typing import Any, Callable, Dict, List, Optional
from uuid import uuid4
from fastapi import FastAPI, HTTPException
from fastapi import FastAPI, HTTPException, Request
from lightning_utilities.core.apply_func import apply_to_collection
from lightning_app.api.request_types import _APIRequest, _CommandRequest, _RequestResponse
from lightning_app.utilities.app_helpers import Logger
@ -19,6 +21,77 @@ def _signature_proxy_function():
pass
@dataclass
class _FastApiMockRequest:
"""This class is meant to mock FastAPI Request class that isn't pickle-able.
If a user relies on FastAPI Request annotation, the Lightning framework
patches the annotation before pickling and replace them right after.
Finally, the FastAPI request is converted back to the _FastApiMockRequest
before being delivered to the users.
Example:
import lightning as L
from fastapi import Request
from lightning.app.api import Post
class Flow(L.LightningFlow):
def request(self, request: Request) -> OutputRequestModel:
...
def configure_api(self):
return [Post("/api/v1/request", self.request)]
"""
_body: Optional[str] = None
_json: Optional[str] = None
_method: Optional[str] = None
_headers: Optional[Dict] = None
@property
def receive(self):
raise NotImplementedError
@property
def method(self):
raise self._method
@property
def headers(self):
return self._headers
def body(self):
return self._body
def json(self):
return self._json
def stream(self):
raise NotImplementedError
def form(self):
raise NotImplementedError
def close(self):
raise NotImplementedError
def is_disconnected(self):
raise NotImplementedError
async def _mock_fastapi_request(request: Request):
# TODO: Add more requests parameters.
return _FastApiMockRequest(
_body=await request.body(),
_json=await request.json(),
_headers=request.headers,
_method=request.method,
)
class _HttpMethod:
def __init__(self, route: str, method: Callable, method_name: Optional[str] = None, timeout: int = 30, **kwargs):
"""This class is used to inject user defined methods within the App Rest API.
@ -34,6 +107,7 @@ class _HttpMethod:
self.method_annotations = method.__annotations__
# TODO: Validate the signature contains only pydantic models.
self.method_signature = inspect.signature(method)
if not self.attached_to_flow:
self.component_name = method.__name__
self.method = method
@ -43,10 +117,16 @@ class _HttpMethod:
self.timeout = timeout
self.kwargs = kwargs
# Enable the users to rely on FastAPI annotation typing with Request.
# Note: Only a part of the Request functionatilities are supported.
self._patch_fast_api_request()
def add_route(self, app: FastAPI, request_queue: Queue, responses_store: Dict[str, Any]) -> None:
# 1: Get the route associated with the http method.
route = getattr(app, self.__class__.__name__.lower())
self._unpatch_fast_api_request()
# 2: Create a proxy function with the signature of the wrapped method.
fn = deepcopy(_signature_proxy_function)
fn.__annotations__ = self.method_annotations
@ -69,6 +149,11 @@ class _HttpMethod:
@wraps(_signature_proxy_function)
async def _handle_request(*args, **kwargs):
async def fn(*args, **kwargs):
args, kwargs = apply_to_collection((args, kwargs), Request, _mock_fastapi_request)
for k, v in kwargs.items():
if hasattr(v, "__await__"):
kwargs[k] = await v
request_id = str(uuid4()).split("-")[0]
logger.debug(f"Processing request {request_id} for route: {self.route}")
request_queue.put(
@ -101,6 +186,26 @@ class _HttpMethod:
# 4: Register the user provided route to the Rest API.
route(self.route, **self.kwargs)(_handle_request)
def _patch_fast_api_request(self):
"""This function replaces signature annotation for Request with its mock."""
for k, v in self.method_annotations.items():
if v == Request:
self.method_annotations[k] = _FastApiMockRequest
for v in self.method_signature.parameters.values():
if v._annotation == Request:
v._annotation = _FastApiMockRequest
def _unpatch_fast_api_request(self):
"""This function replaces back signature annotation to fastapi Request."""
for k, v in self.method_annotations.items():
if v == _FastApiMockRequest:
self.method_annotations[k] = Request
for v in self.method_signature.parameters.values():
if v._annotation == _FastApiMockRequest:
v._annotation = Request
class Post(_HttpMethod):
pass

View File

@ -12,7 +12,7 @@ import aiohttp
import pytest
import requests
from deepdiff import DeepDiff, Delta
from fastapi import HTTPException
from fastapi import HTTPException, Request
from httpx import AsyncClient
from pydantic import BaseModel
@ -479,10 +479,13 @@ class FlowAPI(LightningFlow):
if self.counter == 501:
self._exit()
def request(self, config: InputRequestModel) -> OutputRequestModel:
def request(self, config: InputRequestModel, request: Request) -> OutputRequestModel:
self.counter += 1
if config.index % 5 == 0:
raise HTTPException(status_code=400, detail="HERE")
assert request.body()
assert request.json()
assert request.headers
return OutputRequestModel(name=config.name, counter=self.counter)
def configure_api(self):