[App] Add env variables to desactivate pull and push of the App State (#15367)

This commit is contained in:
thomas chaton 2022-10-28 14:26:08 +01:00 committed by GitHub
parent 6b0b6b8903
commit b6ebc7b5f5
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 149 additions and 9 deletions

View File

@ -23,6 +23,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Added support for adding requirements to commands and installing them when missing when running an app command ([#15198](https://github.com/Lightning-AI/lightning/pull/15198) - Added support for adding requirements to commands and installing them when missing when running an app command ([#15198](https://github.com/Lightning-AI/lightning/pull/15198)
- Added Lightning CLI Connection to be terminal session instead of global ([#15241](https://github.com/Lightning-AI/lightning/pull/15241) - Added Lightning CLI Connection to be terminal session instead of global ([#15241](https://github.com/Lightning-AI/lightning/pull/15241)
- Add a `JustPyFrontend` to ease UI creation with `https://github.com/justpy-org/justpy` ([#15002](https://github.com/Lightning-AI/lightning/pull/15002)) - Add a `JustPyFrontend` to ease UI creation with `https://github.com/justpy-org/justpy` ([#15002](https://github.com/Lightning-AI/lightning/pull/15002))
- Added a layout endpoint to the Rest API and enable to disable pulling or pushing to the state ([#15367](https://github.com/Lightning-AI/lightning/pull/15367)
### Changed ### Changed

View File

@ -8,7 +8,7 @@ from multiprocessing import Queue
from tempfile import TemporaryDirectory from tempfile import TemporaryDirectory
from threading import Event, Lock, Thread from threading import Event, Lock, Thread
from time import sleep from time import sleep
from typing import Dict, List, Mapping, Optional from typing import Dict, List, Mapping, Optional, Union
import uvicorn import uvicorn
from deepdiff import DeepDiff, Delta from deepdiff import DeepDiff, Delta
@ -23,7 +23,14 @@ from websockets.exceptions import ConnectionClosed
from lightning_app.api.http_methods import HttpMethod from lightning_app.api.http_methods import HttpMethod
from lightning_app.api.request_types import DeltaRequest from lightning_app.api.request_types import DeltaRequest
from lightning_app.core.constants import CLOUD_QUEUE_TYPE, ENABLE_STATE_WEBSOCKET, FRONTEND_DIR from lightning_app.core.constants import (
CLOUD_QUEUE_TYPE,
ENABLE_PULLING_STATE_ENDPOINT,
ENABLE_PUSHING_STATE_ENDPOINT,
ENABLE_STATE_WEBSOCKET,
ENABLE_UPLOAD_ENDPOINT,
FRONTEND_DIR,
)
from lightning_app.core.queues import QueuingSystem from lightning_app.core.queues import QueuingSystem
from lightning_app.storage import Drive from lightning_app.storage import Drive
from lightning_app.utilities.app_helpers import InMemoryStateStore, Logger, StateStore from lightning_app.utilities.app_helpers import InMemoryStateStore, Logger, StateStore
@ -163,6 +170,7 @@ if _is_starsessions_available():
# ranks) # ranks)
@fastapi_service.get("/api/v1/state", response_class=JSONResponse) @fastapi_service.get("/api/v1/state", response_class=JSONResponse)
async def get_state( async def get_state(
response: Response,
x_lightning_type: Optional[str] = Header(None), x_lightning_type: Optional[str] = Header(None),
x_lightning_session_uuid: Optional[str] = Header(None), x_lightning_session_uuid: Optional[str] = Header(None),
x_lightning_session_id: Optional[str] = Header(None), x_lightning_session_id: Optional[str] = Header(None),
@ -172,6 +180,10 @@ async def get_state(
if x_lightning_session_id is None: if x_lightning_session_id is None:
raise Exception("Missing X-Lightning-Session-ID header") raise Exception("Missing X-Lightning-Session-ID header")
if not ENABLE_PULLING_STATE_ENDPOINT:
response.status_code = status.HTTP_405_METHOD_NOT_ALLOWED
return {"status": "failure", "reason": "This endpoint is disabled."}
with lock: with lock:
x_lightning_session_uuid = TEST_SESSION_UUID x_lightning_session_uuid = TEST_SESSION_UUID
state = global_app_state_store.get_app_state(x_lightning_session_uuid) state = global_app_state_store.get_app_state(x_lightning_session_uuid)
@ -179,15 +191,48 @@ async def get_state(
return state return state
def _get_component_by_name(component_name: str, state):
child = state
for child_name in component_name.split(".")[1:]:
try:
child = child["flows"][child_name]
except KeyError:
child = child["structures"][child_name]
if isinstance(child["vars"]["_layout"], list):
assert len(child["vars"]["_layout"]) == 1
return child["vars"]["_layout"][0]["target"]
return child["vars"]["_layout"]["target"]
@fastapi_service.get("/api/v1/layout", response_class=JSONResponse)
async def get_layout() -> Mapping:
with lock:
x_lightning_session_uuid = TEST_SESSION_UUID
state = global_app_state_store.get_app_state(x_lightning_session_uuid)
global_app_state_store.set_served_state(x_lightning_session_uuid, state)
layout = deepcopy(state["vars"]["_layout"])
for la in layout:
if la["content"].startswith("root."):
la["content"] = _get_component_by_name(la["content"], state)
return layout
@fastapi_service.get("/api/v1/spec", response_class=JSONResponse) @fastapi_service.get("/api/v1/spec", response_class=JSONResponse)
async def get_spec( async def get_spec(
response: Response,
x_lightning_session_uuid: Optional[str] = Header(None), x_lightning_session_uuid: Optional[str] = Header(None),
x_lightning_session_id: Optional[str] = Header(None), x_lightning_session_id: Optional[str] = Header(None),
) -> List: ) -> Union[List, Dict]:
if x_lightning_session_uuid is None: if x_lightning_session_uuid is None:
raise Exception("Missing X-Lightning-Session-UUID header") raise Exception("Missing X-Lightning-Session-UUID header")
if x_lightning_session_id is None: if x_lightning_session_id is None:
raise Exception("Missing X-Lightning-Session-ID header") raise Exception("Missing X-Lightning-Session-ID header")
if not ENABLE_PULLING_STATE_ENDPOINT:
response.status_code = status.HTTP_405_METHOD_NOT_ALLOWED
return {"status": "failure", "reason": "This endpoint is disabled."}
global app_spec global app_spec
return app_spec or [] return app_spec or []
@ -195,10 +240,11 @@ async def get_spec(
@fastapi_service.post("/api/v1/delta") @fastapi_service.post("/api/v1/delta")
async def post_delta( async def post_delta(
request: Request, request: Request,
response: Response,
x_lightning_type: Optional[str] = Header(None), x_lightning_type: Optional[str] = Header(None),
x_lightning_session_uuid: Optional[str] = Header(None), x_lightning_session_uuid: Optional[str] = Header(None),
x_lightning_session_id: Optional[str] = Header(None), x_lightning_session_id: Optional[str] = Header(None),
) -> None: ) -> Optional[Dict]:
"""This endpoint is used to make an update to the app state using delta diff, mainly used by streamlit to """This endpoint is used to make an update to the app state using delta diff, mainly used by streamlit to
update the state.""" update the state."""
@ -207,6 +253,10 @@ async def post_delta(
if x_lightning_session_id is None: if x_lightning_session_id is None:
raise Exception("Missing X-Lightning-Session-ID header") raise Exception("Missing X-Lightning-Session-ID header")
if not ENABLE_PUSHING_STATE_ENDPOINT:
response.status_code = status.HTTP_405_METHOD_NOT_ALLOWED
return {"status": "failure", "reason": "This endpoint is disabled."}
body: Dict = await request.json() body: Dict = await request.json()
api_app_delta_queue.put(DeltaRequest(delta=Delta(body["delta"]))) api_app_delta_queue.put(DeltaRequest(delta=Delta(body["delta"])))
@ -214,10 +264,11 @@ async def post_delta(
@fastapi_service.post("/api/v1/state") @fastapi_service.post("/api/v1/state")
async def post_state( async def post_state(
request: Request, request: Request,
response: Response,
x_lightning_type: Optional[str] = Header(None), x_lightning_type: Optional[str] = Header(None),
x_lightning_session_uuid: Optional[str] = Header(None), x_lightning_session_uuid: Optional[str] = Header(None),
x_lightning_session_id: Optional[str] = Header(None), x_lightning_session_id: Optional[str] = Header(None),
) -> None: ) -> Optional[Dict]:
if x_lightning_session_uuid is None: if x_lightning_session_uuid is None:
raise Exception("Missing X-Lightning-Session-UUID header") raise Exception("Missing X-Lightning-Session-UUID header")
if x_lightning_session_id is None: if x_lightning_session_id is None:
@ -231,6 +282,10 @@ async def post_state(
body: Dict = await request.json() body: Dict = await request.json()
x_lightning_session_uuid = TEST_SESSION_UUID x_lightning_session_uuid = TEST_SESSION_UUID
if not ENABLE_PUSHING_STATE_ENDPOINT:
response.status_code = status.HTTP_405_METHOD_NOT_ALLOWED
return {"status": "failure", "reason": "This endpoint is disabled."}
if "stage" in body: if "stage" in body:
last_state = global_app_state_store.get_served_state(x_lightning_session_uuid) last_state = global_app_state_store.get_served_state(x_lightning_session_uuid)
state = deepcopy(last_state) state = deepcopy(last_state)
@ -244,7 +299,11 @@ async def post_state(
@fastapi_service.put("/api/v1/upload_file/{filename}") @fastapi_service.put("/api/v1/upload_file/{filename}")
async def upload_file(filename: str, uploaded_file: UploadFile = File(...)): async def upload_file(response: Response, filename: str, uploaded_file: UploadFile = File(...)):
if not ENABLE_UPLOAD_ENDPOINT:
response.status_code = status.HTTP_405_METHOD_NOT_ALLOWED
return {"status": "failure", "reason": "This endpoint is disabled."}
with TemporaryDirectory() as tmp: with TemporaryDirectory() as tmp:
drive = Drive( drive = Drive(
"lit://uploaded_files", "lit://uploaded_files",

View File

@ -49,7 +49,6 @@ LIGHTNING_CREDENTIAL_PATH = os.getenv("LIGHTNING_CREDENTIAL_PATH", str(Path(LIGH
DOT_IGNORE_FILENAME = ".lightningignore" DOT_IGNORE_FILENAME = ".lightningignore"
LIGHTNING_COMPONENT_PUBLIC_REGISTRY = "https://lightning.ai/v1/components" LIGHTNING_COMPONENT_PUBLIC_REGISTRY = "https://lightning.ai/v1/components"
LIGHTNING_APPS_PUBLIC_REGISTRY = "https://lightning.ai/v1/apps" LIGHTNING_APPS_PUBLIC_REGISTRY = "https://lightning.ai/v1/apps"
ENABLE_STATE_WEBSOCKET = bool(int(os.getenv("ENABLE_STATE_WEBSOCKET", "0")))
# EXPERIMENTAL: ENV VARIABLES TO ENABLE MULTIPLE WORKS IN THE SAME MACHINE # EXPERIMENTAL: ENV VARIABLES TO ENABLE MULTIPLE WORKS IN THE SAME MACHINE
DEFAULT_NUMBER_OF_EXPOSED_PORTS = int(os.getenv("DEFAULT_NUMBER_OF_EXPOSED_PORTS", "50")) DEFAULT_NUMBER_OF_EXPOSED_PORTS = int(os.getenv("DEFAULT_NUMBER_OF_EXPOSED_PORTS", "50"))
@ -63,3 +62,9 @@ ENABLE_MULTIPLE_WORKS_IN_NON_DEFAULT_CONTAINER = bool(
DEBUG: bool = lightning_cloud.env.DEBUG DEBUG: bool = lightning_cloud.env.DEBUG
DEBUG_ENABLED = bool(int(os.getenv("LIGHTNING_DEBUG", "0"))) DEBUG_ENABLED = bool(int(os.getenv("LIGHTNING_DEBUG", "0")))
ENABLE_PULLING_STATE_ENDPOINT = bool(int(os.getenv("ENABLE_PULLING_STATE_ENDPOINT", "1")))
ENABLE_PUSHING_STATE_ENDPOINT = ENABLE_PULLING_STATE_ENDPOINT and bool(
int(os.getenv("ENABLE_PUSHING_STATE_ENDPOINT", "1"))
)
ENABLE_STATE_WEBSOCKET = bool(int(os.getenv("ENABLE_STATE_WEBSOCKET", "0")))
ENABLE_UPLOAD_ENDPOINT = bool(int(os.getenv("ENABLE_UPLOAD_ENDPOINT", "1")))

View File

@ -276,7 +276,11 @@ class LightningFlow:
@property @property
def flows(self): def flows(self):
"""Return its children LightningFlow.""" """Return its children LightningFlow."""
return {el: getattr(self, el) for el in sorted(self._flows)} flows = {el: getattr(self, el) for el in sorted(self._flows)}
for struct_name in sorted(self._structures):
for flow in getattr(self, struct_name).flows:
flows[flow.name] = flow
return flows
def works(self, recurse: bool = True) -> List[LightningWork]: def works(self, recurse: bool = True) -> List[LightningWork]:
"""Return its :class:`~lightning_app.core.work.LightningWork`.""" """Return its :class:`~lightning_app.core.work.LightningWork`."""

View File

@ -49,6 +49,8 @@ from lightning_app.core.constants import (
DISABLE_DEPENDENCY_CACHE, DISABLE_DEPENDENCY_CACHE,
ENABLE_MULTIPLE_WORKS_IN_DEFAULT_CONTAINER, ENABLE_MULTIPLE_WORKS_IN_DEFAULT_CONTAINER,
ENABLE_MULTIPLE_WORKS_IN_NON_DEFAULT_CONTAINER, ENABLE_MULTIPLE_WORKS_IN_NON_DEFAULT_CONTAINER,
ENABLE_PULLING_STATE_ENDPOINT,
ENABLE_PUSHING_STATE_ENDPOINT,
) )
from lightning_app.runners.backends.cloud import CloudBackend from lightning_app.runners.backends.cloud import CloudBackend
from lightning_app.runners.runtime import Runtime from lightning_app.runners.runtime import Runtime
@ -123,6 +125,12 @@ class CloudRuntime(Runtime):
if ENABLE_MULTIPLE_WORKS_IN_NON_DEFAULT_CONTAINER: if ENABLE_MULTIPLE_WORKS_IN_NON_DEFAULT_CONTAINER:
v1_env_vars.append(V1EnvVar(name="ENABLE_MULTIPLE_WORKS_IN_NON_DEFAULT_CONTAINER", value="1")) v1_env_vars.append(V1EnvVar(name="ENABLE_MULTIPLE_WORKS_IN_NON_DEFAULT_CONTAINER", value="1"))
if not ENABLE_PULLING_STATE_ENDPOINT:
v1_env_vars.append(V1EnvVar(name="ENABLE_PULLING_STATE_ENDPOINT", value="0"))
if not ENABLE_PUSHING_STATE_ENDPOINT:
v1_env_vars.append(V1EnvVar(name="ENABLE_PUSHING_STATE_ENDPOINT", value="0"))
work_reqs: List[V1Work] = [] work_reqs: List[V1Work] = []
for flow in self.app.flows: for flow in self.app.flows:
for work in flow.works(recurse=False): for work in flow.works(recurse=False):

View File

@ -1,5 +1,6 @@
import os import os
import shutil import shutil
import threading
from datetime import datetime from datetime import datetime
from pathlib import Path from pathlib import Path
from subprocess import Popen from subprocess import Popen
@ -46,6 +47,11 @@ def pytest_sessionfinish(session, exitstatus):
continue continue
child.kill() child.kill()
main_thread = threading.current_thread()
for t in threading.enumerate():
if t is not main_thread:
t.join(0)
@pytest.fixture(scope="function", autouse=True) @pytest.fixture(scope="function", autouse=True)
def cleanup(): def cleanup():

View File

@ -16,6 +16,7 @@ from fastapi import HTTPException
from httpx import AsyncClient from httpx import AsyncClient
from pydantic import BaseModel from pydantic import BaseModel
import lightning_app
from lightning_app import LightningApp, LightningFlow, LightningWork from lightning_app import LightningApp, LightningFlow, LightningWork
from lightning_app.api.http_methods import Post from lightning_app.api.http_methods import Post
from lightning_app.core import api from lightning_app.core import api
@ -114,16 +115,35 @@ def test_app_state_api_with_flows(runtime_cls, tmpdir):
assert app.root.var_a == -1 assert app.root.var_a == -1
class NestedFlow(LightningFlow):
def run(self):
pass
def configure_layout(self):
return {"name": "main", "content": "https://te"}
class FlowA(LightningFlow): class FlowA(LightningFlow):
def __init__(self): def __init__(self):
super().__init__() super().__init__()
self.counter = 0 self.counter = 0
self.flow = NestedFlow()
self.dict = lightning_app.structures.Dict(**{"0": NestedFlow()})
self.list = lightning_app.structures.List(*[NestedFlow()])
def run(self): def run(self):
self.counter += 1 self.counter += 1
if self.counter >= 3: if self.counter >= 3:
self._exit() self._exit()
def configure_layout(self):
return [
{"name": "main_1", "content": "https://te"},
{"name": "main_2", "content": self.flow},
{"name": "main_3", "content": self.dict["0"]},
{"name": "main_4", "content": self.list[0]},
]
class AppStageTestingApp(LightningApp): class AppStageTestingApp(LightningApp):
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
@ -193,7 +213,7 @@ def test_update_publish_state_and_maybe_refresh_ui():
@pytest.mark.parametrize("x_lightning_type", ["DEFAULT", "STREAMLIT"]) @pytest.mark.parametrize("x_lightning_type", ["DEFAULT", "STREAMLIT"])
@pytest.mark.anyio @pytest.mark.anyio
async def test_start_server(x_lightning_type): async def test_start_server(x_lightning_type, monkeypatch):
"""This test relies on FastAPI TestClient and validates that the REST API properly provides: """This test relies on FastAPI TestClient and validates that the REST API properly provides:
- the state on GET /api/v1/state - the state on GET /api/v1/state
@ -205,6 +225,7 @@ async def test_start_server(x_lightning_type):
return self._queue[0] return self._queue[0]
app = AppStageTestingApp(FlowA(), debug=True) app = AppStageTestingApp(FlowA(), debug=True)
app._update_layout()
app.stage = AppStage.BLOCKING app.stage = AppStage.BLOCKING
publish_state_queue = InfiniteQueue("publish_state_queue") publish_state_queue = InfiniteQueue("publish_state_queue")
change_state_queue = MockQueue("change_state_queue") change_state_queue = MockQueue("change_state_queue")
@ -261,6 +282,14 @@ async def test_start_server(x_lightning_type):
} }
assert response.status_code == 200 assert response.status_code == 200
response = await client.get("/api/v1/layout")
assert response.json() == [
{"name": "main_1", "content": "https://te", "target": "https://te"},
{"name": "main_2", "content": "https://te"},
{"name": "main_3", "content": "https://te"},
{"name": "main_4", "content": "https://te"},
]
response = await client.post("/api/v1/state", json={"state": new_state}, headers=headers) response = await client.post("/api/v1/state", json={"state": new_state}, headers=headers)
assert change_state_queue._queue[1].to_dict() == { assert change_state_queue._queue[1].to_dict() == {
"values_changed": {"root['vars']['counter']": {"new_value": 1}} "values_changed": {"root['vars']['counter']": {"new_value": 1}}
@ -281,6 +310,33 @@ async def test_start_server(x_lightning_type):
} }
assert response.status_code == 200 assert response.status_code == 200
monkeypatch.setattr(api, "ENABLE_PULLING_STATE_ENDPOINT", False)
response = await client.get("/api/v1/state", headers=headers)
assert response.status_code == 405
response = await client.post("/api/v1/state", json={"state": new_state}, headers=headers)
assert response.status_code == 200
monkeypatch.setattr(api, "ENABLE_PUSHING_STATE_ENDPOINT", False)
response = await client.post("/api/v1/state", json={"state": new_state}, headers=headers)
assert response.status_code == 405
response = await client.post(
"/api/v1/delta",
json={
"delta": {
"values_changed": {"root['flows']['video_search']['vars']['should_process']": {"new_value": True}}
}
},
headers=headers,
)
assert change_state_queue._queue[2].to_dict() == {
"values_changed": {"root['flows']['video_search']['vars']['should_process']": {"new_value": True}}
}
assert response.status_code == 405
# used to clean the app_state_store to following test. # used to clean the app_state_store to following test.
global_app_state_store.remove("1234") global_app_state_store.remove("1234")
global_app_state_store.add("1234") global_app_state_store.add("1234")