Add support for functions (#15098)
This commit is contained in:
parent
bbf7848a5f
commit
2e72a4c801
|
@ -1,10 +1,15 @@
|
|||
from command import CustomCommand, CustomConfig
|
||||
|
||||
from lightning import LightningFlow
|
||||
from lightning_app.api import Post
|
||||
from lightning_app.api import Get, Post
|
||||
from lightning_app.core.app import LightningApp
|
||||
|
||||
|
||||
async def handler():
|
||||
print("Has been called")
|
||||
return "Hello World !"
|
||||
|
||||
|
||||
class ChildFlow(LightningFlow):
|
||||
def nested_command(self, name: str):
|
||||
"""A nested command."""
|
||||
|
@ -39,7 +44,10 @@ class FlowCommands(LightningFlow):
|
|||
return commands + self.child_flow.configure_commands()
|
||||
|
||||
def configure_api(self):
|
||||
return [Post("/user/command_without_client", self.command_without_client)]
|
||||
return [
|
||||
Post("/user/command_without_client", self.command_without_client),
|
||||
Get("/pure_function", handler),
|
||||
]
|
||||
|
||||
|
||||
app = LightningApp(FlowCommands(), debug=True)
|
||||
|
|
|
@ -24,6 +24,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
|
|||
- Added Lightning CLI Connection to be terminal session instead of global ([#15241](https://github.com/Lightning-AI/lightning/pull/15241)
|
||||
- Add a `JustPyFrontend` to ease UI creation with `https://github.com/justpy-org/justpy` ([#15002](https://github.com/Lightning-AI/lightning/pull/15002))
|
||||
- Added a layout endpoint to the Rest API and enable to disable pulling or pushing to the state ([#15367](https://github.com/Lightning-AI/lightning/pull/15367)
|
||||
- Added support for functions for `configure_api` and `configure_commands` to be executed in the Rest API process ([#15098](https://github.com/Lightning-AI/lightning/pull/15098)
|
||||
|
||||
|
||||
### Changed
|
||||
|
|
|
@ -29,58 +29,74 @@ class _HttpMethod:
|
|||
timeout: The time in seconds taken before raising a timeout exception.
|
||||
"""
|
||||
self.route = route
|
||||
self.component_name = method.__self__.name
|
||||
self.attached_to_flow = hasattr(method, "__self__")
|
||||
self.method_name = method_name or method.__name__
|
||||
self.method_annotations = method.__annotations__
|
||||
# TODO: Validate the signature contains only pydantic models.
|
||||
self.method_signature = inspect.signature(method)
|
||||
if not self.attached_to_flow:
|
||||
self.component_name = method.__name__
|
||||
self.method = method
|
||||
else:
|
||||
self.component_name = method.__self__.name
|
||||
|
||||
self.timeout = timeout
|
||||
self.kwargs = kwargs
|
||||
|
||||
def add_route(self, app: FastAPI, request_queue: Queue, responses_store: Dict[str, Any]) -> None:
|
||||
# 1: Create a proxy function with the signature of the wrapped method.
|
||||
# 1: Get the route associated with the http method.
|
||||
route = getattr(app, self.__class__.__name__.lower())
|
||||
|
||||
# 2: Create a proxy function with the signature of the wrapped method.
|
||||
fn = deepcopy(_signature_proxy_function)
|
||||
fn.__annotations__ = self.method_annotations
|
||||
fn.__name__ = self.method_name
|
||||
setattr(fn, "__signature__", self.method_signature)
|
||||
|
||||
# 2: Get the route associated with the http method.
|
||||
route = getattr(app, self.__class__.__name__.lower())
|
||||
# Note: Handle requests differently if attached to a flow.
|
||||
if not self.attached_to_flow:
|
||||
# 3: Define the request handler.
|
||||
@wraps(_signature_proxy_function)
|
||||
async def _handle_request(*args, **kwargs):
|
||||
if inspect.iscoroutinefunction(self.method):
|
||||
return await self.method(*args, **kwargs)
|
||||
return self.method(*args, **kwargs)
|
||||
|
||||
request_cls = _CommandRequest if self.route.startswith("/command/") else _APIRequest
|
||||
else:
|
||||
request_cls = _CommandRequest if self.route.startswith("/command/") else _APIRequest
|
||||
|
||||
# 3: Define the request handler.
|
||||
@wraps(_signature_proxy_function)
|
||||
async def _handle_request(*args, **kwargs):
|
||||
async def fn(*args, **kwargs):
|
||||
request_id = str(uuid4()).split("-")[0]
|
||||
logger.debug(f"Processing request {request_id} for route: {self.route}")
|
||||
request_queue.put(
|
||||
request_cls(
|
||||
name=self.component_name,
|
||||
method_name=self.method_name,
|
||||
args=args,
|
||||
kwargs=kwargs,
|
||||
id=request_id,
|
||||
# 3: Define the request handler.
|
||||
@wraps(_signature_proxy_function)
|
||||
async def _handle_request(*args, **kwargs):
|
||||
async def fn(*args, **kwargs):
|
||||
request_id = str(uuid4()).split("-")[0]
|
||||
logger.debug(f"Processing request {request_id} for route: {self.route}")
|
||||
request_queue.put(
|
||||
request_cls(
|
||||
name=self.component_name,
|
||||
method_name=self.method_name,
|
||||
args=args,
|
||||
kwargs=kwargs,
|
||||
id=request_id,
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
t0 = time.time()
|
||||
while request_id not in responses_store:
|
||||
await asyncio.sleep(0.01)
|
||||
if (time.time() - t0) > self.timeout:
|
||||
raise Exception("The response was never received.")
|
||||
t0 = time.time()
|
||||
while request_id not in responses_store:
|
||||
await asyncio.sleep(0.01)
|
||||
if (time.time() - t0) > self.timeout:
|
||||
raise Exception("The response was never received.")
|
||||
|
||||
logger.debug(f"Processed request {request_id} for route: {self.route}")
|
||||
logger.debug(f"Processed request {request_id} for route: {self.route}")
|
||||
|
||||
return responses_store.pop(request_id)
|
||||
return responses_store.pop(request_id)
|
||||
|
||||
response: _RequestResponse = await asyncio.create_task(fn(*args, **kwargs))
|
||||
response: _RequestResponse = await asyncio.create_task(fn(*args, **kwargs))
|
||||
|
||||
if response.status_code != 200:
|
||||
raise HTTPException(response.status_code, detail=response.content)
|
||||
if response.status_code != 200:
|
||||
raise HTTPException(response.status_code, detail=response.content)
|
||||
|
||||
return response.content
|
||||
return response.content
|
||||
|
||||
# 4: Register the user provided route to the Rest API.
|
||||
route(self.route, **self.kwargs)(_handle_request)
|
||||
|
|
|
@ -25,7 +25,11 @@ def load_requirements(
|
|||
requirements = load_requirements(path_req)
|
||||
print(requirements) # ['numpy...', 'torch...', ...]
|
||||
"""
|
||||
with open(os.path.join(path_dir, file_name)) as file:
|
||||
path = os.path.join(path_dir, file_name)
|
||||
if not os.path.isfile(path):
|
||||
return []
|
||||
|
||||
with open(path) as file:
|
||||
lines = [ln.strip() for ln in file.readlines()]
|
||||
reqs = []
|
||||
for ln in lines:
|
||||
|
|
|
@ -471,6 +471,11 @@ class OutputRequestModel(BaseModel):
|
|||
counter: int
|
||||
|
||||
|
||||
async def handler():
|
||||
print("Has been called")
|
||||
return "Hello World !"
|
||||
|
||||
|
||||
class FlowAPI(LightningFlow):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
@ -487,7 +492,7 @@ class FlowAPI(LightningFlow):
|
|||
return OutputRequestModel(name=config.name, counter=self.counter)
|
||||
|
||||
def configure_api(self):
|
||||
return [Post("/api/v1/request", self.request)]
|
||||
return [Post("/api/v1/request", self.request), Post("/api/v1/handler", handler)]
|
||||
|
||||
|
||||
def target():
|
||||
|
@ -538,6 +543,9 @@ def test_configure_api():
|
|||
assert len(results) == N
|
||||
assert all(r.get("detail", None) == ("HERE" if i % 5 == 0 else None) for i, r in enumerate(results))
|
||||
|
||||
response = requests.post(f"http://localhost:{APP_SERVER_PORT}/api/v1/handler")
|
||||
assert response.status_code == 200
|
||||
|
||||
# Stop the Application
|
||||
try:
|
||||
response = requests.post(url, json=InputRequestModel(index=0, name="hello").dict())
|
||||
|
|
|
@ -45,3 +45,8 @@ def test_commands_and_api_example_cloud() -> None:
|
|||
if "['this', 'is', 'awesome']" in log:
|
||||
has_logs = True
|
||||
sleep(1)
|
||||
|
||||
# 7: Send a request to the Rest API directly.
|
||||
resp = requests.get(base_url + "/pure_function")
|
||||
assert resp.status_code == 200
|
||||
assert resp.json() == "Hello World !"
|
||||
|
|
Loading…
Reference in New Issue