[App] Add annotations endpoint (#16159)

This commit is contained in:
Ethan Harris 2022-12-21 12:35:13 +00:00 committed by GitHub
parent 965767199c
commit 0630444fd4
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 51 additions and 0 deletions

View File

@ -1,10 +1,12 @@
import asyncio import asyncio
import json
import os import os
import queue import queue
import sys import sys
import traceback import traceback
from copy import deepcopy from copy import deepcopy
from multiprocessing import Queue from multiprocessing import Queue
from pathlib import Path
from tempfile import TemporaryDirectory from tempfile import TemporaryDirectory
from threading import Event, Lock, Thread from threading import Event, Lock, Thread
from time import sleep from time import sleep
@ -68,6 +70,7 @@ lock = Lock()
app_spec: Optional[List] = None app_spec: Optional[List] = None
app_status: Optional[AppStatus] = None app_status: Optional[AppStatus] = None
app_annotations: Optional[List] = None
# In the future, this would be abstracted to support horizontal scaling. # In the future, this would be abstracted to support horizontal scaling.
responses_store = {} responses_store = {}
@ -345,6 +348,13 @@ async def get_status() -> AppStatus:
return app_status return app_status
@fastapi_service.get("/api/v1/annotations", response_class=JSONResponse)
async def get_annotations() -> Union[List, Dict]:
"""Get the annotations associated with this app."""
global app_annotations
return app_annotations or []
@fastapi_service.get("/healthz", status_code=200) @fastapi_service.get("/healthz", status_code=200)
async def healthz(response: Response): async def healthz(response: Response):
"""Health check endpoint used in the cloud FastAPI servers to check the status periodically.""" """Health check endpoint used in the cloud FastAPI servers to check the status periodically."""
@ -440,6 +450,7 @@ def start_server(
global api_app_delta_queue global api_app_delta_queue
global global_app_state_store global global_app_state_store
global app_spec global app_spec
global app_annotations
app_spec = spec app_spec = spec
api_app_delta_queue = api_delta_queue api_app_delta_queue = api_delta_queue
@ -449,6 +460,12 @@ def start_server(
global_app_state_store.add(TEST_SESSION_UUID) global_app_state_store.add(TEST_SESSION_UUID)
# Load annotations
annotations_path = Path("lightning-annotations.json").resolve()
if annotations_path.exists():
with open(annotations_path) as f:
app_annotations = json.load(f)
refresher = UIRefresher(api_publish_state_queue, api_response_queue) refresher = UIRefresher(api_publish_state_queue, api_response_queue)
refresher.setDaemon(True) refresher.setDaemon(True)
refresher.start() refresher.start()

View File

@ -5,6 +5,7 @@ import os
import sys import sys
from copy import deepcopy from copy import deepcopy
from multiprocessing import Process from multiprocessing import Process
from pathlib import Path
from time import sleep, time from time import sleep, time
from unittest import mock from unittest import mock
@ -562,3 +563,36 @@ def test_configure_api():
time_left -= 0.1 time_left -= 0.1
assert process.exitcode == 0 assert process.exitcode == 0
process.kill() process.kill()
@pytest.mark.anyio
@mock.patch("lightning_app.core.api.UIRefresher", mock.MagicMock())
async def test_get_annotations(tmpdir):
cwd = os.getcwd()
os.chdir(tmpdir)
Path("lightning-annotations.json").write_text('[{"test": 3}]')
try:
app = AppStageTestingApp(FlowA(), log_level="debug")
app._update_layout()
app.stage = AppStage.BLOCKING
change_state_queue = _MockQueue("change_state_queue")
has_started_queue = _MockQueue("has_started_queue")
api_response_queue = _MockQueue("api_response_queue")
spec = extract_metadata_from_app(app)
start_server(
None,
change_state_queue,
api_response_queue,
has_started_queue=has_started_queue,
uvicorn_run=False,
spec=spec,
)
async with AsyncClient(app=fastapi_service, base_url="http://test") as client:
response = await client.get("/api/v1/annotations")
assert response.json() == [{"test": 3}]
finally:
# Cleanup
os.chdir(cwd)