From 2f7daac4b80bc13135f7e14dffcdd0bd3d50a654 Mon Sep 17 00:00:00 2001 From: thomas chaton Date: Wed, 10 Aug 2022 13:17:29 +0200 Subject: [PATCH] Use websockets in e2es (#14138) --- src/lightning_app/cli/lightning_cli.py | 6 +- src/lightning_app/testing/testing.py | 72 ++++++++++++------- src/lightning_app/utilities/app_logs.py | 41 +++++++---- tests/tests_app/utilities/test_app_logs.py | 11 +++ tests/tests_app_examples/test_commands.py | 2 +- .../test_custom_work_dependencies.py | 2 +- tests/tests_app_examples/test_drive.py | 4 +- tests/tests_app_examples/test_idle_timeout.py | 2 +- tests/tests_app_examples/test_payload.py | 2 +- tests/tests_app_examples/test_v0_app.py | 2 +- 10 files changed, 97 insertions(+), 47 deletions(-) create mode 100644 tests/tests_app/utilities/test_app_logs.py diff --git a/src/lightning_app/cli/lightning_cli.py b/src/lightning_app/cli/lightning_cli.py index 45c80d4dcc..babe0aa2b2 100644 --- a/src/lightning_app/cli/lightning_cli.py +++ b/src/lightning_app/cli/lightning_cli.py @@ -136,10 +136,10 @@ def logs(app_name: str, components: List[str], follow: bool) -> None: rich_colors = list(ANSI_COLOR_NAMES) colors = {c: rich_colors[i + 1] for i, c in enumerate(components)} - for component_name, log_event in log_reader: + for log_event in log_reader: date = log_event.timestamp.strftime("%m/%d/%Y %H:%M:%S") - color = colors[component_name] - rich.print(f"[{color}]{component_name}[/{color}] {date} {log_event.message}") + color = colors[log_event.component_name] + rich.print(f"[{color}]{log_event.component_name}[/{color}] {date} {log_event.message}") @_main.command() diff --git a/src/lightning_app/testing/testing.py b/src/lightning_app/testing/testing.py index 74d57db38c..884c02a052 100644 --- a/src/lightning_app/testing/testing.py +++ b/src/lightning_app/testing/testing.py @@ -1,26 +1,30 @@ import asyncio import json +import logging import os import shutil import subprocess import sys import tempfile import time +import traceback from contextlib import contextmanager from subprocess import Popen from time import sleep -from typing import Any, Callable, Dict, Generator, List, Type +from typing import Any, Callable, Dict, Generator, List, Optional, Type import requests from lightning_cloud.openapi.rest import ApiException from requests import Session from rich import print +from rich.color import ANSI_COLOR_NAMES from lightning_app import LightningApp, LightningFlow from lightning_app.cli.lightning_cli import run_app from lightning_app.core.constants import LIGHTNING_CLOUD_PROJECT_ID from lightning_app.runners.multiprocess import MultiProcessRuntime from lightning_app.testing.config import Config +from lightning_app.utilities.app_logs import _app_logs_reader from lightning_app.utilities.cloud import _get_project from lightning_app.utilities.enum import CacheCallsKeys from lightning_app.utilities.imports import _is_playwright_available, requires @@ -32,6 +36,9 @@ if _is_playwright_available(): from playwright.sync_api import HttpCredentials, sync_playwright +_logger = logging.getLogger(__name__) + + class LightningTestApp(LightningApp): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) @@ -282,20 +289,6 @@ def run_app_in_cloud(app_folder: str, app_name: str = "app.py", extra_args: [str var scrollingElement = (document.scrollingElement || document.body); scrollingElement.scrollTop = scrollingElement.scrollHeight; }, 200); - - if (!window._logs) { - window._logs = []; - } - - if (window.logTerminals) { - Object.entries(window.logTerminals).forEach( - ([key, value]) => { - window.logTerminals[key]._onLightningWritelnHandler = function (data) { - window._logs = window._logs.concat([data]); - } - } - ); - } """ ) @@ -309,8 +302,46 @@ def run_app_in_cloud(app_folder: str, app_name: str = "app.py", extra_args: [str except (playwright._impl._api_types.Error, playwright._impl._api_types.TimeoutError): pass - def fetch_logs() -> str: - return admin_page.evaluate("window._logs;") + client = LightningClient() + project = _get_project(client) + identifiers = [] + rich_colors = list(ANSI_COLOR_NAMES) + + def fetch_logs(component_names: Optional[List[str]] = None) -> Generator: + """This methods creates websockets connection in threads and returns the logs to the main thread.""" + app_id = admin_page.url.split("/")[-1] + + if not component_names: + works = client.lightningwork_service_list_lightningwork( + project_id=project.project_id, + app_id=app_id, + ).lightningworks + component_names = ["flow"] + [w.name for w in works] + + def on_error_callback(ws_app, *_): + print(traceback.print_exc()) + ws_app.close() + + colors = {c: rich_colors[i + 1] for i, c in enumerate(component_names)} + gen = _app_logs_reader( + client=client, + project_id=project.project_id, + app_id=app_id, + component_names=component_names, + follow=False, + on_error_callback=on_error_callback, + ) + max_length = max(len(c.replace("root.", "")) for c in component_names) + for log_event in gen: + message = log_event.message + identifier = f"{log_event.timestamp}{log_event.message}" + if identifier not in identifiers: + date = log_event.timestamp.strftime("%m/%d/%Y %H:%M:%S") + identifiers.append(identifier) + color = colors[log_event.component_name] + padding = (max_length - len(log_event.component_name)) * " " + print(f"[{color}]{log_event.component_name}{padding}[/{color}] {date} {message}") + yield message # 5. Print your application ID print( @@ -323,11 +354,6 @@ def run_app_in_cloud(app_folder: str, app_name: str = "app.py", extra_args: [str pass finally: print("##################################################") - printed_logs = [] - for log in fetch_logs(): - if log not in printed_logs: - printed_logs.append(log) - print(log.split("[0m")[-1]) button = admin_page.locator('[data-cy="stop"]') try: button.wait_for(timeout=3 * 1000) @@ -337,8 +363,6 @@ def run_app_in_cloud(app_folder: str, app_name: str = "app.py", extra_args: [str context.close() browser.close() - client = LightningClient() - project = _get_project(client) list_lightningapps = client.lightningapp_instance_service_list_lightningapp_instances(project.project_id) for lightningapp in list_lightningapps.lightningapps: diff --git a/src/lightning_app/utilities/app_logs.py b/src/lightning_app/utilities/app_logs.py index 4a7af9b5c5..536fbaae05 100644 --- a/src/lightning_app/utilities/app_logs.py +++ b/src/lightning_app/utilities/app_logs.py @@ -5,7 +5,7 @@ from dataclasses import dataclass from datetime import datetime, timedelta from json import JSONDecodeError from threading import Thread -from typing import Iterator, List, Optional, Tuple +from typing import Callable, Iterator, List, Optional import dateutil.parser from websocket import WebSocketApp @@ -30,10 +30,17 @@ class _LogEventLabels: class _LogEvent: message: str timestamp: datetime + component_name: str labels: _LogEventLabels + def __ge__(self, other: "_LogEvent") -> bool: + return self.timestamp >= other.timestamp -def _push_logevents_to_read_queue_callback(component_name: str, read_queue: queue.PriorityQueue): + def __gt__(self, other: "_LogEvent") -> bool: + return self.timestamp > other.timestamp + + +def _push_log_events_to_read_queue_callback(component_name: str, read_queue: queue.PriorityQueue): """Pushes _LogEvents from websocket to read_queue. Returns callback function used with `on_message_callback` of websocket.WebSocketApp. @@ -43,13 +50,17 @@ def _push_logevents_to_read_queue_callback(component_name: str, read_queue: queu # We strongly trust that the contract on API will hold atm :D event_dict = json.loads(msg) labels = _LogEventLabels(**event_dict["labels"]) + if "message" in event_dict: + message = event_dict["message"] + timestamp = dateutil.parser.isoparse(event_dict["timestamp"]) event = _LogEvent( - message=event_dict["message"], - timestamp=dateutil.parser.isoparse(event_dict["timestamp"]), + message=message, + timestamp=timestamp, + component_name=component_name, labels=labels, ) - read_queue.put((event.timestamp, component_name, event)) + read_queue.put(event) return callback @@ -66,8 +77,13 @@ def _error_callback(ws_app: WebSocketApp, error: Exception): def _app_logs_reader( - client: LightningClient, project_id: str, app_id: str, component_names: List[str], follow: bool -) -> Iterator[Tuple[str, _LogEvent]]: + client: LightningClient, + project_id: str, + app_id: str, + component_names: List[str], + follow: bool, + on_error_callback: Optional[Callable] = None, +) -> Iterator[_LogEvent]: read_queue = queue.PriorityQueue() logs_api_client = _LightningLogsSocketAPI(client.api_client) @@ -78,8 +94,8 @@ def _app_logs_reader( project_id=project_id, app_id=app_id, component=component_name, - on_message_callback=_push_logevents_to_read_queue_callback(component_name, read_queue), - on_error_callback=_error_callback, + on_message_callback=_push_log_events_to_read_queue_callback(component_name, read_queue), + on_error_callback=on_error_callback or _error_callback, ) for component_name in component_names ] @@ -92,20 +108,19 @@ def _app_logs_reader( for th in log_threads: th.start() + # Print logs from queue when log event is available user_log_start = "<<< BEGIN USER_RUN_FLOW SECTION >>>" start_timestamp = None # Print logs from queue when log event is available try: while True: - _, component_name, log_event = read_queue.get(timeout=None if follow else 1.0) - log_event: _LogEvent - + log_event = read_queue.get(timeout=None if follow else 1.0) if user_log_start in log_event.message: start_timestamp = log_event.timestamp + timedelta(seconds=0.5) if start_timestamp and log_event.timestamp > start_timestamp: - yield component_name, log_event + yield log_event except queue.Empty: # Empty is raised by queue.get if timeout is reached. Follow = False case. diff --git a/tests/tests_app/utilities/test_app_logs.py b/tests/tests_app/utilities/test_app_logs.py new file mode 100644 index 0000000000..e7384dd72d --- /dev/null +++ b/tests/tests_app/utilities/test_app_logs.py @@ -0,0 +1,11 @@ +from datetime import datetime +from unittest.mock import MagicMock + +from lightning_app.utilities.app_logs import _LogEvent + + +def test_log_event(): + event_1 = _LogEvent("", datetime.now(), MagicMock(), MagicMock()) + event_2 = _LogEvent("", datetime.now(), MagicMock(), MagicMock()) + assert event_1 < event_2 + assert event_1 <= event_2 diff --git a/tests/tests_app_examples/test_commands.py b/tests/tests_app_examples/test_commands.py index 266f0305c7..236e587e23 100644 --- a/tests/tests_app_examples/test_commands.py +++ b/tests/tests_app_examples/test_commands.py @@ -26,7 +26,7 @@ def test_commands_example_cloud() -> None: has_logs = False while not has_logs: - for log in fetch_logs(): + for log in fetch_logs(["flow"]): if "['something', 'else']" in log: has_logs = True sleep(1) diff --git a/tests/tests_app_examples/test_custom_work_dependencies.py b/tests/tests_app_examples/test_custom_work_dependencies.py index d7c9db5ef6..b8971e0ef2 100644 --- a/tests/tests_app_examples/test_custom_work_dependencies.py +++ b/tests/tests_app_examples/test_custom_work_dependencies.py @@ -16,7 +16,7 @@ def test_custom_work_dependencies_example_cloud() -> None: ) as (_, _, fetch_logs, _): has_logs = False while not has_logs: - for log in fetch_logs(): + for log in fetch_logs(["flow"]): if "Custom Work Dependency checker End" in log: has_logs = True sleep(1) diff --git a/tests/tests_app_examples/test_drive.py b/tests/tests_app_examples/test_drive.py index 14efc34587..630e76b550 100644 --- a/tests/tests_app_examples/test_drive.py +++ b/tests/tests_app_examples/test_drive.py @@ -11,14 +11,14 @@ from lightning_app.testing.testing import run_app_in_cloud def test_drive_example_cloud() -> None: with run_app_in_cloud(os.path.join(_PROJECT_ROOT, "examples/app_drive")) as ( _, - view_page, + _, fetch_logs, _, ): has_logs = False while not has_logs: - for log in fetch_logs(): + for log in fetch_logs(["flow"]): if "Application End!" in log: has_logs = True sleep(1) diff --git a/tests/tests_app_examples/test_idle_timeout.py b/tests/tests_app_examples/test_idle_timeout.py index a39ae3f693..f06181ce86 100644 --- a/tests/tests_app_examples/test_idle_timeout.py +++ b/tests/tests_app_examples/test_idle_timeout.py @@ -17,7 +17,7 @@ def test_idle_timeout_example_cloud() -> None: ): has_logs = False while not has_logs: - for log in fetch_logs(): + for log in fetch_logs(["flow"]): if "Application End" in log: has_logs = True sleep(1) diff --git a/tests/tests_app_examples/test_payload.py b/tests/tests_app_examples/test_payload.py index 58fc28a4a8..b40b8ca52d 100644 --- a/tests/tests_app_examples/test_payload.py +++ b/tests/tests_app_examples/test_payload.py @@ -13,7 +13,7 @@ def test_payload_example_cloud() -> None: has_logs = False while not has_logs: - for log in fetch_logs(): + for log in fetch_logs(["flow"]): if "Application End!" in log: has_logs = True sleep(1) diff --git a/tests/tests_app_examples/test_v0_app.py b/tests/tests_app_examples/test_v0_app.py index acc9e285c4..026c45a4e1 100644 --- a/tests/tests_app_examples/test_v0_app.py +++ b/tests/tests_app_examples/test_v0_app.py @@ -45,7 +45,7 @@ def run_v0_app(fetch_logs, view_page): wait_for(view_page, check_content, "TAB_2", "Hello from component B") has_logs = False while not has_logs: - for log in fetch_logs(): + for log in fetch_logs(["flow"]): if "'a': 'a', 'b': 'b'" in log: has_logs = True sleep(1)