Add support for functions (#15098)

This commit is contained in:
thomas chaton 2022-10-28 16:06:45 +01:00 committed by GitHub
parent bbf7848a5f
commit 2e72a4c801
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 76 additions and 34 deletions

View File

@ -1,10 +1,15 @@
from command import CustomCommand, CustomConfig
from lightning import LightningFlow
from lightning_app.api import Post
from lightning_app.api import Get, Post
from lightning_app.core.app import LightningApp
async def handler():
print("Has been called")
return "Hello World !"
class ChildFlow(LightningFlow):
def nested_command(self, name: str):
"""A nested command."""
@ -39,7 +44,10 @@ class FlowCommands(LightningFlow):
return commands + self.child_flow.configure_commands()
def configure_api(self):
return [Post("/user/command_without_client", self.command_without_client)]
return [
Post("/user/command_without_client", self.command_without_client),
Get("/pure_function", handler),
]
app = LightningApp(FlowCommands(), debug=True)

View File

@ -24,6 +24,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Added Lightning CLI Connection to be terminal session instead of global ([#15241](https://github.com/Lightning-AI/lightning/pull/15241)
- Add a `JustPyFrontend` to ease UI creation with `https://github.com/justpy-org/justpy` ([#15002](https://github.com/Lightning-AI/lightning/pull/15002))
- Added a layout endpoint to the Rest API and enable to disable pulling or pushing to the state ([#15367](https://github.com/Lightning-AI/lightning/pull/15367)
- Added support for functions for `configure_api` and `configure_commands` to be executed in the Rest API process ([#15098](https://github.com/Lightning-AI/lightning/pull/15098)
### Changed

View File

@ -29,58 +29,74 @@ class _HttpMethod:
timeout: The time in seconds taken before raising a timeout exception.
"""
self.route = route
self.component_name = method.__self__.name
self.attached_to_flow = hasattr(method, "__self__")
self.method_name = method_name or method.__name__
self.method_annotations = method.__annotations__
# TODO: Validate the signature contains only pydantic models.
self.method_signature = inspect.signature(method)
if not self.attached_to_flow:
self.component_name = method.__name__
self.method = method
else:
self.component_name = method.__self__.name
self.timeout = timeout
self.kwargs = kwargs
def add_route(self, app: FastAPI, request_queue: Queue, responses_store: Dict[str, Any]) -> None:
# 1: Create a proxy function with the signature of the wrapped method.
# 1: Get the route associated with the http method.
route = getattr(app, self.__class__.__name__.lower())
# 2: Create a proxy function with the signature of the wrapped method.
fn = deepcopy(_signature_proxy_function)
fn.__annotations__ = self.method_annotations
fn.__name__ = self.method_name
setattr(fn, "__signature__", self.method_signature)
# 2: Get the route associated with the http method.
route = getattr(app, self.__class__.__name__.lower())
# Note: Handle requests differently if attached to a flow.
if not self.attached_to_flow:
# 3: Define the request handler.
@wraps(_signature_proxy_function)
async def _handle_request(*args, **kwargs):
if inspect.iscoroutinefunction(self.method):
return await self.method(*args, **kwargs)
return self.method(*args, **kwargs)
request_cls = _CommandRequest if self.route.startswith("/command/") else _APIRequest
else:
request_cls = _CommandRequest if self.route.startswith("/command/") else _APIRequest
# 3: Define the request handler.
@wraps(_signature_proxy_function)
async def _handle_request(*args, **kwargs):
async def fn(*args, **kwargs):
request_id = str(uuid4()).split("-")[0]
logger.debug(f"Processing request {request_id} for route: {self.route}")
request_queue.put(
request_cls(
name=self.component_name,
method_name=self.method_name,
args=args,
kwargs=kwargs,
id=request_id,
# 3: Define the request handler.
@wraps(_signature_proxy_function)
async def _handle_request(*args, **kwargs):
async def fn(*args, **kwargs):
request_id = str(uuid4()).split("-")[0]
logger.debug(f"Processing request {request_id} for route: {self.route}")
request_queue.put(
request_cls(
name=self.component_name,
method_name=self.method_name,
args=args,
kwargs=kwargs,
id=request_id,
)
)
)
t0 = time.time()
while request_id not in responses_store:
await asyncio.sleep(0.01)
if (time.time() - t0) > self.timeout:
raise Exception("The response was never received.")
t0 = time.time()
while request_id not in responses_store:
await asyncio.sleep(0.01)
if (time.time() - t0) > self.timeout:
raise Exception("The response was never received.")
logger.debug(f"Processed request {request_id} for route: {self.route}")
logger.debug(f"Processed request {request_id} for route: {self.route}")
return responses_store.pop(request_id)
return responses_store.pop(request_id)
response: _RequestResponse = await asyncio.create_task(fn(*args, **kwargs))
response: _RequestResponse = await asyncio.create_task(fn(*args, **kwargs))
if response.status_code != 200:
raise HTTPException(response.status_code, detail=response.content)
if response.status_code != 200:
raise HTTPException(response.status_code, detail=response.content)
return response.content
return response.content
# 4: Register the user provided route to the Rest API.
route(self.route, **self.kwargs)(_handle_request)

View File

@ -25,7 +25,11 @@ def load_requirements(
requirements = load_requirements(path_req)
print(requirements) # ['numpy...', 'torch...', ...]
"""
with open(os.path.join(path_dir, file_name)) as file:
path = os.path.join(path_dir, file_name)
if not os.path.isfile(path):
return []
with open(path) as file:
lines = [ln.strip() for ln in file.readlines()]
reqs = []
for ln in lines:

View File

@ -471,6 +471,11 @@ class OutputRequestModel(BaseModel):
counter: int
async def handler():
print("Has been called")
return "Hello World !"
class FlowAPI(LightningFlow):
def __init__(self):
super().__init__()
@ -487,7 +492,7 @@ class FlowAPI(LightningFlow):
return OutputRequestModel(name=config.name, counter=self.counter)
def configure_api(self):
return [Post("/api/v1/request", self.request)]
return [Post("/api/v1/request", self.request), Post("/api/v1/handler", handler)]
def target():
@ -538,6 +543,9 @@ def test_configure_api():
assert len(results) == N
assert all(r.get("detail", None) == ("HERE" if i % 5 == 0 else None) for i, r in enumerate(results))
response = requests.post(f"http://localhost:{APP_SERVER_PORT}/api/v1/handler")
assert response.status_code == 200
# Stop the Application
try:
response = requests.post(url, json=InputRequestModel(index=0, name="hello").dict())

View File

@ -45,3 +45,8 @@ def test_commands_and_api_example_cloud() -> None:
if "['this', 'is', 'awesome']" in log:
has_logs = True
sleep(1)
# 7: Send a request to the Rest API directly.
resp = requests.get(base_url + "/pure_function")
assert resp.status_code == 200
assert resp.json() == "Hello World !"