From 4c35867b618b4c36bfac5428756b95223f7f526a Mon Sep 17 00:00:00 2001 From: thomas chaton Date: Mon, 25 Jul 2022 19:13:46 +0200 Subject: [PATCH] [App] Introduce Commands (#13602) --- .github/workflows/ci-app_cloud_e2e_test.yml | 3 +- .gitignore | 3 + examples/app_commands/.lightning | 1 + examples/app_commands/app.py | 39 +++ examples/app_commands/command.py | 17 ++ src/lightning_app/CHANGELOG.md | 4 + src/lightning_app/cli/lightning_cli.py | 91 ++++++- src/lightning_app/components/python/tracer.py | 4 +- src/lightning_app/core/api.py | 84 +++++- src/lightning_app/core/app.py | 9 + src/lightning_app/core/constants.py | 2 +- src/lightning_app/core/flow.py | 37 ++- src/lightning_app/core/queues.py | 17 ++ src/lightning_app/runners/backends/backend.py | 4 +- src/lightning_app/runners/multiprocess.py | 3 + src/lightning_app/source_code/uploader.py | 22 +- src/lightning_app/testing/testing.py | 6 +- src/lightning_app/utilities/cli_helpers.py | 60 ++++- .../utilities/commands/__init__.py | 3 + src/lightning_app/utilities/commands/base.py | 245 ++++++++++++++++++ .../utilities/packaging/lightning_utils.py | 18 +- src/lightning_app/utilities/proxies.py | 4 + tests/tests_app/cli/test_cli.py | 8 +- .../components/python/test_python.py | 7 +- tests/tests_app/core/test_lightning_api.py | 28 +- tests/tests_app/source_code/test_uploader.py | 5 +- tests/tests_app/utilities/test_commands.py | 162 ++++++++++++ tests/tests_app/utilities/test_state.py | 4 +- tests/tests_app_examples/test_commands.py | 31 +++ 29 files changed, 858 insertions(+), 63 deletions(-) create mode 100644 examples/app_commands/.lightning create mode 100644 examples/app_commands/app.py create mode 100644 examples/app_commands/command.py create mode 100644 src/lightning_app/utilities/commands/__init__.py create mode 100644 src/lightning_app/utilities/commands/base.py create mode 100644 tests/tests_app/utilities/test_commands.py create mode 100644 tests/tests_app_examples/test_commands.py diff --git a/.github/workflows/ci-app_cloud_e2e_test.yml b/.github/workflows/ci-app_cloud_e2e_test.yml index 3abdf9c92b..cb0fbdf40a 100644 --- a/.github/workflows/ci-app_cloud_e2e_test.yml +++ b/.github/workflows/ci-app_cloud_e2e_test.yml @@ -54,6 +54,7 @@ jobs: - custom_work_dependencies - drive - payload + - commands timeout-minutes: 35 steps: - uses: actions/checkout@v2 @@ -155,7 +156,7 @@ jobs: shell: bash run: | mkdir -p ${VIDEO_LOCATION} - HEADLESS=1 python -m pytest tests/tests_app_examples/test_${{ matrix.app_name }}.py::test_${{ matrix.app_name }}_example_cloud --timeout=900 --capture=no -v --color=yes + HEADLESS=1 PACKAGE_LIGHTNING=1 python -m pytest tests/tests_app_examples/test_${{ matrix.app_name }}.py::test_${{ matrix.app_name }}_example_cloud --timeout=900 --capture=no -v --color=yes # Delete the artifacts if successful rm -r ${VIDEO_LOCATION}/${{ matrix.app_name }} diff --git a/.gitignore b/.gitignore index ad4422b1a7..7040a91297 100644 --- a/.gitignore +++ b/.gitignore @@ -109,6 +109,7 @@ celerybeat-schedule # dotenv .env +.env_stagging # virtualenv .venv @@ -160,3 +161,5 @@ tags .tags src/lightning_app/ui/* *examples/template_react_ui* +hars* +artifacts/* diff --git a/examples/app_commands/.lightning b/examples/app_commands/.lightning new file mode 100644 index 0000000000..3efc0ce628 --- /dev/null +++ b/examples/app_commands/.lightning @@ -0,0 +1 @@ +name: app-commands diff --git a/examples/app_commands/app.py b/examples/app_commands/app.py new file mode 100644 index 0000000000..99eb15c75c --- /dev/null +++ b/examples/app_commands/app.py @@ -0,0 +1,39 @@ +from command import CustomCommand, CustomConfig + +from lightning import LightningFlow +from lightning_app.core.app import LightningApp + + +class ChildFlow(LightningFlow): + def trigger_method(self, name: str): + print(f"Hello {name}") + + def configure_commands(self): + return [{"nested_trigger_command": self.trigger_method}] + + +class FlowCommands(LightningFlow): + def __init__(self): + super().__init__() + self.names = [] + self.child_flow = ChildFlow() + + def run(self): + if len(self.names): + print(self.names) + + def trigger_without_client_command(self, name: str): + self.names.append(name) + + def trigger_with_client_command(self, config: CustomConfig): + self.names.append(config.name) + + def configure_commands(self): + commands = [ + {"trigger_without_client_command": self.trigger_without_client_command}, + {"trigger_with_client_command": CustomCommand(self.trigger_with_client_command)}, + ] + return commands + self.child_flow.configure_commands() + + +app = LightningApp(FlowCommands()) diff --git a/examples/app_commands/command.py b/examples/app_commands/command.py new file mode 100644 index 0000000000..8c3070f6d7 --- /dev/null +++ b/examples/app_commands/command.py @@ -0,0 +1,17 @@ +from argparse import ArgumentParser + +from pydantic import BaseModel + +from lightning.app.utilities.commands import ClientCommand + + +class CustomConfig(BaseModel): + name: str + + +class CustomCommand(ClientCommand): + def run(self): + parser = ArgumentParser() + parser.add_argument("--name", type=str) + args = parser.parse_args() + self.invoke_handler(config=CustomConfig(name=args.name)) diff --git a/src/lightning_app/CHANGELOG.md b/src/lightning_app/CHANGELOG.md index bc5bf25dc8..7d0dcb589b 100644 --- a/src/lightning_app/CHANGELOG.md +++ b/src/lightning_app/CHANGELOG.md @@ -8,6 +8,10 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Added +- Add support for `Lightning App Commands` through the `configure_commands` hook on the Lightning Flow and the `ClientCommand` ([#13602](https://github.com/Lightning-AI/lightning/pull/13602)) + +### Changed + - Update the Lightning App docs ([#13537](https://github.com/PyTorchLightning/pytorch-lightning/pull/13537)) ### Changed diff --git a/src/lightning_app/cli/lightning_cli.py b/src/lightning_app/cli/lightning_cli.py index 696269c712..74b2d1c492 100644 --- a/src/lightning_app/cli/lightning_cli.py +++ b/src/lightning_app/cli/lightning_cli.py @@ -1,9 +1,13 @@ import logging import os +import sys +from argparse import ArgumentParser from pathlib import Path from typing import List, Tuple, Union +from uuid import uuid4 import click +import requests from requests.exceptions import ConnectionError from lightning_app import __version__ as ver @@ -11,9 +15,13 @@ from lightning_app.cli import cmd_init, cmd_install, cmd_pl_init, cmd_react_ui_i from lightning_app.core.constants import get_lightning_cloud_url, LOCAL_LAUNCH_ADMIN_VIEW from lightning_app.runners.runtime import dispatch from lightning_app.runners.runtime_type import RuntimeType -from lightning_app.utilities.cli_helpers import _format_input_env_variables +from lightning_app.utilities.cli_helpers import ( + _format_input_env_variables, + _retrieve_application_url_and_available_commands, +) from lightning_app.utilities.install_components import register_all_external_components from lightning_app.utilities.login import Auth +from lightning_app.utilities.state import headers_for logger = logging.getLogger(__name__) @@ -26,14 +34,23 @@ def get_app_url(runtime_type: RuntimeType, *args) -> str: return "http://127.0.0.1:7501/admin" if LOCAL_LAUNCH_ADMIN_VIEW else "http://127.0.0.1:7501/view" +def main(): + if len(sys.argv) == 1: + _main() + elif sys.argv[1] in _main.commands.keys() or sys.argv[1] == "--help": + _main() + else: + app_command() + + @click.group() @click.version_option(ver) -def main(): +def _main(): register_all_external_components() pass -@main.command() +@_main.command() def login(): """Log in to your Lightning.ai account.""" auth = Auth() @@ -46,7 +63,7 @@ def login(): exit(1) -@main.command() +@_main.command() def logout(): """Log out of your Lightning.ai account.""" Auth().clear() @@ -93,7 +110,7 @@ def _run_app( click.echo("Application is ready in the cloud") -@main.group() +@_main.group() def run(): """Run your application.""" @@ -125,31 +142,83 @@ def run_app( _run_app(file, cloud, without_server, no_cache, name, blocking, open_ui, env) -@main.group(hidden=True) +def app_command(): + """Execute a function in a running application from its name.""" + from lightning_app.utilities.commands.base import _download_command + + logger.warn("Lightning Commands are a beta feature and APIs aren't stable yet.") + + debug_mode = bool(int(os.getenv("DEBUG", "0"))) + + parser = ArgumentParser() + parser.add_argument("--app_id", default=None, type=str, help="Optional argument to identify an application.") + hparams, argv = parser.parse_known_args() + + # 1: Collect the url and comments from the running application + url, commands = _retrieve_application_url_and_available_commands(hparams.app_id) + if url is None or commands is None: + raise Exception("We couldn't find any matching running app.") + + if not commands: + raise Exception("This application doesn't expose any commands yet.") + + command = argv[0] + + command_names = [c["command"] for c in commands] + if command not in command_names: + raise Exception(f"The provided command {command} isn't available in {command_names}") + + # 2: Send the command from the user + command_metadata = [c for c in commands if c["command"] == command][0] + params = command_metadata["params"] + + # 3: Execute the command + if not command_metadata["is_client_command"]: + # TODO: Improve what is supported there. + kwargs = {k.split("=")[0].replace("--", ""): k.split("=")[1] for k in argv[1:]} + for param in params: + if param not in kwargs: + raise Exception(f"The argument --{param}=X hasn't been provided.") + json = { + "command_name": command, + "command_arguments": kwargs, + "affiliation": command_metadata["affiliation"], + "id": str(uuid4()), + } + resp = requests.post(url + "/api/v1/commands", json=json, headers=headers_for({})) + assert resp.status_code == 200, resp.json() + else: + client_command, models = _download_command(command_metadata, hparams.app_id, debug_mode=debug_mode) + client_command._setup(metadata=command_metadata, models=models, app_url=url) + sys.argv = argv + client_command.run() + + +@_main.group(hidden=True) def fork(): """Fork an application.""" pass -@main.group(hidden=True) +@_main.group(hidden=True) def stop(): """Stop your application.""" pass -@main.group(hidden=True) +@_main.group(hidden=True) def delete(): """Delete an application.""" pass -@main.group(name="list", hidden=True) +@_main.group(name="list", hidden=True) def get_list(): """List your applications.""" pass -@main.group() +@_main.group() def install(): """Install Lightning apps and components.""" @@ -207,7 +276,7 @@ def install_component(name, yes, version): cmd_install.gallery_component(name, yes, version) -@main.group() +@_main.group() def init(): """Init a Lightning app and component.""" diff --git a/src/lightning_app/components/python/tracer.py b/src/lightning_app/components/python/tracer.py index ed692c7f3e..fa955646ac 100644 --- a/src/lightning_app/components/python/tracer.py +++ b/src/lightning_app/components/python/tracer.py @@ -93,8 +93,6 @@ class TracerPythonScript(LightningWork): :language: python """ super().__init__(**kwargs) - if not os.path.exists(script_path): - raise FileNotFoundError(f"The provided `script_path` {script_path}` wasn't found.") self.script_path = str(script_path) if isinstance(script_args, str): script_args = script_args.split(" ") @@ -105,6 +103,8 @@ class TracerPythonScript(LightningWork): setattr(self, name, None) def run(self, **kwargs): + if not os.path.exists(self.script_path): + raise FileNotFoundError(f"The provided `script_path` {self.script_path}` wasn't found.") kwargs = {k: v.value if isinstance(v, Payload) else v for k, v in kwargs.items()} init_globals = globals() init_globals.update(kwargs) diff --git a/src/lightning_app/core/api.py b/src/lightning_app/core/api.py index 024eb71238..f38c1844e2 100644 --- a/src/lightning_app/core/api.py +++ b/src/lightning_app/core/api.py @@ -3,6 +3,8 @@ import logging import os import queue import sys +import time +import traceback from copy import deepcopy from multiprocessing import Queue from threading import Event, Lock, Thread @@ -40,6 +42,10 @@ STATE_EVENT = "State changed" frontend_static_dir = os.path.join(FRONTEND_DIR, "static") api_app_delta_queue: Queue = None +api_commands_requests_queue: Queue = None +api_commands_metadata_queue: Queue = None +api_commands_responses_queue: Queue = None + template = {"ui": {}, "app": {}} templates = Jinja2Templates(directory=FRONTEND_DIR) @@ -50,6 +56,8 @@ global_app_state_store.add(TEST_SESSION_UUID) lock = Lock() app_spec: Optional[List] = None +app_commands_metadata: Optional[Dict] = None +commands_response_store = {} logger = logging.getLogger(__name__) @@ -59,16 +67,22 @@ logger = logging.getLogger(__name__) class UIRefresher(Thread): - def __init__(self, api_publish_state_queue) -> None: + def __init__(self, api_publish_state_queue, api_commands_metadata_queue, api_commands_responses_queue) -> None: super().__init__(daemon=True) self.api_publish_state_queue = api_publish_state_queue + self.api_commands_metadata_queue = api_commands_metadata_queue + self.api_commands_responses_queue = api_commands_responses_queue self._exit_event = Event() def run(self): # TODO: Create multiple threads to handle the background logic # TODO: Investigate the use of `parallel=True` - while not self._exit_event.is_set(): - self.run_once() + try: + while not self._exit_event.is_set(): + self.run_once() + except Exception as e: + logger.error(traceback.print_exc()) + raise e def run_once(self): try: @@ -78,6 +92,22 @@ class UIRefresher(Thread): except queue.Empty: pass + try: + metadata = self.api_commands_metadata_queue.get(timeout=0) + with lock: + global app_commands_metadata + app_commands_metadata = metadata + except queue.Empty: + pass + + try: + response = self.api_commands_responses_queue.get(timeout=0) + with lock: + global commands_response_store + commands_response_store[response["id"]] = response["response"] + except queue.Empty: + pass + def join(self, timeout: Optional[float] = None) -> None: self._exit_event.set() super().join(timeout) @@ -146,6 +176,43 @@ async def get_spec( return app_spec or [] +@fastapi_service.post("/api/v1/commands", response_class=JSONResponse) +async def run_remote_command( + request: Request, +) -> None: + data = await request.json() + command_name = data.get("command_name", None) + if not command_name: + raise Exception("The provided command name is empty.") + command_arguments = data.get("command_arguments", None) + if not command_arguments: + raise Exception("The provided command metadata is empty.") + affiliation = data.get("affiliation", None) + if not affiliation: + raise Exception("The provided affiliation is empty.") + + async def fn(data): + request_id = data["id"] + api_commands_requests_queue.put(data) + + t0 = time.time() + while request_id not in commands_response_store: + await asyncio.sleep(0.1) + if (time.time() - t0) > 15: + raise Exception("The response was never received.") + + return commands_response_store[request_id] + + return await asyncio.create_task(fn(data)) + + +@fastapi_service.get("/api/v1/commands", response_class=JSONResponse) +async def get_commands() -> Optional[Dict]: + global app_commands_metadata + with lock: + return app_commands_metadata + + @fastapi_service.post("/api/v1/delta") async def post_delta( request: Request, @@ -279,6 +346,9 @@ class LightningUvicornServer(uvicorn.Server): def start_server( api_publish_state_queue, api_delta_queue, + commands_requests_queue, + commands_responses_queue, + commands_metadata_queue, has_started_queue: Optional[Queue] = None, host="127.0.0.1", port=8000, @@ -288,16 +358,22 @@ def start_server( ): global api_app_delta_queue global global_app_state_store + global api_commands_requests_queue + global api_commands_responses_queue global app_spec + app_spec = spec api_app_delta_queue = api_delta_queue + api_commands_requests_queue = commands_requests_queue + api_commands_responses_queue = commands_responses_queue + api_commands_metadata_queue = commands_metadata_queue if app_state_store is not None: global_app_state_store = app_state_store global_app_state_store.add(TEST_SESSION_UUID) - refresher = UIRefresher(api_publish_state_queue) + refresher = UIRefresher(api_publish_state_queue, api_commands_metadata_queue, commands_responses_queue) refresher.setDaemon(True) refresher.start() diff --git a/src/lightning_app/core/app.py b/src/lightning_app/core/app.py index 81a1a2115e..6599b53efc 100644 --- a/src/lightning_app/core/app.py +++ b/src/lightning_app/core/app.py @@ -16,6 +16,7 @@ from lightning_app.core.queues import BaseQueue, SingleProcessQueue from lightning_app.frontend import Frontend from lightning_app.storage.path import storage_root_dir from lightning_app.utilities.app_helpers import _delta_to_appstate_delta, _LightningAppRef +from lightning_app.utilities.commands.base import _populate_commands_endpoint, _process_command_requests from lightning_app.utilities.component import _convert_paths_after_init from lightning_app.utilities.enum import AppStage from lightning_app.utilities.exceptions import CacheMissException, ExitAppException @@ -72,6 +73,9 @@ class LightningApp: # queues definition. self.delta_queue: t.Optional[BaseQueue] = None self.readiness_queue: t.Optional[BaseQueue] = None + self.commands_requests_queue: t.Optional[BaseQueue] = None + self.commands_responses_queue: t.Optional[BaseQueue] = None + self.commands_metadata_queue: t.Optional[BaseQueue] = None self.api_publish_state_queue: t.Optional[BaseQueue] = None self.api_delta_queue: t.Optional[BaseQueue] = None self.error_queue: t.Optional[BaseQueue] = None @@ -81,6 +85,7 @@ class LightningApp: self.copy_response_queues: t.Optional[t.Dict[str, BaseQueue]] = None self.caller_queues: t.Optional[t.Dict[str, BaseQueue]] = None self.work_queues: t.Optional[t.Dict[str, BaseQueue]] = None + self.commands: t.Optional[t.List] = None self.should_publish_changes_to_api = False self.component_affiliation = None @@ -345,6 +350,8 @@ class LightningApp: elif self.stage == AppStage.RESTARTING: return self._apply_restarting() + _process_command_requests(self) + try: self.check_error_queue() t0 = time() @@ -397,6 +404,8 @@ class LightningApp: self._reset_run_time_monitor() + _populate_commands_endpoint(self) + while not done: done = self.run_once() diff --git a/src/lightning_app/core/constants.py b/src/lightning_app/core/constants.py index 7644f60a2c..fd62de13cc 100644 --- a/src/lightning_app/core/constants.py +++ b/src/lightning_app/core/constants.py @@ -22,7 +22,7 @@ REDIS_QUEUES_READ_DEFAULT_TIMEOUT = 0.005 REDIS_WARNING_QUEUE_SIZE = 1000 USER_ID = os.getenv("USER_ID", "1234") FRONTEND_DIR = os.path.join(os.path.dirname(lightning_app.__file__), "ui") -PREPARE_LIGHTING = bool(int(os.getenv("PREPARE_LIGHTING", "0"))) +PACKAGE_LIGHTNING = os.getenv("PACKAGE_LIGHTNING", None) LOCAL_LAUNCH_ADMIN_VIEW = bool(int(os.getenv("LOCAL_LAUNCH_ADMIN_VIEW", "0"))) CLOUD_UPLOAD_WARNING = int(os.getenv("CLOUD_UPLOAD_WARNING", "2")) DISABLE_DEPENDENCY_CACHE = bool(int(os.getenv("DISABLE_DEPENDENCY_CACHE", "0"))) diff --git a/src/lightning_app/core/flow.py b/src/lightning_app/core/flow.py index a5dcfd0a77..d1af891476 100644 --- a/src/lightning_app/core/flow.py +++ b/src/lightning_app/core/flow.py @@ -356,8 +356,7 @@ class LightningFlow: class Flow(LightningFlow): def run(self): if self.schedule("hourly"): - # run some code once every hour. - print("run this every hour") + print("run some code every hour") Arguments: cron_pattern: The cron pattern to provide. Learn more at https://crontab.guru/. @@ -509,20 +508,16 @@ class LightningFlow: # add your streamlit code here! import streamlit as st - st.button("Hello!") **Example:** Arrange the UI of my children in tabs (default UI by Lightning). .. code-block:: python class Flow(LightningFlow): - ... - def configure_layout(self): return [ dict(name="First Tab", content=self.child0), dict(name="Second Tab", content=self.child1), - # You can include direct URLs too dict(name="Lightning", content="https://lightning.ai"), ] @@ -608,3 +603,33 @@ class LightningFlow: yield value self._calls[call_hash].update({"has_finished": True}) + + def configure_commands(self): + """Configure the commands of this LightningFlow. + + Returns a list of dictionaries mapping a command name to a flow method. + + .. code-block:: python + + class Flow(LightningFlow): + def __init__(self): + super().__init__() + self.names = [] + + def configure_commands(self): + return {"my_command_name": self.my_remote_method} + + def my_remote_method(self, name): + self.names.append(name) + + Once the app is running with the following command: + + .. code-block:: bash + + lightning run app app.py + + .. code-block:: bash + + lightning my_command_name --args name=my_own_name + """ + raise NotImplementedError diff --git a/src/lightning_app/core/queues.py b/src/lightning_app/core/queues.py index 3b88d89653..efac823004 100644 --- a/src/lightning_app/core/queues.py +++ b/src/lightning_app/core/queues.py @@ -36,6 +36,9 @@ ORCHESTRATOR_RESPONSE_CONSTANT = "ORCHESTRATOR_RESPONSE" ORCHESTRATOR_COPY_REQUEST_CONSTANT = "ORCHESTRATOR_COPY_REQUEST" ORCHESTRATOR_COPY_RESPONSE_CONSTANT = "ORCHESTRATOR_COPY_RESPONSE" WORK_QUEUE_CONSTANT = "WORK_QUEUE" +COMMANDS_REQUESTS_QUEUE_CONSTANT = "COMMANDS_REQUESTS_QUEUE" +COMMANDS_RESPONSES_QUEUE_CONSTANT = "COMMANDS_RESPONSES_QUEUE" +COMMANDS_METADATA_QUEUE_CONSTANT = "COMMANDS_METADATA_QUEUE" class QueuingSystem(Enum): @@ -51,6 +54,20 @@ class QueuingSystem(Enum): else: return SingleProcessQueue(queue_name, default_timeout=STATE_UPDATE_TIMEOUT) + def get_commands_requests_queue(self, queue_id: Optional[str] = None) -> "BaseQueue": + queue_name = f"{queue_id}_{COMMANDS_REQUESTS_QUEUE_CONSTANT}" if queue_id else COMMANDS_REQUESTS_QUEUE_CONSTANT + return self._get_queue(queue_name) + + def get_commands_responses_queue(self, queue_id: Optional[str] = None) -> "BaseQueue": + queue_name = ( + f"{queue_id}_{COMMANDS_RESPONSES_QUEUE_CONSTANT}" if queue_id else COMMANDS_RESPONSES_QUEUE_CONSTANT + ) + return self._get_queue(queue_name) + + def get_commands_metadata_queue(self, queue_id: Optional[str] = None) -> "BaseQueue": + queue_name = f"{queue_id}_{COMMANDS_METADATA_QUEUE_CONSTANT}" if queue_id else COMMANDS_METADATA_QUEUE_CONSTANT + return self._get_queue(queue_name) + def get_readiness_queue(self, queue_id: Optional[str] = None) -> "BaseQueue": queue_name = f"{queue_id}_{READINESS_QUEUE_CONSTANT}" if queue_id else READINESS_QUEUE_CONSTANT return self._get_queue(queue_name) diff --git a/src/lightning_app/runners/backends/backend.py b/src/lightning_app/runners/backends/backend.py index 80ceb105bb..c370c7098b 100644 --- a/src/lightning_app/runners/backends/backend.py +++ b/src/lightning_app/runners/backends/backend.py @@ -82,9 +82,11 @@ class Backend(ABC): kw = dict(queue_id=self.queue_id) app.delta_queue = self.queues.get_delta_queue(**kw) app.readiness_queue = self.queues.get_readiness_queue(**kw) + app.commands_requests_queue = self.queues.get_commands_requests_queue(**kw) + app.commands_responses_queue = self.queues.get_commands_responses_queue(**kw) + app.commands_metadata_queue = self.queues.get_commands_metadata_queue(**kw) app.error_queue = self.queues.get_error_queue(**kw) app.delta_queue = self.queues.get_delta_queue(**kw) - app.readiness_queue = self.queues.get_readiness_queue(**kw) app.error_queue = self.queues.get_error_queue(**kw) app.api_publish_state_queue = self.queues.get_api_state_publish_queue(**kw) app.api_delta_queue = self.queues.get_api_delta_queue(**kw) diff --git a/src/lightning_app/runners/multiprocess.py b/src/lightning_app/runners/multiprocess.py index 4c58c816c5..92ec900d89 100644 --- a/src/lightning_app/runners/multiprocess.py +++ b/src/lightning_app/runners/multiprocess.py @@ -66,6 +66,9 @@ class MultiProcessRuntime(Runtime): api_publish_state_queue=self.app.api_publish_state_queue, api_delta_queue=self.app.api_delta_queue, has_started_queue=has_started_queue, + commands_requests_queue=self.app.commands_requests_queue, + commands_responses_queue=self.app.commands_responses_queue, + commands_metadata_queue=self.app.commands_metadata_queue, spec=extract_metadata_from_app(self.app), ) server_proc = multiprocessing.Process(target=start_server, kwargs=kwargs) diff --git a/src/lightning_app/source_code/uploader.py b/src/lightning_app/source_code/uploader.py index b3a77bc633..5816c01c3f 100644 --- a/src/lightning_app/source_code/uploader.py +++ b/src/lightning_app/source_code/uploader.py @@ -39,16 +39,15 @@ class FileUploader: self.total_size = total_size self.name = name - @staticmethod - def upload_s3_data(url: str, data: bytes, retries: int, disconnect_retry_wait_seconds: int) -> str: - """Send data to s3 url. + def upload_data(self, url: str, data: bytes, retries: int, disconnect_retry_wait_seconds: int) -> str: + """Send data to url. Parameters ---------- url: str - S3 url string to send data to + url string to send data to data: bytes - Bytes of data to send to S3 + Bytes of data to send to url retries: int Amount of retries disconnect_retry_wait_seconds: int @@ -65,16 +64,19 @@ class FileUploader: retries = Retry(total=10) with requests.Session() as s: s.mount("https://", HTTPAdapter(max_retries=retries)) - response = s.put(url, data=data) - if "ETag" not in response.headers: - raise ValueError(f"Unexpected response from S3, response: {response.content}") - return response.headers["ETag"] + return self._upload_data(s, url, data) except BrokenPipeError: time.sleep(disconnect_retry_wait_seconds) disconnect_retries -= 1 raise ValueError("Unable to upload file after multiple attempts") + def _upload_data(self, s: requests.Session, url: str, data: bytes): + resp = s.put(url, data=data) + if "ETag" not in resp.headers: + raise ValueError(f"Unexpected response from {url}, response: {resp.content}") + return resp.headers["ETag"] + def upload(self) -> None: """Upload files from source dir into target path in S3.""" task_id = self.progress.add_task("upload", filename=self.name, total=self.total_size) @@ -82,7 +84,7 @@ class FileUploader: try: with open(self.source_file, "rb") as f: data = f.read() - self.upload_s3_data(self.presigned_url, data, self.retries, self.disconnect_retry_wait_seconds) + self.upload_data(self.presigned_url, data, self.retries, self.disconnect_retry_wait_seconds) self.progress.update(task_id, advance=len(data)) finally: self.progress.stop() diff --git a/src/lightning_app/testing/testing.py b/src/lightning_app/testing/testing.py index 9e7c727756..bdf37cacf0 100644 --- a/src/lightning_app/testing/testing.py +++ b/src/lightning_app/testing/testing.py @@ -179,11 +179,7 @@ def run_app_in_cloud(app_folder: str, app_name: str = "app.py") -> Generator: # 5. Create chromium browser, auth to lightning_app.ai and yield the admin and view pages. with sync_playwright() as p: browser = p.chromium.launch(headless=bool(int(os.getenv("HEADLESS", "0")))) - payload = { - "apiKey": Config.api_key, - "username": Config.username, - "duration": "120000", - } + payload = {"apiKey": Config.api_key, "username": Config.username, "duration": "120000"} context = browser.new_context( # Eventually this will need to be deleted http_credentials=HttpCredentials( diff --git a/src/lightning_app/utilities/cli_helpers.py b/src/lightning_app/utilities/cli_helpers.py index b573440501..fcce96ec64 100644 --- a/src/lightning_app/utilities/cli_helpers.py +++ b/src/lightning_app/utilities/cli_helpers.py @@ -1,5 +1,11 @@ import re -from typing import Dict +from typing import Dict, Optional + +import requests + +from lightning_app.core.constants import APP_SERVER_PORT +from lightning_app.utilities.cloud import _get_project +from lightning_app.utilities.network import LightningClient def _format_input_env_variables(env_list: tuple) -> Dict[str, str]: @@ -35,3 +41,55 @@ def _format_input_env_variables(env_list: tuple) -> Dict[str, str]: env_vars_dict[var_name] = value return env_vars_dict + + +def _is_url(id: Optional[str]) -> bool: + if isinstance(id, str) and (id.startswith("https://") or id.startswith("http://")): + return True + return False + + +def _retrieve_application_url_and_available_commands(app_id_or_name_or_url: Optional[str]): + """This function is used to retrieve the current url associated with an id.""" + + if _is_url(app_id_or_name_or_url): + url = app_id_or_name_or_url + assert url + resp = requests.get(url + "/api/v1/commands") + if resp.status_code != 200: + raise Exception(f"The server didn't process the request properly. Found {resp.json()}") + return url, resp.json() + + # 2: If no identifier has been provided, evaluate the local application + failed_locally = False + + if app_id_or_name_or_url is None: + try: + url = f"http://localhost:{APP_SERVER_PORT}" + resp = requests.get(f"{url}/api/v1/commands") + if resp.status_code != 200: + raise Exception(f"The server didn't process the request properly. Found {resp.json()}") + return url, resp.json() + except requests.exceptions.ConnectionError: + failed_locally = True + + # 3: If an identified was provided or the local evaluation has failed, evaluate the cloud. + if app_id_or_name_or_url or failed_locally: + client = LightningClient() + project = _get_project(client) + list_lightningapps = client.lightningapp_instance_service_list_lightningapp_instances(project.project_id) + + lightningapp_names = [lightningapp.name for lightningapp in list_lightningapps.lightningapps] + + if not app_id_or_name_or_url: + raise Exception(f"Provide an application name, id or url with --app_id=X. Found {lightningapp_names}") + + for lightningapp in list_lightningapps.lightningapps: + if lightningapp.id == app_id_or_name_or_url or lightningapp.name == app_id_or_name_or_url: + if lightningapp.status.url == "": + raise Exception("The application is starting. Try in a few moments.") + resp = requests.get(lightningapp.status.url + "/api/v1/commands") + if resp.status_code != 200: + raise Exception(f"The server didn't process the request properly. Found {resp.json()}") + return lightningapp.status.url, resp.json() + return None, None diff --git a/src/lightning_app/utilities/commands/__init__.py b/src/lightning_app/utilities/commands/__init__.py new file mode 100644 index 0000000000..2ae6aba120 --- /dev/null +++ b/src/lightning_app/utilities/commands/__init__.py @@ -0,0 +1,3 @@ +from lightning_app.utilities.commands.base import ClientCommand + +__all__ = ["ClientCommand"] diff --git a/src/lightning_app/utilities/commands/base.py b/src/lightning_app/utilities/commands/base.py new file mode 100644 index 0000000000..11661e51ca --- /dev/null +++ b/src/lightning_app/utilities/commands/base.py @@ -0,0 +1,245 @@ +import errno +import inspect +import logging +import os +import os.path as osp +import shutil +import sys +from getpass import getuser +from importlib.util import module_from_spec, spec_from_file_location +from tempfile import gettempdir +from typing import Any, Callable, Dict, List, Optional, Tuple +from uuid import uuid4 + +import requests +from pydantic import BaseModel + +from lightning_app.utilities.app_helpers import is_overridden +from lightning_app.utilities.cloud import _get_project +from lightning_app.utilities.network import LightningClient +from lightning_app.utilities.state import AppState + +_logger = logging.getLogger(__name__) + + +def makedirs(path: str): + r"""Recursive directory creation function.""" + try: + os.makedirs(osp.expanduser(osp.normpath(path))) + except OSError as e: + if e.errno != errno.EEXIST and osp.isdir(path): + raise e + + +class _ClientCommandConfig(BaseModel): + command: str + affiliation: str + params: Dict[str, str] + is_client_command: bool + cls_path: str + cls_name: str + owner: str + requirements: Optional[List[str]] + + +class ClientCommand: + def __init__(self, method: Callable, requirements: Optional[List[str]] = None) -> None: + self.method = method + flow = getattr(method, "__self__", None) + self.owner = flow.name if flow else None + self.requirements = requirements + self.metadata = None + self.models: Optional[Dict[str, BaseModel]] = None + self.app_url = None + self._state = None + + def _setup(self, metadata: Dict[str, Any], models: Dict[str, BaseModel], app_url: str) -> None: + self.metadata = metadata + self.models = models + self.app_url = app_url + + @property + def state(self): + if self._state is None: + assert self.app_url + # TODO: Resolve this hack + os.environ["LIGHTNING_APP_STATE_URL"] = "1" + self._state = AppState(host=self.app_url) + self._state._request_state() + os.environ.pop("LIGHTNING_APP_STATE_URL") + return self._state + + def run(self, **cli_kwargs) -> None: + """Overrides with the logic to execute on the client side.""" + + def invoke_handler(self, **kwargs: Any) -> Dict[str, Any]: + from lightning.app.utilities.state import headers_for + + assert kwargs.keys() == self.models.keys() + for k, v in kwargs.items(): + assert isinstance(v, self.models[k]) + json = { + "command_name": self.metadata["command"], + "command_arguments": {k: v.json() for k, v in kwargs.items()}, + "affiliation": self.metadata["affiliation"], + "id": str(uuid4()), + } + resp = requests.post(self.app_url + "/api/v1/commands", json=json, headers=headers_for({})) + assert resp.status_code == 200, resp.json() + return resp.json() + + def _to_dict(self): + return {"owner": self.owner, "requirements": self.requirements} + + def __call__(self, **kwargs: Any) -> Any: + assert self.models + input = {} + for k, v in kwargs.items(): + input[k] = self.models[k].parse_raw(v) + return self.method(**input) + + +def _download_command( + command_metadata: Dict[str, Any], + app_id: Optional[str], + debug_mode: bool = False, +) -> Tuple[ClientCommand, Dict[str, BaseModel]]: + # TODO: This is a skateboard implementation and the final version will rely on versioned + # immutable commands for security concerns + config = _ClientCommandConfig(**command_metadata) + tmpdir = osp.join(gettempdir(), f"{getuser()}_commands") + makedirs(tmpdir) + target_file = osp.join(tmpdir, f"{config.command}.py") + if app_id: + client = LightningClient() + project_id = _get_project(client).project_id + response = client.lightningapp_instance_service_list_lightningapp_instance_artifacts(project_id, app_id) + for artifact in response.artifacts: + if f"commands/{config.command}.py" == artifact.filename: + r = requests.get(artifact.url, allow_redirects=True) + with open(target_file, "wb") as f: + f.write(r.content) + else: + if not debug_mode: + shutil.copy(config.cls_path, target_file) + + cls_name = config.cls_name + spec = spec_from_file_location(config.cls_name, config.cls_path if debug_mode else target_file) + mod = module_from_spec(spec) + sys.modules[cls_name] = mod + spec.loader.exec_module(mod) + command = getattr(mod, cls_name)(method=None, requirements=config.requirements) + models = {k: getattr(mod, v) for k, v in config.params.items()} + if debug_mode: + shutil.rmtree(tmpdir) + return command, models + + +def _to_annotation(anno: str) -> str: + anno = anno.split("'")[1] + if "." in anno: + return anno.split(".")[-1] + return anno + + +def _command_to_method_and_metadata(command: ClientCommand) -> Tuple[Callable, Dict[str, Any]]: + """Extract method and its metadata from a ClientCommand.""" + params = inspect.signature(command.method).parameters + command_metadata = { + "cls_path": inspect.getfile(command.__class__), + "cls_name": command.__class__.__name__, + "params": {p.name: _to_annotation(str(p.annotation)) for p in params.values()}, + **command._to_dict(), + } + method = command.method + command.models = {} + for k, v in command_metadata["params"].items(): + if v == "_empty": + raise Exception( + f"Please, annotate your method {method} with pydantic BaseModel. Refer to the documentation." + ) + config = getattr(sys.modules[command.__module__], v, None) + if config is None: + config = getattr(sys.modules[method.__module__], v, None) + if config: + raise Exception( + f"The provided annotation for the argument {k} should in the file " + f"{inspect.getfile(command.__class__)}, not {inspect.getfile(command.method)}." + ) + if config is None or not issubclass(config, BaseModel): + raise Exception( + f"The provided annotation for the argument {k} shouldn't an instance of pydantic BaseModel." + ) + command.models[k] = config + return method, command_metadata + + +def _upload_command(command_name: str, command: ClientCommand) -> Optional[str]: + from lightning_app.storage.path import _is_s3fs_available, filesystem, shared_storage_path + + filepath = f"commands/{command_name}.py" + remote_url = str(shared_storage_path() / "artifacts" / filepath) + fs = filesystem() + + if _is_s3fs_available(): + from s3fs import S3FileSystem + + if not isinstance(fs, S3FileSystem): + return + source_file = str(inspect.getfile(command.__class__)) + remote_url = str(shared_storage_path() / "artifacts" / filepath) + fs.put(source_file, remote_url) + return filepath + + +def _populate_commands_endpoint(app): + if not is_overridden("configure_commands", app.root): + return + + # 1: Populate commands metadata + commands = app.root.configure_commands() + commands_metadata = [] + command_names = set() + for command_mapping in commands: + for command_name, command in command_mapping.items(): + is_client_command = isinstance(command, ClientCommand) + extras = {} + if is_client_command: + _upload_command(command_name, command) + command, extras = _command_to_method_and_metadata(command) + if command_name in command_names: + raise Exception(f"The component name {command_name} has already been used. They need to be unique.") + command_names.add(command_name) + params = inspect.signature(command).parameters + commands_metadata.append( + { + "command": command_name, + "affiliation": command.__self__.name, + "params": list(params.keys()), + "is_client_command": is_client_command, + **extras, + } + ) + + # 1.2: Pass the collected commands through the queue to the Rest API. + app.commands_metadata_queue.put(commands_metadata) + app.commands = commands + + +def _process_command_requests(app): + if not is_overridden("configure_commands", app.root): + return + + # 1: Populate commands metadata + commands = app.commands + + # 2: Collect requests metadata + command_query = app.get_state_changed_from_queue(app.commands_requests_queue) + if command_query: + for command in commands: + for command_name, method in command.items(): + if command_query["command_name"] == command_name: + # 2.1: Evaluate the method associated to a specific command. + # Validation is done on the CLI side. + response = method(**command_query["command_arguments"]) + app.commands_responses_queue.put({"response": response, "id": command_query["id"]}) diff --git a/src/lightning_app/utilities/packaging/lightning_utils.py b/src/lightning_app/utilities/packaging/lightning_utils.py index ae26d39ec5..37f4ff2298 100644 --- a/src/lightning_app/utilities/packaging/lightning_utils.py +++ b/src/lightning_app/utilities/packaging/lightning_utils.py @@ -15,7 +15,7 @@ from packaging.version import Version from lightning_app import _logger, _PROJECT_ROOT, _root_logger from lightning_app.__version__ import version -from lightning_app.core.constants import PREPARE_LIGHTING +from lightning_app.core.constants import PACKAGE_LIGHTNING from lightning_app.utilities.git import check_github_repository, get_dir_name logger = logging.getLogger(__name__) @@ -96,11 +96,13 @@ def _prepare_lightning_wheels_and_requirements(root: Path) -> Optional[Callable] # Packaging the Lightning codebase happens only inside the `lightning` repo. git_dir_name = get_dir_name() if check_github_repository() else None - if not PREPARE_LIGHTING and (not git_dir_name or (git_dir_name and not git_dir_name.startswith("lightning"))): + is_lightning = git_dir_name and git_dir_name == "lightning" + + if (PACKAGE_LIGHTNING is None and not is_lightning) or PACKAGE_LIGHTNING == "0": return - if not bool(int(os.getenv("SKIP_LIGHTING_WHEELS_BUILD", "0"))): - download_frontend(_PROJECT_ROOT) - _prepare_wheel(_PROJECT_ROOT) + + download_frontend(_PROJECT_ROOT) + _prepare_wheel(_PROJECT_ROOT) logger.info("Packaged Lightning with your application.") @@ -108,11 +110,12 @@ def _prepare_lightning_wheels_and_requirements(root: Path) -> Optional[Callable] tar_files = [os.path.join(root, tar_name)] - # skipping this by default - if not bool(int(os.getenv("SKIP_LIGHTING_UTILITY_WHEELS_BUILD", "1"))): + # Don't skip by default + if (PACKAGE_LIGHTNING or is_lightning) and not bool(int(os.getenv("SKIP_LIGHTING_UTILITY_WHEELS_BUILD", "0"))): # building and copying launcher wheel if installed in editable mode launcher_project_path = get_dist_path_if_editable_install("lightning_launcher") if launcher_project_path: + logger.info("Packaged Lightning Launcher with your application.") _prepare_wheel(launcher_project_path) tar_name = _copy_tar(launcher_project_path, root) tar_files.append(os.path.join(root, tar_name)) @@ -120,6 +123,7 @@ def _prepare_lightning_wheels_and_requirements(root: Path) -> Optional[Callable] # building and copying lightning-cloud wheel if installed in editable mode lightning_cloud_project_path = get_dist_path_if_editable_install("lightning_cloud") if lightning_cloud_project_path: + logger.info("Packaged Lightning Cloud with your application.") _prepare_wheel(lightning_cloud_project_path) tar_name = _copy_tar(lightning_cloud_project_path, root) tar_files.append(os.path.join(root, tar_name)) diff --git a/src/lightning_app/utilities/proxies.py b/src/lightning_app/utilities/proxies.py index ead681bff7..c33e41bb70 100644 --- a/src/lightning_app/utilities/proxies.py +++ b/src/lightning_app/utilities/proxies.py @@ -5,6 +5,7 @@ import signal import sys import threading import time +import traceback import warnings from copy import deepcopy from dataclasses import dataclass @@ -398,6 +399,9 @@ class WorkRunner: ) self.delta_queue.put(ComponentDelta(id=self.work_name, delta=Delta(DeepDiff(state, self.work.state)))) self.work.on_exception(e) + print("########## CAPTURED EXCEPTION ###########") + print(traceback.print_exc()) + print("########## CAPTURED EXCEPTION ###########") return # 14. Copy all artifacts to the shared storage so other Works can access them while this Work gets scaled down diff --git a/tests/tests_app/cli/test_cli.py b/tests/tests_app/cli/test_cli.py index 2626116990..39d8d6b789 100644 --- a/tests/tests_app/cli/test_cli.py +++ b/tests/tests_app/cli/test_cli.py @@ -5,7 +5,7 @@ import pytest from click.testing import CliRunner from lightning_cloud.openapi import Externalv1LightningappInstance -from lightning_app.cli.lightning_cli import get_app_url, login, logout, main, run +from lightning_app.cli.lightning_cli import _main, get_app_url, login, logout, run from lightning_app.runners.runtime_type import RuntimeType @@ -37,7 +37,7 @@ def test_start_target_url(runtime_type, extra_args, lightning_cloud_url, expecte assert get_app_url(runtime_type, *extra_args) == expected_url -@pytest.mark.parametrize("command", [main, run]) +@pytest.mark.parametrize("command", [_main, run]) def test_commands(command): runner = CliRunner() result = runner.invoke(command) @@ -46,12 +46,12 @@ def test_commands(command): def test_main_lightning_cli_help(): """Validate the Lightning CLI.""" - res = os.popen("python -m lightning_app --help").read() + res = os.popen("python -m lightning --help").read() assert "login " in res assert "logout " in res assert "run " in res - res = os.popen("python -m lightning_app run --help").read() + res = os.popen("python -m lightning run --help").read() assert "app " in res # hidden run commands should not appear in the help text diff --git a/tests/tests_app/components/python/test_python.py b/tests/tests_app/components/python/test_python.py index 283f449092..61969ef1c4 100644 --- a/tests/tests_app/components/python/test_python.py +++ b/tests/tests_app/components/python/test_python.py @@ -17,10 +17,9 @@ def test_non_existing_python_script(): run_work_isolated(python_script) assert not python_script.has_started - with pytest.raises(FileNotFoundError, match=match): - python_script = TracerPythonScript(match) - run_work_isolated(python_script) - assert not python_script.has_started + python_script = TracerPythonScript(match, raise_exception=False) + run_work_isolated(python_script) + assert python_script.has_failed def test_simple_python_script(): diff --git a/tests/tests_app/core/test_lightning_api.py b/tests/tests_app/core/test_lightning_api.py index 81ba6fe0ba..9de7c63051 100644 --- a/tests/tests_app/core/test_lightning_api.py +++ b/tests/tests_app/core/test_lightning_api.py @@ -161,10 +161,12 @@ def test_update_publish_state_and_maybe_refresh_ui(): app = AppStageTestingApp(FlowA(), debug=True) publish_state_queue = MockQueue("publish_state_queue") + commands_metadata_queue = MockQueue("commands_metadata_queue") + commands_responses_queue = MockQueue("commands_metadata_queue") publish_state_queue.put(app.state_with_changes) - thread = UIRefresher(publish_state_queue) + thread = UIRefresher(publish_state_queue, commands_metadata_queue, commands_responses_queue) thread.run_once() assert global_app_state_store.get_app_state("1234") == app.state_with_changes @@ -190,11 +192,21 @@ async def test_start_server(x_lightning_type): publish_state_queue = InfiniteQueue("publish_state_queue") change_state_queue = MockQueue("change_state_queue") has_started_queue = MockQueue("has_started_queue") + commands_requests_queue = MockQueue("commands_requests_queue") + commands_responses_queue = MockQueue("commands_responses_queue") + commands_metadata_queue = MockQueue("commands_metadata_queue") state = app.state_with_changes publish_state_queue.put(state) spec = extract_metadata_from_app(app) ui_refresher = start_server( - publish_state_queue, change_state_queue, has_started_queue=has_started_queue, uvicorn_run=False, spec=spec + publish_state_queue, + change_state_queue, + commands_requests_queue, + commands_responses_queue, + commands_metadata_queue, + has_started_queue=has_started_queue, + uvicorn_run=False, + spec=spec, ) headers = headers_for({"type": x_lightning_type}) @@ -331,10 +343,16 @@ def test_start_server_started(): api_publish_state_queue = mp.Queue() api_delta_queue = mp.Queue() has_started_queue = mp.Queue() + commands_requests_queue = mp.Queue() + commands_responses_queue = mp.Queue() + commands_metadata_queue = mp.Queue() kwargs = dict( api_publish_state_queue=api_publish_state_queue, api_delta_queue=api_delta_queue, has_started_queue=has_started_queue, + commands_requests_queue=commands_requests_queue, + commands_responses_queue=commands_responses_queue, + commands_metadata_queue=commands_metadata_queue, port=1111, ) @@ -354,12 +372,18 @@ def test_start_server_info_message(ui_refresher, uvicorn_run, caplog, monkeypatc api_publish_state_queue = MockQueue() api_delta_queue = MockQueue() has_started_queue = MockQueue() + commands_requests_queue = MockQueue() + commands_responses_queue = MockQueue() + commands_metadata_queue = MockQueue() kwargs = dict( host=host, port=1111, api_publish_state_queue=api_publish_state_queue, api_delta_queue=api_delta_queue, has_started_queue=has_started_queue, + commands_requests_queue=commands_requests_queue, + commands_responses_queue=commands_responses_queue, + commands_metadata_queue=commands_metadata_queue, ) monkeypatch.setattr(api, "logger", logging.getLogger()) diff --git a/tests/tests_app/source_code/test_uploader.py b/tests/tests_app/source_code/test_uploader.py index 82789e83e3..774442291d 100644 --- a/tests/tests_app/source_code/test_uploader.py +++ b/tests/tests_app/source_code/test_uploader.py @@ -39,10 +39,11 @@ def test_file_uploader(): @mock.patch("lightning_app.source_code.uploader.requests.Session", MockedRequestSession) def test_file_uploader_failing_when_no_etag(): response["response"] = MagicMock(headers={}) + presigned_url = "https://test-url" file_uploader = uploader.FileUploader( - presigned_url="https://test-url", source_file="test.txt", total_size=100, name="test.txt" + presigned_url=presigned_url, source_file="test.txt", total_size=100, name="test.txt" ) file_uploader.progress = MagicMock() - with pytest.raises(ValueError, match="Unexpected response from S3, response"): + with pytest.raises(ValueError, match=f"Unexpected response from {presigned_url}, response"): file_uploader.upload() diff --git a/tests/tests_app/utilities/test_commands.py b/tests/tests_app/utilities/test_commands.py new file mode 100644 index 0000000000..1e8e36ed09 --- /dev/null +++ b/tests/tests_app/utilities/test_commands.py @@ -0,0 +1,162 @@ +import argparse +import sys +from multiprocessing import Process +from time import sleep +from unittest.mock import MagicMock + +import pytest +import requests +from pydantic import BaseModel + +from lightning import LightningFlow +from lightning_app import LightningApp +from lightning_app.cli.lightning_cli import app_command +from lightning_app.core.constants import APP_SERVER_PORT +from lightning_app.runners import MultiProcessRuntime +from lightning_app.testing.helpers import RunIf +from lightning_app.utilities.commands.base import _command_to_method_and_metadata, _download_command, ClientCommand +from lightning_app.utilities.state import AppState + + +class SweepConfig(BaseModel): + sweep_name: str + num_trials: int + + +class SweepCommand(ClientCommand): + def run(self) -> None: + print(sys.argv) + parser = argparse.ArgumentParser() + parser.add_argument("--sweep_name", type=str) + parser.add_argument("--num_trials", type=int) + hparams = parser.parse_args() + + config = SweepConfig(sweep_name=hparams.sweep_name, num_trials=hparams.num_trials) + response = self.invoke_handler(config=config) + assert response is True + + +class FlowCommands(LightningFlow): + def __init__(self): + super().__init__() + self.names = [] + self.has_sweep = False + + def run(self): + if self.has_sweep and len(self.names) == 1: + sleep(2) + self._exit() + + def trigger_method(self, name: str): + self.names.append(name) + + def sweep(self, config: SweepConfig): + self.has_sweep = True + return True + + def configure_commands(self): + return [{"user_command": self.trigger_method}, {"sweep": SweepCommand(self.sweep)}] + + +class DummyConfig(BaseModel): + something: str + something_else: int + + +class DummyCommand(ClientCommand): + def run(self, something: str, something_else: int) -> None: + config = DummyConfig(something=something, something_else=something_else) + response = self.invoke_handler(config=config) + assert response == {"body": 0} + + +def run(config: DummyConfig): + assert isinstance(config, DummyCommand) + + +def run_failure_0(name: str): + pass + + +def run_failure_1(name): + pass + + +class CustomModel(BaseModel): + pass + + +def run_failure_2(name: CustomModel): + pass + + +@RunIf(skip_windows=True) +def test_command_to_method_and_metadata(): + with pytest.raises(Exception, match="The provided annotation for the argument name"): + _command_to_method_and_metadata(ClientCommand(run_failure_0)) + + with pytest.raises(Exception, match="annotate your method"): + _command_to_method_and_metadata(ClientCommand(run_failure_1)) + + with pytest.raises(Exception, match="lightning_app/utilities/commands/base.py"): + _command_to_method_and_metadata(ClientCommand(run_failure_2)) + + +def test_client_commands(monkeypatch): + import requests + + resp = MagicMock() + resp.status_code = 200 + value = {"body": 0} + resp.json = MagicMock(return_value=value) + post = MagicMock() + post.return_value = resp + monkeypatch.setattr(requests, "post", post) + url = "http//" + kwargs = {"something": "1", "something_else": "1"} + command = DummyCommand(run) + _, command_metadata = _command_to_method_and_metadata(command) + command_metadata.update( + { + "command": "dummy", + "affiliation": "root", + "is_client_command": True, + "owner": "root", + } + ) + client_command, models = _download_command(command_metadata, None) + client_command._setup(metadata=command_metadata, models=models, app_url=url) + client_command.run(**kwargs) + + +def target(): + app = LightningApp(FlowCommands()) + MultiProcessRuntime(app).dispatch() + + +def test_configure_commands(monkeypatch): + process = Process(target=target) + process.start() + time_left = 15 + while time_left > 0: + try: + requests.get(f"http://localhost:{APP_SERVER_PORT}/healthz") + break + except requests.exceptions.ConnectionError: + sleep(0.1) + time_left -= 0.1 + + sleep(0.5) + monkeypatch.setattr(sys, "argv", ["lightning", "user_command", "--name=something"]) + app_command() + sleep(0.5) + state = AppState() + state._request_state() + assert state.names == ["something"] + monkeypatch.setattr(sys, "argv", ["lightning", "sweep", "--sweep_name", "my_name", "--num_trials", "1"]) + app_command() + time_left = 15 + while time_left > 0 or process.exitcode is None: + sleep(0.1) + time_left -= 0.1 + assert process.exitcode == 0 diff --git a/tests/tests_app/utilities/test_state.py b/tests/tests_app/utilities/test_state.py index e275817f68..0740ffc615 100644 --- a/tests/tests_app/utilities/test_state.py +++ b/tests/tests_app/utilities/test_state.py @@ -15,7 +15,7 @@ from lightning_app.utilities.state import AppState def test_app_state_not_connected(_): """Test an error message when a disconnected AppState tries to access attributes.""" - state = AppState() + state = AppState(port=8000) with pytest.raises(AttributeError, match="Failed to connect and fetch the app state"): _ = state.value with pytest.raises(AttributeError, match="Failed to connect and fetch the app state"): @@ -209,7 +209,7 @@ def test_attach_plugin(): @mock.patch("lightning_app.utilities.state._configure_session", return_value=requests) def test_app_state_connection_error(_): """Test an error message when a connection to retrieve the state can't be established.""" - app_state = AppState() + app_state = AppState(port=8000) with pytest.raises(AttributeError, match=r"Failed to connect and fetch the app state\. Is the app running?"): app_state._request_state() diff --git a/tests/tests_app_examples/test_commands.py b/tests/tests_app_examples/test_commands.py new file mode 100644 index 0000000000..5116b1b9d5 --- /dev/null +++ b/tests/tests_app_examples/test_commands.py @@ -0,0 +1,31 @@ +import os +from subprocess import Popen +from time import sleep +from unittest import mock + +import pytest +from tests_app import _PROJECT_ROOT + +from lightning_app.testing.testing import run_app_in_cloud + + +@mock.patch.dict(os.environ, {"SKIP_LIGHTING_UTILITY_WHEELS_BUILD": "0"}) +@pytest.mark.cloud +def test_commands_example_cloud() -> None: + with run_app_in_cloud(os.path.join(_PROJECT_ROOT, "examples/app_commands")) as ( + admin_page, + _, + fetch_logs, + ): + app_id = admin_page.url.split("/")[-1] + cmd = f"lightning trigger_with_client_command --name=something --app_id {app_id}" + Popen(cmd, shell=True).wait() + cmd = f"lightning trigger_without_client_command --name=else --app_id {app_id}" + Popen(cmd, shell=True).wait() + + has_logs = False + while not has_logs: + for log in fetch_logs(): + if "['something', 'else']" in log: + has_logs = True + sleep(1)