Use websockets in e2es (#14138)

This commit is contained in:
thomas chaton 2022-08-10 13:17:29 +02:00 committed by GitHub
parent d5f35ece72
commit 2f7daac4b8
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
10 changed files with 97 additions and 47 deletions

View File

@ -136,10 +136,10 @@ def logs(app_name: str, components: List[str], follow: bool) -> None:
rich_colors = list(ANSI_COLOR_NAMES)
colors = {c: rich_colors[i + 1] for i, c in enumerate(components)}
for component_name, log_event in log_reader:
for log_event in log_reader:
date = log_event.timestamp.strftime("%m/%d/%Y %H:%M:%S")
color = colors[component_name]
rich.print(f"[{color}]{component_name}[/{color}] {date} {log_event.message}")
color = colors[log_event.component_name]
rich.print(f"[{color}]{log_event.component_name}[/{color}] {date} {log_event.message}")
@_main.command()

View File

@ -1,26 +1,30 @@
import asyncio
import json
import logging
import os
import shutil
import subprocess
import sys
import tempfile
import time
import traceback
from contextlib import contextmanager
from subprocess import Popen
from time import sleep
from typing import Any, Callable, Dict, Generator, List, Type
from typing import Any, Callable, Dict, Generator, List, Optional, Type
import requests
from lightning_cloud.openapi.rest import ApiException
from requests import Session
from rich import print
from rich.color import ANSI_COLOR_NAMES
from lightning_app import LightningApp, LightningFlow
from lightning_app.cli.lightning_cli import run_app
from lightning_app.core.constants import LIGHTNING_CLOUD_PROJECT_ID
from lightning_app.runners.multiprocess import MultiProcessRuntime
from lightning_app.testing.config import Config
from lightning_app.utilities.app_logs import _app_logs_reader
from lightning_app.utilities.cloud import _get_project
from lightning_app.utilities.enum import CacheCallsKeys
from lightning_app.utilities.imports import _is_playwright_available, requires
@ -32,6 +36,9 @@ if _is_playwright_available():
from playwright.sync_api import HttpCredentials, sync_playwright
_logger = logging.getLogger(__name__)
class LightningTestApp(LightningApp):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
@ -282,20 +289,6 @@ def run_app_in_cloud(app_folder: str, app_name: str = "app.py", extra_args: [str
var scrollingElement = (document.scrollingElement || document.body);
scrollingElement.scrollTop = scrollingElement.scrollHeight;
}, 200);
if (!window._logs) {
window._logs = [];
}
if (window.logTerminals) {
Object.entries(window.logTerminals).forEach(
([key, value]) => {
window.logTerminals[key]._onLightningWritelnHandler = function (data) {
window._logs = window._logs.concat([data]);
}
}
);
}
"""
)
@ -309,8 +302,46 @@ def run_app_in_cloud(app_folder: str, app_name: str = "app.py", extra_args: [str
except (playwright._impl._api_types.Error, playwright._impl._api_types.TimeoutError):
pass
def fetch_logs() -> str:
return admin_page.evaluate("window._logs;")
client = LightningClient()
project = _get_project(client)
identifiers = []
rich_colors = list(ANSI_COLOR_NAMES)
def fetch_logs(component_names: Optional[List[str]] = None) -> Generator:
"""This methods creates websockets connection in threads and returns the logs to the main thread."""
app_id = admin_page.url.split("/")[-1]
if not component_names:
works = client.lightningwork_service_list_lightningwork(
project_id=project.project_id,
app_id=app_id,
).lightningworks
component_names = ["flow"] + [w.name for w in works]
def on_error_callback(ws_app, *_):
print(traceback.print_exc())
ws_app.close()
colors = {c: rich_colors[i + 1] for i, c in enumerate(component_names)}
gen = _app_logs_reader(
client=client,
project_id=project.project_id,
app_id=app_id,
component_names=component_names,
follow=False,
on_error_callback=on_error_callback,
)
max_length = max(len(c.replace("root.", "")) for c in component_names)
for log_event in gen:
message = log_event.message
identifier = f"{log_event.timestamp}{log_event.message}"
if identifier not in identifiers:
date = log_event.timestamp.strftime("%m/%d/%Y %H:%M:%S")
identifiers.append(identifier)
color = colors[log_event.component_name]
padding = (max_length - len(log_event.component_name)) * " "
print(f"[{color}]{log_event.component_name}{padding}[/{color}] {date} {message}")
yield message
# 5. Print your application ID
print(
@ -323,11 +354,6 @@ def run_app_in_cloud(app_folder: str, app_name: str = "app.py", extra_args: [str
pass
finally:
print("##################################################")
printed_logs = []
for log in fetch_logs():
if log not in printed_logs:
printed_logs.append(log)
print(log.split("[0m")[-1])
button = admin_page.locator('[data-cy="stop"]')
try:
button.wait_for(timeout=3 * 1000)
@ -337,8 +363,6 @@ def run_app_in_cloud(app_folder: str, app_name: str = "app.py", extra_args: [str
context.close()
browser.close()
client = LightningClient()
project = _get_project(client)
list_lightningapps = client.lightningapp_instance_service_list_lightningapp_instances(project.project_id)
for lightningapp in list_lightningapps.lightningapps:

View File

@ -5,7 +5,7 @@ from dataclasses import dataclass
from datetime import datetime, timedelta
from json import JSONDecodeError
from threading import Thread
from typing import Iterator, List, Optional, Tuple
from typing import Callable, Iterator, List, Optional
import dateutil.parser
from websocket import WebSocketApp
@ -30,10 +30,17 @@ class _LogEventLabels:
class _LogEvent:
message: str
timestamp: datetime
component_name: str
labels: _LogEventLabels
def __ge__(self, other: "_LogEvent") -> bool:
return self.timestamp >= other.timestamp
def _push_logevents_to_read_queue_callback(component_name: str, read_queue: queue.PriorityQueue):
def __gt__(self, other: "_LogEvent") -> bool:
return self.timestamp > other.timestamp
def _push_log_events_to_read_queue_callback(component_name: str, read_queue: queue.PriorityQueue):
"""Pushes _LogEvents from websocket to read_queue.
Returns callback function used with `on_message_callback` of websocket.WebSocketApp.
@ -43,13 +50,17 @@ def _push_logevents_to_read_queue_callback(component_name: str, read_queue: queu
# We strongly trust that the contract on API will hold atm :D
event_dict = json.loads(msg)
labels = _LogEventLabels(**event_dict["labels"])
if "message" in event_dict:
message = event_dict["message"]
timestamp = dateutil.parser.isoparse(event_dict["timestamp"])
event = _LogEvent(
message=event_dict["message"],
timestamp=dateutil.parser.isoparse(event_dict["timestamp"]),
message=message,
timestamp=timestamp,
component_name=component_name,
labels=labels,
)
read_queue.put((event.timestamp, component_name, event))
read_queue.put(event)
return callback
@ -66,8 +77,13 @@ def _error_callback(ws_app: WebSocketApp, error: Exception):
def _app_logs_reader(
client: LightningClient, project_id: str, app_id: str, component_names: List[str], follow: bool
) -> Iterator[Tuple[str, _LogEvent]]:
client: LightningClient,
project_id: str,
app_id: str,
component_names: List[str],
follow: bool,
on_error_callback: Optional[Callable] = None,
) -> Iterator[_LogEvent]:
read_queue = queue.PriorityQueue()
logs_api_client = _LightningLogsSocketAPI(client.api_client)
@ -78,8 +94,8 @@ def _app_logs_reader(
project_id=project_id,
app_id=app_id,
component=component_name,
on_message_callback=_push_logevents_to_read_queue_callback(component_name, read_queue),
on_error_callback=_error_callback,
on_message_callback=_push_log_events_to_read_queue_callback(component_name, read_queue),
on_error_callback=on_error_callback or _error_callback,
)
for component_name in component_names
]
@ -92,20 +108,19 @@ def _app_logs_reader(
for th in log_threads:
th.start()
# Print logs from queue when log event is available
user_log_start = "<<< BEGIN USER_RUN_FLOW SECTION >>>"
start_timestamp = None
# Print logs from queue when log event is available
try:
while True:
_, component_name, log_event = read_queue.get(timeout=None if follow else 1.0)
log_event: _LogEvent
log_event = read_queue.get(timeout=None if follow else 1.0)
if user_log_start in log_event.message:
start_timestamp = log_event.timestamp + timedelta(seconds=0.5)
if start_timestamp and log_event.timestamp > start_timestamp:
yield component_name, log_event
yield log_event
except queue.Empty:
# Empty is raised by queue.get if timeout is reached. Follow = False case.

View File

@ -0,0 +1,11 @@
from datetime import datetime
from unittest.mock import MagicMock
from lightning_app.utilities.app_logs import _LogEvent
def test_log_event():
event_1 = _LogEvent("", datetime.now(), MagicMock(), MagicMock())
event_2 = _LogEvent("", datetime.now(), MagicMock(), MagicMock())
assert event_1 < event_2
assert event_1 <= event_2

View File

@ -26,7 +26,7 @@ def test_commands_example_cloud() -> None:
has_logs = False
while not has_logs:
for log in fetch_logs():
for log in fetch_logs(["flow"]):
if "['something', 'else']" in log:
has_logs = True
sleep(1)

View File

@ -16,7 +16,7 @@ def test_custom_work_dependencies_example_cloud() -> None:
) as (_, _, fetch_logs, _):
has_logs = False
while not has_logs:
for log in fetch_logs():
for log in fetch_logs(["flow"]):
if "Custom Work Dependency checker End" in log:
has_logs = True
sleep(1)

View File

@ -11,14 +11,14 @@ from lightning_app.testing.testing import run_app_in_cloud
def test_drive_example_cloud() -> None:
with run_app_in_cloud(os.path.join(_PROJECT_ROOT, "examples/app_drive")) as (
_,
view_page,
_,
fetch_logs,
_,
):
has_logs = False
while not has_logs:
for log in fetch_logs():
for log in fetch_logs(["flow"]):
if "Application End!" in log:
has_logs = True
sleep(1)

View File

@ -17,7 +17,7 @@ def test_idle_timeout_example_cloud() -> None:
):
has_logs = False
while not has_logs:
for log in fetch_logs():
for log in fetch_logs(["flow"]):
if "Application End" in log:
has_logs = True
sleep(1)

View File

@ -13,7 +13,7 @@ def test_payload_example_cloud() -> None:
has_logs = False
while not has_logs:
for log in fetch_logs():
for log in fetch_logs(["flow"]):
if "Application End!" in log:
has_logs = True
sleep(1)

View File

@ -45,7 +45,7 @@ def run_v0_app(fetch_logs, view_page):
wait_for(view_page, check_content, "TAB_2", "Hello from component B")
has_logs = False
while not has_logs:
for log in fetch_logs():
for log in fetch_logs(["flow"]):
if "'a': 'a', 'b': 'b'" in log:
has_logs = True
sleep(1)