From 2e72a4c8014e9488c5ba9174c506ab63d88a9980 Mon Sep 17 00:00:00 2001 From: thomas chaton Date: Fri, 28 Oct 2022 16:06:45 +0100 Subject: [PATCH] Add support for functions (#15098) --- examples/app_commands_and_api/app.py | 12 ++- src/lightning_app/CHANGELOG.md | 1 + src/lightning_app/api/http_methods.py | 76 +++++++++++-------- .../utilities/packaging/build_config.py | 6 +- tests/tests_app/core/test_lightning_api.py | 10 ++- .../test_commands_and_api.py | 5 ++ 6 files changed, 76 insertions(+), 34 deletions(-) diff --git a/examples/app_commands_and_api/app.py b/examples/app_commands_and_api/app.py index d3529e0d8c..ea00cf72a9 100644 --- a/examples/app_commands_and_api/app.py +++ b/examples/app_commands_and_api/app.py @@ -1,10 +1,15 @@ from command import CustomCommand, CustomConfig from lightning import LightningFlow -from lightning_app.api import Post +from lightning_app.api import Get, Post from lightning_app.core.app import LightningApp +async def handler(): + print("Has been called") + return "Hello World !" + + class ChildFlow(LightningFlow): def nested_command(self, name: str): """A nested command.""" @@ -39,7 +44,10 @@ class FlowCommands(LightningFlow): return commands + self.child_flow.configure_commands() def configure_api(self): - return [Post("/user/command_without_client", self.command_without_client)] + return [ + Post("/user/command_without_client", self.command_without_client), + Get("/pure_function", handler), + ] app = LightningApp(FlowCommands(), debug=True) diff --git a/src/lightning_app/CHANGELOG.md b/src/lightning_app/CHANGELOG.md index c148e7f55d..4a20306ba1 100644 --- a/src/lightning_app/CHANGELOG.md +++ b/src/lightning_app/CHANGELOG.md @@ -24,6 +24,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Added Lightning CLI Connection to be terminal session instead of global ([#15241](https://github.com/Lightning-AI/lightning/pull/15241) - Add a `JustPyFrontend` to ease UI creation with `https://github.com/justpy-org/justpy` ([#15002](https://github.com/Lightning-AI/lightning/pull/15002)) - Added a layout endpoint to the Rest API and enable to disable pulling or pushing to the state ([#15367](https://github.com/Lightning-AI/lightning/pull/15367) +- Added support for functions for `configure_api` and `configure_commands` to be executed in the Rest API process ([#15098](https://github.com/Lightning-AI/lightning/pull/15098) ### Changed diff --git a/src/lightning_app/api/http_methods.py b/src/lightning_app/api/http_methods.py index 81d107b85c..ca09a9a83e 100644 --- a/src/lightning_app/api/http_methods.py +++ b/src/lightning_app/api/http_methods.py @@ -29,58 +29,74 @@ class _HttpMethod: timeout: The time in seconds taken before raising a timeout exception. """ self.route = route - self.component_name = method.__self__.name + self.attached_to_flow = hasattr(method, "__self__") self.method_name = method_name or method.__name__ self.method_annotations = method.__annotations__ # TODO: Validate the signature contains only pydantic models. self.method_signature = inspect.signature(method) + if not self.attached_to_flow: + self.component_name = method.__name__ + self.method = method + else: + self.component_name = method.__self__.name + self.timeout = timeout self.kwargs = kwargs def add_route(self, app: FastAPI, request_queue: Queue, responses_store: Dict[str, Any]) -> None: - # 1: Create a proxy function with the signature of the wrapped method. + # 1: Get the route associated with the http method. + route = getattr(app, self.__class__.__name__.lower()) + + # 2: Create a proxy function with the signature of the wrapped method. fn = deepcopy(_signature_proxy_function) fn.__annotations__ = self.method_annotations fn.__name__ = self.method_name setattr(fn, "__signature__", self.method_signature) - # 2: Get the route associated with the http method. - route = getattr(app, self.__class__.__name__.lower()) + # Note: Handle requests differently if attached to a flow. + if not self.attached_to_flow: + # 3: Define the request handler. + @wraps(_signature_proxy_function) + async def _handle_request(*args, **kwargs): + if inspect.iscoroutinefunction(self.method): + return await self.method(*args, **kwargs) + return self.method(*args, **kwargs) - request_cls = _CommandRequest if self.route.startswith("/command/") else _APIRequest + else: + request_cls = _CommandRequest if self.route.startswith("/command/") else _APIRequest - # 3: Define the request handler. - @wraps(_signature_proxy_function) - async def _handle_request(*args, **kwargs): - async def fn(*args, **kwargs): - request_id = str(uuid4()).split("-")[0] - logger.debug(f"Processing request {request_id} for route: {self.route}") - request_queue.put( - request_cls( - name=self.component_name, - method_name=self.method_name, - args=args, - kwargs=kwargs, - id=request_id, + # 3: Define the request handler. + @wraps(_signature_proxy_function) + async def _handle_request(*args, **kwargs): + async def fn(*args, **kwargs): + request_id = str(uuid4()).split("-")[0] + logger.debug(f"Processing request {request_id} for route: {self.route}") + request_queue.put( + request_cls( + name=self.component_name, + method_name=self.method_name, + args=args, + kwargs=kwargs, + id=request_id, + ) ) - ) - t0 = time.time() - while request_id not in responses_store: - await asyncio.sleep(0.01) - if (time.time() - t0) > self.timeout: - raise Exception("The response was never received.") + t0 = time.time() + while request_id not in responses_store: + await asyncio.sleep(0.01) + if (time.time() - t0) > self.timeout: + raise Exception("The response was never received.") - logger.debug(f"Processed request {request_id} for route: {self.route}") + logger.debug(f"Processed request {request_id} for route: {self.route}") - return responses_store.pop(request_id) + return responses_store.pop(request_id) - response: _RequestResponse = await asyncio.create_task(fn(*args, **kwargs)) + response: _RequestResponse = await asyncio.create_task(fn(*args, **kwargs)) - if response.status_code != 200: - raise HTTPException(response.status_code, detail=response.content) + if response.status_code != 200: + raise HTTPException(response.status_code, detail=response.content) - return response.content + return response.content # 4: Register the user provided route to the Rest API. route(self.route, **self.kwargs)(_handle_request) diff --git a/src/lightning_app/utilities/packaging/build_config.py b/src/lightning_app/utilities/packaging/build_config.py index 6d42012be8..bc29a75882 100644 --- a/src/lightning_app/utilities/packaging/build_config.py +++ b/src/lightning_app/utilities/packaging/build_config.py @@ -25,7 +25,11 @@ def load_requirements( requirements = load_requirements(path_req) print(requirements) # ['numpy...', 'torch...', ...] """ - with open(os.path.join(path_dir, file_name)) as file: + path = os.path.join(path_dir, file_name) + if not os.path.isfile(path): + return [] + + with open(path) as file: lines = [ln.strip() for ln in file.readlines()] reqs = [] for ln in lines: diff --git a/tests/tests_app/core/test_lightning_api.py b/tests/tests_app/core/test_lightning_api.py index b53765eb80..e5494757cd 100644 --- a/tests/tests_app/core/test_lightning_api.py +++ b/tests/tests_app/core/test_lightning_api.py @@ -471,6 +471,11 @@ class OutputRequestModel(BaseModel): counter: int +async def handler(): + print("Has been called") + return "Hello World !" + + class FlowAPI(LightningFlow): def __init__(self): super().__init__() @@ -487,7 +492,7 @@ class FlowAPI(LightningFlow): return OutputRequestModel(name=config.name, counter=self.counter) def configure_api(self): - return [Post("/api/v1/request", self.request)] + return [Post("/api/v1/request", self.request), Post("/api/v1/handler", handler)] def target(): @@ -538,6 +543,9 @@ def test_configure_api(): assert len(results) == N assert all(r.get("detail", None) == ("HERE" if i % 5 == 0 else None) for i, r in enumerate(results)) + response = requests.post(f"http://localhost:{APP_SERVER_PORT}/api/v1/handler") + assert response.status_code == 200 + # Stop the Application try: response = requests.post(url, json=InputRequestModel(index=0, name="hello").dict()) diff --git a/tests/tests_app_examples/test_commands_and_api.py b/tests/tests_app_examples/test_commands_and_api.py index 25ca40d6d1..a9c5fe0892 100644 --- a/tests/tests_app_examples/test_commands_and_api.py +++ b/tests/tests_app_examples/test_commands_and_api.py @@ -45,3 +45,8 @@ def test_commands_and_api_example_cloud() -> None: if "['this', 'is', 'awesome']" in log: has_logs = True sleep(1) + + # 7: Send a request to the Rest API directly. + resp = requests.get(base_url + "/pure_function") + assert resp.status_code == 200 + assert resp.json() == "Hello World !"