diff --git a/src/lightning_app/CHANGELOG.md b/src/lightning_app/CHANGELOG.md index 70bd95ae00..c148e7f55d 100644 --- a/src/lightning_app/CHANGELOG.md +++ b/src/lightning_app/CHANGELOG.md @@ -23,6 +23,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Added support for adding requirements to commands and installing them when missing when running an app command ([#15198](https://github.com/Lightning-AI/lightning/pull/15198) - Added Lightning CLI Connection to be terminal session instead of global ([#15241](https://github.com/Lightning-AI/lightning/pull/15241) - Add a `JustPyFrontend` to ease UI creation with `https://github.com/justpy-org/justpy` ([#15002](https://github.com/Lightning-AI/lightning/pull/15002)) +- Added a layout endpoint to the Rest API and enable to disable pulling or pushing to the state ([#15367](https://github.com/Lightning-AI/lightning/pull/15367) + ### Changed diff --git a/src/lightning_app/core/api.py b/src/lightning_app/core/api.py index 4c8c2659a4..0eb34145b7 100644 --- a/src/lightning_app/core/api.py +++ b/src/lightning_app/core/api.py @@ -8,7 +8,7 @@ from multiprocessing import Queue from tempfile import TemporaryDirectory from threading import Event, Lock, Thread from time import sleep -from typing import Dict, List, Mapping, Optional +from typing import Dict, List, Mapping, Optional, Union import uvicorn from deepdiff import DeepDiff, Delta @@ -23,7 +23,14 @@ from websockets.exceptions import ConnectionClosed from lightning_app.api.http_methods import HttpMethod from lightning_app.api.request_types import DeltaRequest -from lightning_app.core.constants import CLOUD_QUEUE_TYPE, ENABLE_STATE_WEBSOCKET, FRONTEND_DIR +from lightning_app.core.constants import ( + CLOUD_QUEUE_TYPE, + ENABLE_PULLING_STATE_ENDPOINT, + ENABLE_PUSHING_STATE_ENDPOINT, + ENABLE_STATE_WEBSOCKET, + ENABLE_UPLOAD_ENDPOINT, + FRONTEND_DIR, +) from lightning_app.core.queues import QueuingSystem from lightning_app.storage import Drive from lightning_app.utilities.app_helpers import InMemoryStateStore, Logger, StateStore @@ -163,6 +170,7 @@ if _is_starsessions_available(): # ranks) @fastapi_service.get("/api/v1/state", response_class=JSONResponse) async def get_state( + response: Response, x_lightning_type: Optional[str] = Header(None), x_lightning_session_uuid: Optional[str] = Header(None), x_lightning_session_id: Optional[str] = Header(None), @@ -172,6 +180,10 @@ async def get_state( if x_lightning_session_id is None: raise Exception("Missing X-Lightning-Session-ID header") + if not ENABLE_PULLING_STATE_ENDPOINT: + response.status_code = status.HTTP_405_METHOD_NOT_ALLOWED + return {"status": "failure", "reason": "This endpoint is disabled."} + with lock: x_lightning_session_uuid = TEST_SESSION_UUID state = global_app_state_store.get_app_state(x_lightning_session_uuid) @@ -179,15 +191,48 @@ async def get_state( return state +def _get_component_by_name(component_name: str, state): + child = state + for child_name in component_name.split(".")[1:]: + try: + child = child["flows"][child_name] + except KeyError: + child = child["structures"][child_name] + + if isinstance(child["vars"]["_layout"], list): + assert len(child["vars"]["_layout"]) == 1 + return child["vars"]["_layout"][0]["target"] + return child["vars"]["_layout"]["target"] + + +@fastapi_service.get("/api/v1/layout", response_class=JSONResponse) +async def get_layout() -> Mapping: + with lock: + x_lightning_session_uuid = TEST_SESSION_UUID + state = global_app_state_store.get_app_state(x_lightning_session_uuid) + global_app_state_store.set_served_state(x_lightning_session_uuid, state) + layout = deepcopy(state["vars"]["_layout"]) + for la in layout: + if la["content"].startswith("root."): + la["content"] = _get_component_by_name(la["content"], state) + return layout + + @fastapi_service.get("/api/v1/spec", response_class=JSONResponse) async def get_spec( + response: Response, x_lightning_session_uuid: Optional[str] = Header(None), x_lightning_session_id: Optional[str] = Header(None), -) -> List: +) -> Union[List, Dict]: if x_lightning_session_uuid is None: raise Exception("Missing X-Lightning-Session-UUID header") if x_lightning_session_id is None: raise Exception("Missing X-Lightning-Session-ID header") + + if not ENABLE_PULLING_STATE_ENDPOINT: + response.status_code = status.HTTP_405_METHOD_NOT_ALLOWED + return {"status": "failure", "reason": "This endpoint is disabled."} + global app_spec return app_spec or [] @@ -195,10 +240,11 @@ async def get_spec( @fastapi_service.post("/api/v1/delta") async def post_delta( request: Request, + response: Response, x_lightning_type: Optional[str] = Header(None), x_lightning_session_uuid: Optional[str] = Header(None), x_lightning_session_id: Optional[str] = Header(None), -) -> None: +) -> Optional[Dict]: """This endpoint is used to make an update to the app state using delta diff, mainly used by streamlit to update the state.""" @@ -207,6 +253,10 @@ async def post_delta( if x_lightning_session_id is None: raise Exception("Missing X-Lightning-Session-ID header") + if not ENABLE_PUSHING_STATE_ENDPOINT: + response.status_code = status.HTTP_405_METHOD_NOT_ALLOWED + return {"status": "failure", "reason": "This endpoint is disabled."} + body: Dict = await request.json() api_app_delta_queue.put(DeltaRequest(delta=Delta(body["delta"]))) @@ -214,10 +264,11 @@ async def post_delta( @fastapi_service.post("/api/v1/state") async def post_state( request: Request, + response: Response, x_lightning_type: Optional[str] = Header(None), x_lightning_session_uuid: Optional[str] = Header(None), x_lightning_session_id: Optional[str] = Header(None), -) -> None: +) -> Optional[Dict]: if x_lightning_session_uuid is None: raise Exception("Missing X-Lightning-Session-UUID header") if x_lightning_session_id is None: @@ -231,6 +282,10 @@ async def post_state( body: Dict = await request.json() x_lightning_session_uuid = TEST_SESSION_UUID + if not ENABLE_PUSHING_STATE_ENDPOINT: + response.status_code = status.HTTP_405_METHOD_NOT_ALLOWED + return {"status": "failure", "reason": "This endpoint is disabled."} + if "stage" in body: last_state = global_app_state_store.get_served_state(x_lightning_session_uuid) state = deepcopy(last_state) @@ -244,7 +299,11 @@ async def post_state( @fastapi_service.put("/api/v1/upload_file/{filename}") -async def upload_file(filename: str, uploaded_file: UploadFile = File(...)): +async def upload_file(response: Response, filename: str, uploaded_file: UploadFile = File(...)): + if not ENABLE_UPLOAD_ENDPOINT: + response.status_code = status.HTTP_405_METHOD_NOT_ALLOWED + return {"status": "failure", "reason": "This endpoint is disabled."} + with TemporaryDirectory() as tmp: drive = Drive( "lit://uploaded_files", diff --git a/src/lightning_app/core/constants.py b/src/lightning_app/core/constants.py index 1c5dc862ea..f36a101908 100644 --- a/src/lightning_app/core/constants.py +++ b/src/lightning_app/core/constants.py @@ -49,7 +49,6 @@ LIGHTNING_CREDENTIAL_PATH = os.getenv("LIGHTNING_CREDENTIAL_PATH", str(Path(LIGH DOT_IGNORE_FILENAME = ".lightningignore" LIGHTNING_COMPONENT_PUBLIC_REGISTRY = "https://lightning.ai/v1/components" LIGHTNING_APPS_PUBLIC_REGISTRY = "https://lightning.ai/v1/apps" -ENABLE_STATE_WEBSOCKET = bool(int(os.getenv("ENABLE_STATE_WEBSOCKET", "0"))) # EXPERIMENTAL: ENV VARIABLES TO ENABLE MULTIPLE WORKS IN THE SAME MACHINE DEFAULT_NUMBER_OF_EXPOSED_PORTS = int(os.getenv("DEFAULT_NUMBER_OF_EXPOSED_PORTS", "50")) @@ -63,3 +62,9 @@ ENABLE_MULTIPLE_WORKS_IN_NON_DEFAULT_CONTAINER = bool( DEBUG: bool = lightning_cloud.env.DEBUG DEBUG_ENABLED = bool(int(os.getenv("LIGHTNING_DEBUG", "0"))) +ENABLE_PULLING_STATE_ENDPOINT = bool(int(os.getenv("ENABLE_PULLING_STATE_ENDPOINT", "1"))) +ENABLE_PUSHING_STATE_ENDPOINT = ENABLE_PULLING_STATE_ENDPOINT and bool( + int(os.getenv("ENABLE_PUSHING_STATE_ENDPOINT", "1")) +) +ENABLE_STATE_WEBSOCKET = bool(int(os.getenv("ENABLE_STATE_WEBSOCKET", "0"))) +ENABLE_UPLOAD_ENDPOINT = bool(int(os.getenv("ENABLE_UPLOAD_ENDPOINT", "1"))) diff --git a/src/lightning_app/core/flow.py b/src/lightning_app/core/flow.py index 068c0e1ef1..e65d266c84 100644 --- a/src/lightning_app/core/flow.py +++ b/src/lightning_app/core/flow.py @@ -276,7 +276,11 @@ class LightningFlow: @property def flows(self): """Return its children LightningFlow.""" - return {el: getattr(self, el) for el in sorted(self._flows)} + flows = {el: getattr(self, el) for el in sorted(self._flows)} + for struct_name in sorted(self._structures): + for flow in getattr(self, struct_name).flows: + flows[flow.name] = flow + return flows def works(self, recurse: bool = True) -> List[LightningWork]: """Return its :class:`~lightning_app.core.work.LightningWork`.""" diff --git a/src/lightning_app/runners/cloud.py b/src/lightning_app/runners/cloud.py index 0816632448..44f7fab2f4 100644 --- a/src/lightning_app/runners/cloud.py +++ b/src/lightning_app/runners/cloud.py @@ -49,6 +49,8 @@ from lightning_app.core.constants import ( DISABLE_DEPENDENCY_CACHE, ENABLE_MULTIPLE_WORKS_IN_DEFAULT_CONTAINER, ENABLE_MULTIPLE_WORKS_IN_NON_DEFAULT_CONTAINER, + ENABLE_PULLING_STATE_ENDPOINT, + ENABLE_PUSHING_STATE_ENDPOINT, ) from lightning_app.runners.backends.cloud import CloudBackend from lightning_app.runners.runtime import Runtime @@ -123,6 +125,12 @@ class CloudRuntime(Runtime): if ENABLE_MULTIPLE_WORKS_IN_NON_DEFAULT_CONTAINER: v1_env_vars.append(V1EnvVar(name="ENABLE_MULTIPLE_WORKS_IN_NON_DEFAULT_CONTAINER", value="1")) + if not ENABLE_PULLING_STATE_ENDPOINT: + v1_env_vars.append(V1EnvVar(name="ENABLE_PULLING_STATE_ENDPOINT", value="0")) + + if not ENABLE_PUSHING_STATE_ENDPOINT: + v1_env_vars.append(V1EnvVar(name="ENABLE_PUSHING_STATE_ENDPOINT", value="0")) + work_reqs: List[V1Work] = [] for flow in self.app.flows: for work in flow.works(recurse=False): diff --git a/tests/tests_app/conftest.py b/tests/tests_app/conftest.py index 7492d6ae9e..508196d566 100644 --- a/tests/tests_app/conftest.py +++ b/tests/tests_app/conftest.py @@ -1,5 +1,6 @@ import os import shutil +import threading from datetime import datetime from pathlib import Path from subprocess import Popen @@ -46,6 +47,11 @@ def pytest_sessionfinish(session, exitstatus): continue child.kill() + main_thread = threading.current_thread() + for t in threading.enumerate(): + if t is not main_thread: + t.join(0) + @pytest.fixture(scope="function", autouse=True) def cleanup(): diff --git a/tests/tests_app/core/test_lightning_api.py b/tests/tests_app/core/test_lightning_api.py index 66873a70ac..80d1cfd06c 100644 --- a/tests/tests_app/core/test_lightning_api.py +++ b/tests/tests_app/core/test_lightning_api.py @@ -16,6 +16,7 @@ from fastapi import HTTPException from httpx import AsyncClient from pydantic import BaseModel +import lightning_app from lightning_app import LightningApp, LightningFlow, LightningWork from lightning_app.api.http_methods import Post from lightning_app.core import api @@ -114,16 +115,35 @@ def test_app_state_api_with_flows(runtime_cls, tmpdir): assert app.root.var_a == -1 +class NestedFlow(LightningFlow): + def run(self): + pass + + def configure_layout(self): + return {"name": "main", "content": "https://te"} + + class FlowA(LightningFlow): def __init__(self): super().__init__() self.counter = 0 + self.flow = NestedFlow() + self.dict = lightning_app.structures.Dict(**{"0": NestedFlow()}) + self.list = lightning_app.structures.List(*[NestedFlow()]) def run(self): self.counter += 1 if self.counter >= 3: self._exit() + def configure_layout(self): + return [ + {"name": "main_1", "content": "https://te"}, + {"name": "main_2", "content": self.flow}, + {"name": "main_3", "content": self.dict["0"]}, + {"name": "main_4", "content": self.list[0]}, + ] + class AppStageTestingApp(LightningApp): def __init__(self, *args, **kwargs): @@ -193,7 +213,7 @@ def test_update_publish_state_and_maybe_refresh_ui(): @pytest.mark.parametrize("x_lightning_type", ["DEFAULT", "STREAMLIT"]) @pytest.mark.anyio -async def test_start_server(x_lightning_type): +async def test_start_server(x_lightning_type, monkeypatch): """This test relies on FastAPI TestClient and validates that the REST API properly provides: - the state on GET /api/v1/state @@ -205,6 +225,7 @@ async def test_start_server(x_lightning_type): return self._queue[0] app = AppStageTestingApp(FlowA(), debug=True) + app._update_layout() app.stage = AppStage.BLOCKING publish_state_queue = InfiniteQueue("publish_state_queue") change_state_queue = MockQueue("change_state_queue") @@ -261,6 +282,14 @@ async def test_start_server(x_lightning_type): } assert response.status_code == 200 + response = await client.get("/api/v1/layout") + assert response.json() == [ + {"name": "main_1", "content": "https://te", "target": "https://te"}, + {"name": "main_2", "content": "https://te"}, + {"name": "main_3", "content": "https://te"}, + {"name": "main_4", "content": "https://te"}, + ] + response = await client.post("/api/v1/state", json={"state": new_state}, headers=headers) assert change_state_queue._queue[1].to_dict() == { "values_changed": {"root['vars']['counter']": {"new_value": 1}} @@ -281,6 +310,33 @@ async def test_start_server(x_lightning_type): } assert response.status_code == 200 + monkeypatch.setattr(api, "ENABLE_PULLING_STATE_ENDPOINT", False) + + response = await client.get("/api/v1/state", headers=headers) + assert response.status_code == 405 + + response = await client.post("/api/v1/state", json={"state": new_state}, headers=headers) + assert response.status_code == 200 + + monkeypatch.setattr(api, "ENABLE_PUSHING_STATE_ENDPOINT", False) + + response = await client.post("/api/v1/state", json={"state": new_state}, headers=headers) + assert response.status_code == 405 + + response = await client.post( + "/api/v1/delta", + json={ + "delta": { + "values_changed": {"root['flows']['video_search']['vars']['should_process']": {"new_value": True}} + } + }, + headers=headers, + ) + assert change_state_queue._queue[2].to_dict() == { + "values_changed": {"root['flows']['video_search']['vars']['should_process']": {"new_value": True}} + } + assert response.status_code == 405 + # used to clean the app_state_store to following test. global_app_state_store.remove("1234") global_app_state_store.add("1234")