[App] Introduce Commands (#13602)

This commit is contained in:
thomas chaton 2022-07-25 19:13:46 +02:00 committed by GitHub
parent a8d7b4476c
commit 4c35867b61
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
29 changed files with 858 additions and 63 deletions

View File

@ -54,6 +54,7 @@ jobs:
- custom_work_dependencies
- drive
- payload
- commands
timeout-minutes: 35
steps:
- uses: actions/checkout@v2
@ -155,7 +156,7 @@ jobs:
shell: bash
run: |
mkdir -p ${VIDEO_LOCATION}
HEADLESS=1 python -m pytest tests/tests_app_examples/test_${{ matrix.app_name }}.py::test_${{ matrix.app_name }}_example_cloud --timeout=900 --capture=no -v --color=yes
HEADLESS=1 PACKAGE_LIGHTNING=1 python -m pytest tests/tests_app_examples/test_${{ matrix.app_name }}.py::test_${{ matrix.app_name }}_example_cloud --timeout=900 --capture=no -v --color=yes
# Delete the artifacts if successful
rm -r ${VIDEO_LOCATION}/${{ matrix.app_name }}

3
.gitignore vendored
View File

@ -109,6 +109,7 @@ celerybeat-schedule
# dotenv
.env
.env_stagging
# virtualenv
.venv
@ -160,3 +161,5 @@ tags
.tags
src/lightning_app/ui/*
*examples/template_react_ui*
hars*
artifacts/*

View File

@ -0,0 +1 @@
name: app-commands

View File

@ -0,0 +1,39 @@
from command import CustomCommand, CustomConfig
from lightning import LightningFlow
from lightning_app.core.app import LightningApp
class ChildFlow(LightningFlow):
def trigger_method(self, name: str):
print(f"Hello {name}")
def configure_commands(self):
return [{"nested_trigger_command": self.trigger_method}]
class FlowCommands(LightningFlow):
def __init__(self):
super().__init__()
self.names = []
self.child_flow = ChildFlow()
def run(self):
if len(self.names):
print(self.names)
def trigger_without_client_command(self, name: str):
self.names.append(name)
def trigger_with_client_command(self, config: CustomConfig):
self.names.append(config.name)
def configure_commands(self):
commands = [
{"trigger_without_client_command": self.trigger_without_client_command},
{"trigger_with_client_command": CustomCommand(self.trigger_with_client_command)},
]
return commands + self.child_flow.configure_commands()
app = LightningApp(FlowCommands())

View File

@ -0,0 +1,17 @@
from argparse import ArgumentParser
from pydantic import BaseModel
from lightning.app.utilities.commands import ClientCommand
class CustomConfig(BaseModel):
name: str
class CustomCommand(ClientCommand):
def run(self):
parser = ArgumentParser()
parser.add_argument("--name", type=str)
args = parser.parse_args()
self.invoke_handler(config=CustomConfig(name=args.name))

View File

@ -8,6 +8,10 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
### Added
- Add support for `Lightning App Commands` through the `configure_commands` hook on the Lightning Flow and the `ClientCommand` ([#13602](https://github.com/Lightning-AI/lightning/pull/13602))
### Changed
- Update the Lightning App docs ([#13537](https://github.com/PyTorchLightning/pytorch-lightning/pull/13537))
### Changed

View File

@ -1,9 +1,13 @@
import logging
import os
import sys
from argparse import ArgumentParser
from pathlib import Path
from typing import List, Tuple, Union
from uuid import uuid4
import click
import requests
from requests.exceptions import ConnectionError
from lightning_app import __version__ as ver
@ -11,9 +15,13 @@ from lightning_app.cli import cmd_init, cmd_install, cmd_pl_init, cmd_react_ui_i
from lightning_app.core.constants import get_lightning_cloud_url, LOCAL_LAUNCH_ADMIN_VIEW
from lightning_app.runners.runtime import dispatch
from lightning_app.runners.runtime_type import RuntimeType
from lightning_app.utilities.cli_helpers import _format_input_env_variables
from lightning_app.utilities.cli_helpers import (
_format_input_env_variables,
_retrieve_application_url_and_available_commands,
)
from lightning_app.utilities.install_components import register_all_external_components
from lightning_app.utilities.login import Auth
from lightning_app.utilities.state import headers_for
logger = logging.getLogger(__name__)
@ -26,14 +34,23 @@ def get_app_url(runtime_type: RuntimeType, *args) -> str:
return "http://127.0.0.1:7501/admin" if LOCAL_LAUNCH_ADMIN_VIEW else "http://127.0.0.1:7501/view"
def main():
if len(sys.argv) == 1:
_main()
elif sys.argv[1] in _main.commands.keys() or sys.argv[1] == "--help":
_main()
else:
app_command()
@click.group()
@click.version_option(ver)
def main():
def _main():
register_all_external_components()
pass
@main.command()
@_main.command()
def login():
"""Log in to your Lightning.ai account."""
auth = Auth()
@ -46,7 +63,7 @@ def login():
exit(1)
@main.command()
@_main.command()
def logout():
"""Log out of your Lightning.ai account."""
Auth().clear()
@ -93,7 +110,7 @@ def _run_app(
click.echo("Application is ready in the cloud")
@main.group()
@_main.group()
def run():
"""Run your application."""
@ -125,31 +142,83 @@ def run_app(
_run_app(file, cloud, without_server, no_cache, name, blocking, open_ui, env)
@main.group(hidden=True)
def app_command():
"""Execute a function in a running application from its name."""
from lightning_app.utilities.commands.base import _download_command
logger.warn("Lightning Commands are a beta feature and APIs aren't stable yet.")
debug_mode = bool(int(os.getenv("DEBUG", "0")))
parser = ArgumentParser()
parser.add_argument("--app_id", default=None, type=str, help="Optional argument to identify an application.")
hparams, argv = parser.parse_known_args()
# 1: Collect the url and comments from the running application
url, commands = _retrieve_application_url_and_available_commands(hparams.app_id)
if url is None or commands is None:
raise Exception("We couldn't find any matching running app.")
if not commands:
raise Exception("This application doesn't expose any commands yet.")
command = argv[0]
command_names = [c["command"] for c in commands]
if command not in command_names:
raise Exception(f"The provided command {command} isn't available in {command_names}")
# 2: Send the command from the user
command_metadata = [c for c in commands if c["command"] == command][0]
params = command_metadata["params"]
# 3: Execute the command
if not command_metadata["is_client_command"]:
# TODO: Improve what is supported there.
kwargs = {k.split("=")[0].replace("--", ""): k.split("=")[1] for k in argv[1:]}
for param in params:
if param not in kwargs:
raise Exception(f"The argument --{param}=X hasn't been provided.")
json = {
"command_name": command,
"command_arguments": kwargs,
"affiliation": command_metadata["affiliation"],
"id": str(uuid4()),
}
resp = requests.post(url + "/api/v1/commands", json=json, headers=headers_for({}))
assert resp.status_code == 200, resp.json()
else:
client_command, models = _download_command(command_metadata, hparams.app_id, debug_mode=debug_mode)
client_command._setup(metadata=command_metadata, models=models, app_url=url)
sys.argv = argv
client_command.run()
@_main.group(hidden=True)
def fork():
"""Fork an application."""
pass
@main.group(hidden=True)
@_main.group(hidden=True)
def stop():
"""Stop your application."""
pass
@main.group(hidden=True)
@_main.group(hidden=True)
def delete():
"""Delete an application."""
pass
@main.group(name="list", hidden=True)
@_main.group(name="list", hidden=True)
def get_list():
"""List your applications."""
pass
@main.group()
@_main.group()
def install():
"""Install Lightning apps and components."""
@ -207,7 +276,7 @@ def install_component(name, yes, version):
cmd_install.gallery_component(name, yes, version)
@main.group()
@_main.group()
def init():
"""Init a Lightning app and component."""

View File

@ -93,8 +93,6 @@ class TracerPythonScript(LightningWork):
:language: python
"""
super().__init__(**kwargs)
if not os.path.exists(script_path):
raise FileNotFoundError(f"The provided `script_path` {script_path}` wasn't found.")
self.script_path = str(script_path)
if isinstance(script_args, str):
script_args = script_args.split(" ")
@ -105,6 +103,8 @@ class TracerPythonScript(LightningWork):
setattr(self, name, None)
def run(self, **kwargs):
if not os.path.exists(self.script_path):
raise FileNotFoundError(f"The provided `script_path` {self.script_path}` wasn't found.")
kwargs = {k: v.value if isinstance(v, Payload) else v for k, v in kwargs.items()}
init_globals = globals()
init_globals.update(kwargs)

View File

@ -3,6 +3,8 @@ import logging
import os
import queue
import sys
import time
import traceback
from copy import deepcopy
from multiprocessing import Queue
from threading import Event, Lock, Thread
@ -40,6 +42,10 @@ STATE_EVENT = "State changed"
frontend_static_dir = os.path.join(FRONTEND_DIR, "static")
api_app_delta_queue: Queue = None
api_commands_requests_queue: Queue = None
api_commands_metadata_queue: Queue = None
api_commands_responses_queue: Queue = None
template = {"ui": {}, "app": {}}
templates = Jinja2Templates(directory=FRONTEND_DIR)
@ -50,6 +56,8 @@ global_app_state_store.add(TEST_SESSION_UUID)
lock = Lock()
app_spec: Optional[List] = None
app_commands_metadata: Optional[Dict] = None
commands_response_store = {}
logger = logging.getLogger(__name__)
@ -59,16 +67,22 @@ logger = logging.getLogger(__name__)
class UIRefresher(Thread):
def __init__(self, api_publish_state_queue) -> None:
def __init__(self, api_publish_state_queue, api_commands_metadata_queue, api_commands_responses_queue) -> None:
super().__init__(daemon=True)
self.api_publish_state_queue = api_publish_state_queue
self.api_commands_metadata_queue = api_commands_metadata_queue
self.api_commands_responses_queue = api_commands_responses_queue
self._exit_event = Event()
def run(self):
# TODO: Create multiple threads to handle the background logic
# TODO: Investigate the use of `parallel=True`
while not self._exit_event.is_set():
self.run_once()
try:
while not self._exit_event.is_set():
self.run_once()
except Exception as e:
logger.error(traceback.print_exc())
raise e
def run_once(self):
try:
@ -78,6 +92,22 @@ class UIRefresher(Thread):
except queue.Empty:
pass
try:
metadata = self.api_commands_metadata_queue.get(timeout=0)
with lock:
global app_commands_metadata
app_commands_metadata = metadata
except queue.Empty:
pass
try:
response = self.api_commands_responses_queue.get(timeout=0)
with lock:
global commands_response_store
commands_response_store[response["id"]] = response["response"]
except queue.Empty:
pass
def join(self, timeout: Optional[float] = None) -> None:
self._exit_event.set()
super().join(timeout)
@ -146,6 +176,43 @@ async def get_spec(
return app_spec or []
@fastapi_service.post("/api/v1/commands", response_class=JSONResponse)
async def run_remote_command(
request: Request,
) -> None:
data = await request.json()
command_name = data.get("command_name", None)
if not command_name:
raise Exception("The provided command name is empty.")
command_arguments = data.get("command_arguments", None)
if not command_arguments:
raise Exception("The provided command metadata is empty.")
affiliation = data.get("affiliation", None)
if not affiliation:
raise Exception("The provided affiliation is empty.")
async def fn(data):
request_id = data["id"]
api_commands_requests_queue.put(data)
t0 = time.time()
while request_id not in commands_response_store:
await asyncio.sleep(0.1)
if (time.time() - t0) > 15:
raise Exception("The response was never received.")
return commands_response_store[request_id]
return await asyncio.create_task(fn(data))
@fastapi_service.get("/api/v1/commands", response_class=JSONResponse)
async def get_commands() -> Optional[Dict]:
global app_commands_metadata
with lock:
return app_commands_metadata
@fastapi_service.post("/api/v1/delta")
async def post_delta(
request: Request,
@ -279,6 +346,9 @@ class LightningUvicornServer(uvicorn.Server):
def start_server(
api_publish_state_queue,
api_delta_queue,
commands_requests_queue,
commands_responses_queue,
commands_metadata_queue,
has_started_queue: Optional[Queue] = None,
host="127.0.0.1",
port=8000,
@ -288,16 +358,22 @@ def start_server(
):
global api_app_delta_queue
global global_app_state_store
global api_commands_requests_queue
global api_commands_responses_queue
global app_spec
app_spec = spec
api_app_delta_queue = api_delta_queue
api_commands_requests_queue = commands_requests_queue
api_commands_responses_queue = commands_responses_queue
api_commands_metadata_queue = commands_metadata_queue
if app_state_store is not None:
global_app_state_store = app_state_store
global_app_state_store.add(TEST_SESSION_UUID)
refresher = UIRefresher(api_publish_state_queue)
refresher = UIRefresher(api_publish_state_queue, api_commands_metadata_queue, commands_responses_queue)
refresher.setDaemon(True)
refresher.start()

View File

@ -16,6 +16,7 @@ from lightning_app.core.queues import BaseQueue, SingleProcessQueue
from lightning_app.frontend import Frontend
from lightning_app.storage.path import storage_root_dir
from lightning_app.utilities.app_helpers import _delta_to_appstate_delta, _LightningAppRef
from lightning_app.utilities.commands.base import _populate_commands_endpoint, _process_command_requests
from lightning_app.utilities.component import _convert_paths_after_init
from lightning_app.utilities.enum import AppStage
from lightning_app.utilities.exceptions import CacheMissException, ExitAppException
@ -72,6 +73,9 @@ class LightningApp:
# queues definition.
self.delta_queue: t.Optional[BaseQueue] = None
self.readiness_queue: t.Optional[BaseQueue] = None
self.commands_requests_queue: t.Optional[BaseQueue] = None
self.commands_responses_queue: t.Optional[BaseQueue] = None
self.commands_metadata_queue: t.Optional[BaseQueue] = None
self.api_publish_state_queue: t.Optional[BaseQueue] = None
self.api_delta_queue: t.Optional[BaseQueue] = None
self.error_queue: t.Optional[BaseQueue] = None
@ -81,6 +85,7 @@ class LightningApp:
self.copy_response_queues: t.Optional[t.Dict[str, BaseQueue]] = None
self.caller_queues: t.Optional[t.Dict[str, BaseQueue]] = None
self.work_queues: t.Optional[t.Dict[str, BaseQueue]] = None
self.commands: t.Optional[t.List] = None
self.should_publish_changes_to_api = False
self.component_affiliation = None
@ -345,6 +350,8 @@ class LightningApp:
elif self.stage == AppStage.RESTARTING:
return self._apply_restarting()
_process_command_requests(self)
try:
self.check_error_queue()
t0 = time()
@ -397,6 +404,8 @@ class LightningApp:
self._reset_run_time_monitor()
_populate_commands_endpoint(self)
while not done:
done = self.run_once()

View File

@ -22,7 +22,7 @@ REDIS_QUEUES_READ_DEFAULT_TIMEOUT = 0.005
REDIS_WARNING_QUEUE_SIZE = 1000
USER_ID = os.getenv("USER_ID", "1234")
FRONTEND_DIR = os.path.join(os.path.dirname(lightning_app.__file__), "ui")
PREPARE_LIGHTING = bool(int(os.getenv("PREPARE_LIGHTING", "0")))
PACKAGE_LIGHTNING = os.getenv("PACKAGE_LIGHTNING", None)
LOCAL_LAUNCH_ADMIN_VIEW = bool(int(os.getenv("LOCAL_LAUNCH_ADMIN_VIEW", "0")))
CLOUD_UPLOAD_WARNING = int(os.getenv("CLOUD_UPLOAD_WARNING", "2"))
DISABLE_DEPENDENCY_CACHE = bool(int(os.getenv("DISABLE_DEPENDENCY_CACHE", "0")))

View File

@ -356,8 +356,7 @@ class LightningFlow:
class Flow(LightningFlow):
def run(self):
if self.schedule("hourly"):
# run some code once every hour.
print("run this every hour")
print("run some code every hour")
Arguments:
cron_pattern: The cron pattern to provide. Learn more at https://crontab.guru/.
@ -509,20 +508,16 @@ class LightningFlow:
# add your streamlit code here!
import streamlit as st
st.button("Hello!")
**Example:** Arrange the UI of my children in tabs (default UI by Lightning).
.. code-block:: python
class Flow(LightningFlow):
...
def configure_layout(self):
return [
dict(name="First Tab", content=self.child0),
dict(name="Second Tab", content=self.child1),
# You can include direct URLs too
dict(name="Lightning", content="https://lightning.ai"),
]
@ -608,3 +603,33 @@ class LightningFlow:
yield value
self._calls[call_hash].update({"has_finished": True})
def configure_commands(self):
"""Configure the commands of this LightningFlow.
Returns a list of dictionaries mapping a command name to a flow method.
.. code-block:: python
class Flow(LightningFlow):
def __init__(self):
super().__init__()
self.names = []
def configure_commands(self):
return {"my_command_name": self.my_remote_method}
def my_remote_method(self, name):
self.names.append(name)
Once the app is running with the following command:
.. code-block:: bash
lightning run app app.py
.. code-block:: bash
lightning my_command_name --args name=my_own_name
"""
raise NotImplementedError

View File

@ -36,6 +36,9 @@ ORCHESTRATOR_RESPONSE_CONSTANT = "ORCHESTRATOR_RESPONSE"
ORCHESTRATOR_COPY_REQUEST_CONSTANT = "ORCHESTRATOR_COPY_REQUEST"
ORCHESTRATOR_COPY_RESPONSE_CONSTANT = "ORCHESTRATOR_COPY_RESPONSE"
WORK_QUEUE_CONSTANT = "WORK_QUEUE"
COMMANDS_REQUESTS_QUEUE_CONSTANT = "COMMANDS_REQUESTS_QUEUE"
COMMANDS_RESPONSES_QUEUE_CONSTANT = "COMMANDS_RESPONSES_QUEUE"
COMMANDS_METADATA_QUEUE_CONSTANT = "COMMANDS_METADATA_QUEUE"
class QueuingSystem(Enum):
@ -51,6 +54,20 @@ class QueuingSystem(Enum):
else:
return SingleProcessQueue(queue_name, default_timeout=STATE_UPDATE_TIMEOUT)
def get_commands_requests_queue(self, queue_id: Optional[str] = None) -> "BaseQueue":
queue_name = f"{queue_id}_{COMMANDS_REQUESTS_QUEUE_CONSTANT}" if queue_id else COMMANDS_REQUESTS_QUEUE_CONSTANT
return self._get_queue(queue_name)
def get_commands_responses_queue(self, queue_id: Optional[str] = None) -> "BaseQueue":
queue_name = (
f"{queue_id}_{COMMANDS_RESPONSES_QUEUE_CONSTANT}" if queue_id else COMMANDS_RESPONSES_QUEUE_CONSTANT
)
return self._get_queue(queue_name)
def get_commands_metadata_queue(self, queue_id: Optional[str] = None) -> "BaseQueue":
queue_name = f"{queue_id}_{COMMANDS_METADATA_QUEUE_CONSTANT}" if queue_id else COMMANDS_METADATA_QUEUE_CONSTANT
return self._get_queue(queue_name)
def get_readiness_queue(self, queue_id: Optional[str] = None) -> "BaseQueue":
queue_name = f"{queue_id}_{READINESS_QUEUE_CONSTANT}" if queue_id else READINESS_QUEUE_CONSTANT
return self._get_queue(queue_name)

View File

@ -82,9 +82,11 @@ class Backend(ABC):
kw = dict(queue_id=self.queue_id)
app.delta_queue = self.queues.get_delta_queue(**kw)
app.readiness_queue = self.queues.get_readiness_queue(**kw)
app.commands_requests_queue = self.queues.get_commands_requests_queue(**kw)
app.commands_responses_queue = self.queues.get_commands_responses_queue(**kw)
app.commands_metadata_queue = self.queues.get_commands_metadata_queue(**kw)
app.error_queue = self.queues.get_error_queue(**kw)
app.delta_queue = self.queues.get_delta_queue(**kw)
app.readiness_queue = self.queues.get_readiness_queue(**kw)
app.error_queue = self.queues.get_error_queue(**kw)
app.api_publish_state_queue = self.queues.get_api_state_publish_queue(**kw)
app.api_delta_queue = self.queues.get_api_delta_queue(**kw)

View File

@ -66,6 +66,9 @@ class MultiProcessRuntime(Runtime):
api_publish_state_queue=self.app.api_publish_state_queue,
api_delta_queue=self.app.api_delta_queue,
has_started_queue=has_started_queue,
commands_requests_queue=self.app.commands_requests_queue,
commands_responses_queue=self.app.commands_responses_queue,
commands_metadata_queue=self.app.commands_metadata_queue,
spec=extract_metadata_from_app(self.app),
)
server_proc = multiprocessing.Process(target=start_server, kwargs=kwargs)

View File

@ -39,16 +39,15 @@ class FileUploader:
self.total_size = total_size
self.name = name
@staticmethod
def upload_s3_data(url: str, data: bytes, retries: int, disconnect_retry_wait_seconds: int) -> str:
"""Send data to s3 url.
def upload_data(self, url: str, data: bytes, retries: int, disconnect_retry_wait_seconds: int) -> str:
"""Send data to url.
Parameters
----------
url: str
S3 url string to send data to
url string to send data to
data: bytes
Bytes of data to send to S3
Bytes of data to send to url
retries: int
Amount of retries
disconnect_retry_wait_seconds: int
@ -65,16 +64,19 @@ class FileUploader:
retries = Retry(total=10)
with requests.Session() as s:
s.mount("https://", HTTPAdapter(max_retries=retries))
response = s.put(url, data=data)
if "ETag" not in response.headers:
raise ValueError(f"Unexpected response from S3, response: {response.content}")
return response.headers["ETag"]
return self._upload_data(s, url, data)
except BrokenPipeError:
time.sleep(disconnect_retry_wait_seconds)
disconnect_retries -= 1
raise ValueError("Unable to upload file after multiple attempts")
def _upload_data(self, s: requests.Session, url: str, data: bytes):
resp = s.put(url, data=data)
if "ETag" not in resp.headers:
raise ValueError(f"Unexpected response from {url}, response: {resp.content}")
return resp.headers["ETag"]
def upload(self) -> None:
"""Upload files from source dir into target path in S3."""
task_id = self.progress.add_task("upload", filename=self.name, total=self.total_size)
@ -82,7 +84,7 @@ class FileUploader:
try:
with open(self.source_file, "rb") as f:
data = f.read()
self.upload_s3_data(self.presigned_url, data, self.retries, self.disconnect_retry_wait_seconds)
self.upload_data(self.presigned_url, data, self.retries, self.disconnect_retry_wait_seconds)
self.progress.update(task_id, advance=len(data))
finally:
self.progress.stop()

View File

@ -179,11 +179,7 @@ def run_app_in_cloud(app_folder: str, app_name: str = "app.py") -> Generator:
# 5. Create chromium browser, auth to lightning_app.ai and yield the admin and view pages.
with sync_playwright() as p:
browser = p.chromium.launch(headless=bool(int(os.getenv("HEADLESS", "0"))))
payload = {
"apiKey": Config.api_key,
"username": Config.username,
"duration": "120000",
}
payload = {"apiKey": Config.api_key, "username": Config.username, "duration": "120000"}
context = browser.new_context(
# Eventually this will need to be deleted
http_credentials=HttpCredentials(

View File

@ -1,5 +1,11 @@
import re
from typing import Dict
from typing import Dict, Optional
import requests
from lightning_app.core.constants import APP_SERVER_PORT
from lightning_app.utilities.cloud import _get_project
from lightning_app.utilities.network import LightningClient
def _format_input_env_variables(env_list: tuple) -> Dict[str, str]:
@ -35,3 +41,55 @@ def _format_input_env_variables(env_list: tuple) -> Dict[str, str]:
env_vars_dict[var_name] = value
return env_vars_dict
def _is_url(id: Optional[str]) -> bool:
if isinstance(id, str) and (id.startswith("https://") or id.startswith("http://")):
return True
return False
def _retrieve_application_url_and_available_commands(app_id_or_name_or_url: Optional[str]):
"""This function is used to retrieve the current url associated with an id."""
if _is_url(app_id_or_name_or_url):
url = app_id_or_name_or_url
assert url
resp = requests.get(url + "/api/v1/commands")
if resp.status_code != 200:
raise Exception(f"The server didn't process the request properly. Found {resp.json()}")
return url, resp.json()
# 2: If no identifier has been provided, evaluate the local application
failed_locally = False
if app_id_or_name_or_url is None:
try:
url = f"http://localhost:{APP_SERVER_PORT}"
resp = requests.get(f"{url}/api/v1/commands")
if resp.status_code != 200:
raise Exception(f"The server didn't process the request properly. Found {resp.json()}")
return url, resp.json()
except requests.exceptions.ConnectionError:
failed_locally = True
# 3: If an identified was provided or the local evaluation has failed, evaluate the cloud.
if app_id_or_name_or_url or failed_locally:
client = LightningClient()
project = _get_project(client)
list_lightningapps = client.lightningapp_instance_service_list_lightningapp_instances(project.project_id)
lightningapp_names = [lightningapp.name for lightningapp in list_lightningapps.lightningapps]
if not app_id_or_name_or_url:
raise Exception(f"Provide an application name, id or url with --app_id=X. Found {lightningapp_names}")
for lightningapp in list_lightningapps.lightningapps:
if lightningapp.id == app_id_or_name_or_url or lightningapp.name == app_id_or_name_or_url:
if lightningapp.status.url == "":
raise Exception("The application is starting. Try in a few moments.")
resp = requests.get(lightningapp.status.url + "/api/v1/commands")
if resp.status_code != 200:
raise Exception(f"The server didn't process the request properly. Found {resp.json()}")
return lightningapp.status.url, resp.json()
return None, None

View File

@ -0,0 +1,3 @@
from lightning_app.utilities.commands.base import ClientCommand
__all__ = ["ClientCommand"]

View File

@ -0,0 +1,245 @@
import errno
import inspect
import logging
import os
import os.path as osp
import shutil
import sys
from getpass import getuser
from importlib.util import module_from_spec, spec_from_file_location
from tempfile import gettempdir
from typing import Any, Callable, Dict, List, Optional, Tuple
from uuid import uuid4
import requests
from pydantic import BaseModel
from lightning_app.utilities.app_helpers import is_overridden
from lightning_app.utilities.cloud import _get_project
from lightning_app.utilities.network import LightningClient
from lightning_app.utilities.state import AppState
_logger = logging.getLogger(__name__)
def makedirs(path: str):
r"""Recursive directory creation function."""
try:
os.makedirs(osp.expanduser(osp.normpath(path)))
except OSError as e:
if e.errno != errno.EEXIST and osp.isdir(path):
raise e
class _ClientCommandConfig(BaseModel):
command: str
affiliation: str
params: Dict[str, str]
is_client_command: bool
cls_path: str
cls_name: str
owner: str
requirements: Optional[List[str]]
class ClientCommand:
def __init__(self, method: Callable, requirements: Optional[List[str]] = None) -> None:
self.method = method
flow = getattr(method, "__self__", None)
self.owner = flow.name if flow else None
self.requirements = requirements
self.metadata = None
self.models: Optional[Dict[str, BaseModel]] = None
self.app_url = None
self._state = None
def _setup(self, metadata: Dict[str, Any], models: Dict[str, BaseModel], app_url: str) -> None:
self.metadata = metadata
self.models = models
self.app_url = app_url
@property
def state(self):
if self._state is None:
assert self.app_url
# TODO: Resolve this hack
os.environ["LIGHTNING_APP_STATE_URL"] = "1"
self._state = AppState(host=self.app_url)
self._state._request_state()
os.environ.pop("LIGHTNING_APP_STATE_URL")
return self._state
def run(self, **cli_kwargs) -> None:
"""Overrides with the logic to execute on the client side."""
def invoke_handler(self, **kwargs: Any) -> Dict[str, Any]:
from lightning.app.utilities.state import headers_for
assert kwargs.keys() == self.models.keys()
for k, v in kwargs.items():
assert isinstance(v, self.models[k])
json = {
"command_name": self.metadata["command"],
"command_arguments": {k: v.json() for k, v in kwargs.items()},
"affiliation": self.metadata["affiliation"],
"id": str(uuid4()),
}
resp = requests.post(self.app_url + "/api/v1/commands", json=json, headers=headers_for({}))
assert resp.status_code == 200, resp.json()
return resp.json()
def _to_dict(self):
return {"owner": self.owner, "requirements": self.requirements}
def __call__(self, **kwargs: Any) -> Any:
assert self.models
input = {}
for k, v in kwargs.items():
input[k] = self.models[k].parse_raw(v)
return self.method(**input)
def _download_command(
command_metadata: Dict[str, Any],
app_id: Optional[str],
debug_mode: bool = False,
) -> Tuple[ClientCommand, Dict[str, BaseModel]]:
# TODO: This is a skateboard implementation and the final version will rely on versioned
# immutable commands for security concerns
config = _ClientCommandConfig(**command_metadata)
tmpdir = osp.join(gettempdir(), f"{getuser()}_commands")
makedirs(tmpdir)
target_file = osp.join(tmpdir, f"{config.command}.py")
if app_id:
client = LightningClient()
project_id = _get_project(client).project_id
response = client.lightningapp_instance_service_list_lightningapp_instance_artifacts(project_id, app_id)
for artifact in response.artifacts:
if f"commands/{config.command}.py" == artifact.filename:
r = requests.get(artifact.url, allow_redirects=True)
with open(target_file, "wb") as f:
f.write(r.content)
else:
if not debug_mode:
shutil.copy(config.cls_path, target_file)
cls_name = config.cls_name
spec = spec_from_file_location(config.cls_name, config.cls_path if debug_mode else target_file)
mod = module_from_spec(spec)
sys.modules[cls_name] = mod
spec.loader.exec_module(mod)
command = getattr(mod, cls_name)(method=None, requirements=config.requirements)
models = {k: getattr(mod, v) for k, v in config.params.items()}
if debug_mode:
shutil.rmtree(tmpdir)
return command, models
def _to_annotation(anno: str) -> str:
anno = anno.split("'")[1]
if "." in anno:
return anno.split(".")[-1]
return anno
def _command_to_method_and_metadata(command: ClientCommand) -> Tuple[Callable, Dict[str, Any]]:
"""Extract method and its metadata from a ClientCommand."""
params = inspect.signature(command.method).parameters
command_metadata = {
"cls_path": inspect.getfile(command.__class__),
"cls_name": command.__class__.__name__,
"params": {p.name: _to_annotation(str(p.annotation)) for p in params.values()},
**command._to_dict(),
}
method = command.method
command.models = {}
for k, v in command_metadata["params"].items():
if v == "_empty":
raise Exception(
f"Please, annotate your method {method} with pydantic BaseModel. Refer to the documentation."
)
config = getattr(sys.modules[command.__module__], v, None)
if config is None:
config = getattr(sys.modules[method.__module__], v, None)
if config:
raise Exception(
f"The provided annotation for the argument {k} should in the file "
f"{inspect.getfile(command.__class__)}, not {inspect.getfile(command.method)}."
)
if config is None or not issubclass(config, BaseModel):
raise Exception(
f"The provided annotation for the argument {k} shouldn't an instance of pydantic BaseModel."
)
command.models[k] = config
return method, command_metadata
def _upload_command(command_name: str, command: ClientCommand) -> Optional[str]:
from lightning_app.storage.path import _is_s3fs_available, filesystem, shared_storage_path
filepath = f"commands/{command_name}.py"
remote_url = str(shared_storage_path() / "artifacts" / filepath)
fs = filesystem()
if _is_s3fs_available():
from s3fs import S3FileSystem
if not isinstance(fs, S3FileSystem):
return
source_file = str(inspect.getfile(command.__class__))
remote_url = str(shared_storage_path() / "artifacts" / filepath)
fs.put(source_file, remote_url)
return filepath
def _populate_commands_endpoint(app):
if not is_overridden("configure_commands", app.root):
return
# 1: Populate commands metadata
commands = app.root.configure_commands()
commands_metadata = []
command_names = set()
for command_mapping in commands:
for command_name, command in command_mapping.items():
is_client_command = isinstance(command, ClientCommand)
extras = {}
if is_client_command:
_upload_command(command_name, command)
command, extras = _command_to_method_and_metadata(command)
if command_name in command_names:
raise Exception(f"The component name {command_name} has already been used. They need to be unique.")
command_names.add(command_name)
params = inspect.signature(command).parameters
commands_metadata.append(
{
"command": command_name,
"affiliation": command.__self__.name,
"params": list(params.keys()),
"is_client_command": is_client_command,
**extras,
}
)
# 1.2: Pass the collected commands through the queue to the Rest API.
app.commands_metadata_queue.put(commands_metadata)
app.commands = commands
def _process_command_requests(app):
if not is_overridden("configure_commands", app.root):
return
# 1: Populate commands metadata
commands = app.commands
# 2: Collect requests metadata
command_query = app.get_state_changed_from_queue(app.commands_requests_queue)
if command_query:
for command in commands:
for command_name, method in command.items():
if command_query["command_name"] == command_name:
# 2.1: Evaluate the method associated to a specific command.
# Validation is done on the CLI side.
response = method(**command_query["command_arguments"])
app.commands_responses_queue.put({"response": response, "id": command_query["id"]})

View File

@ -15,7 +15,7 @@ from packaging.version import Version
from lightning_app import _logger, _PROJECT_ROOT, _root_logger
from lightning_app.__version__ import version
from lightning_app.core.constants import PREPARE_LIGHTING
from lightning_app.core.constants import PACKAGE_LIGHTNING
from lightning_app.utilities.git import check_github_repository, get_dir_name
logger = logging.getLogger(__name__)
@ -96,11 +96,13 @@ def _prepare_lightning_wheels_and_requirements(root: Path) -> Optional[Callable]
# Packaging the Lightning codebase happens only inside the `lightning` repo.
git_dir_name = get_dir_name() if check_github_repository() else None
if not PREPARE_LIGHTING and (not git_dir_name or (git_dir_name and not git_dir_name.startswith("lightning"))):
is_lightning = git_dir_name and git_dir_name == "lightning"
if (PACKAGE_LIGHTNING is None and not is_lightning) or PACKAGE_LIGHTNING == "0":
return
if not bool(int(os.getenv("SKIP_LIGHTING_WHEELS_BUILD", "0"))):
download_frontend(_PROJECT_ROOT)
_prepare_wheel(_PROJECT_ROOT)
download_frontend(_PROJECT_ROOT)
_prepare_wheel(_PROJECT_ROOT)
logger.info("Packaged Lightning with your application.")
@ -108,11 +110,12 @@ def _prepare_lightning_wheels_and_requirements(root: Path) -> Optional[Callable]
tar_files = [os.path.join(root, tar_name)]
# skipping this by default
if not bool(int(os.getenv("SKIP_LIGHTING_UTILITY_WHEELS_BUILD", "1"))):
# Don't skip by default
if (PACKAGE_LIGHTNING or is_lightning) and not bool(int(os.getenv("SKIP_LIGHTING_UTILITY_WHEELS_BUILD", "0"))):
# building and copying launcher wheel if installed in editable mode
launcher_project_path = get_dist_path_if_editable_install("lightning_launcher")
if launcher_project_path:
logger.info("Packaged Lightning Launcher with your application.")
_prepare_wheel(launcher_project_path)
tar_name = _copy_tar(launcher_project_path, root)
tar_files.append(os.path.join(root, tar_name))
@ -120,6 +123,7 @@ def _prepare_lightning_wheels_and_requirements(root: Path) -> Optional[Callable]
# building and copying lightning-cloud wheel if installed in editable mode
lightning_cloud_project_path = get_dist_path_if_editable_install("lightning_cloud")
if lightning_cloud_project_path:
logger.info("Packaged Lightning Cloud with your application.")
_prepare_wheel(lightning_cloud_project_path)
tar_name = _copy_tar(lightning_cloud_project_path, root)
tar_files.append(os.path.join(root, tar_name))

View File

@ -5,6 +5,7 @@ import signal
import sys
import threading
import time
import traceback
import warnings
from copy import deepcopy
from dataclasses import dataclass
@ -398,6 +399,9 @@ class WorkRunner:
)
self.delta_queue.put(ComponentDelta(id=self.work_name, delta=Delta(DeepDiff(state, self.work.state))))
self.work.on_exception(e)
print("########## CAPTURED EXCEPTION ###########")
print(traceback.print_exc())
print("########## CAPTURED EXCEPTION ###########")
return
# 14. Copy all artifacts to the shared storage so other Works can access them while this Work gets scaled down

View File

@ -5,7 +5,7 @@ import pytest
from click.testing import CliRunner
from lightning_cloud.openapi import Externalv1LightningappInstance
from lightning_app.cli.lightning_cli import get_app_url, login, logout, main, run
from lightning_app.cli.lightning_cli import _main, get_app_url, login, logout, run
from lightning_app.runners.runtime_type import RuntimeType
@ -37,7 +37,7 @@ def test_start_target_url(runtime_type, extra_args, lightning_cloud_url, expecte
assert get_app_url(runtime_type, *extra_args) == expected_url
@pytest.mark.parametrize("command", [main, run])
@pytest.mark.parametrize("command", [_main, run])
def test_commands(command):
runner = CliRunner()
result = runner.invoke(command)
@ -46,12 +46,12 @@ def test_commands(command):
def test_main_lightning_cli_help():
"""Validate the Lightning CLI."""
res = os.popen("python -m lightning_app --help").read()
res = os.popen("python -m lightning --help").read()
assert "login " in res
assert "logout " in res
assert "run " in res
res = os.popen("python -m lightning_app run --help").read()
res = os.popen("python -m lightning run --help").read()
assert "app " in res
# hidden run commands should not appear in the help text

View File

@ -17,10 +17,9 @@ def test_non_existing_python_script():
run_work_isolated(python_script)
assert not python_script.has_started
with pytest.raises(FileNotFoundError, match=match):
python_script = TracerPythonScript(match)
run_work_isolated(python_script)
assert not python_script.has_started
python_script = TracerPythonScript(match, raise_exception=False)
run_work_isolated(python_script)
assert python_script.has_failed
def test_simple_python_script():

View File

@ -161,10 +161,12 @@ def test_update_publish_state_and_maybe_refresh_ui():
app = AppStageTestingApp(FlowA(), debug=True)
publish_state_queue = MockQueue("publish_state_queue")
commands_metadata_queue = MockQueue("commands_metadata_queue")
commands_responses_queue = MockQueue("commands_metadata_queue")
publish_state_queue.put(app.state_with_changes)
thread = UIRefresher(publish_state_queue)
thread = UIRefresher(publish_state_queue, commands_metadata_queue, commands_responses_queue)
thread.run_once()
assert global_app_state_store.get_app_state("1234") == app.state_with_changes
@ -190,11 +192,21 @@ async def test_start_server(x_lightning_type):
publish_state_queue = InfiniteQueue("publish_state_queue")
change_state_queue = MockQueue("change_state_queue")
has_started_queue = MockQueue("has_started_queue")
commands_requests_queue = MockQueue("commands_requests_queue")
commands_responses_queue = MockQueue("commands_responses_queue")
commands_metadata_queue = MockQueue("commands_metadata_queue")
state = app.state_with_changes
publish_state_queue.put(state)
spec = extract_metadata_from_app(app)
ui_refresher = start_server(
publish_state_queue, change_state_queue, has_started_queue=has_started_queue, uvicorn_run=False, spec=spec
publish_state_queue,
change_state_queue,
commands_requests_queue,
commands_responses_queue,
commands_metadata_queue,
has_started_queue=has_started_queue,
uvicorn_run=False,
spec=spec,
)
headers = headers_for({"type": x_lightning_type})
@ -331,10 +343,16 @@ def test_start_server_started():
api_publish_state_queue = mp.Queue()
api_delta_queue = mp.Queue()
has_started_queue = mp.Queue()
commands_requests_queue = mp.Queue()
commands_responses_queue = mp.Queue()
commands_metadata_queue = mp.Queue()
kwargs = dict(
api_publish_state_queue=api_publish_state_queue,
api_delta_queue=api_delta_queue,
has_started_queue=has_started_queue,
commands_requests_queue=commands_requests_queue,
commands_responses_queue=commands_responses_queue,
commands_metadata_queue=commands_metadata_queue,
port=1111,
)
@ -354,12 +372,18 @@ def test_start_server_info_message(ui_refresher, uvicorn_run, caplog, monkeypatc
api_publish_state_queue = MockQueue()
api_delta_queue = MockQueue()
has_started_queue = MockQueue()
commands_requests_queue = MockQueue()
commands_responses_queue = MockQueue()
commands_metadata_queue = MockQueue()
kwargs = dict(
host=host,
port=1111,
api_publish_state_queue=api_publish_state_queue,
api_delta_queue=api_delta_queue,
has_started_queue=has_started_queue,
commands_requests_queue=commands_requests_queue,
commands_responses_queue=commands_responses_queue,
commands_metadata_queue=commands_metadata_queue,
)
monkeypatch.setattr(api, "logger", logging.getLogger())

View File

@ -39,10 +39,11 @@ def test_file_uploader():
@mock.patch("lightning_app.source_code.uploader.requests.Session", MockedRequestSession)
def test_file_uploader_failing_when_no_etag():
response["response"] = MagicMock(headers={})
presigned_url = "https://test-url"
file_uploader = uploader.FileUploader(
presigned_url="https://test-url", source_file="test.txt", total_size=100, name="test.txt"
presigned_url=presigned_url, source_file="test.txt", total_size=100, name="test.txt"
)
file_uploader.progress = MagicMock()
with pytest.raises(ValueError, match="Unexpected response from S3, response"):
with pytest.raises(ValueError, match=f"Unexpected response from {presigned_url}, response"):
file_uploader.upload()

View File

@ -0,0 +1,162 @@
import argparse
import sys
from multiprocessing import Process
from time import sleep
from unittest.mock import MagicMock
import pytest
import requests
from pydantic import BaseModel
from lightning import LightningFlow
from lightning_app import LightningApp
from lightning_app.cli.lightning_cli import app_command
from lightning_app.core.constants import APP_SERVER_PORT
from lightning_app.runners import MultiProcessRuntime
from lightning_app.testing.helpers import RunIf
from lightning_app.utilities.commands.base import _command_to_method_and_metadata, _download_command, ClientCommand
from lightning_app.utilities.state import AppState
class SweepConfig(BaseModel):
sweep_name: str
num_trials: int
class SweepCommand(ClientCommand):
def run(self) -> None:
print(sys.argv)
parser = argparse.ArgumentParser()
parser.add_argument("--sweep_name", type=str)
parser.add_argument("--num_trials", type=int)
hparams = parser.parse_args()
config = SweepConfig(sweep_name=hparams.sweep_name, num_trials=hparams.num_trials)
response = self.invoke_handler(config=config)
assert response is True
class FlowCommands(LightningFlow):
def __init__(self):
super().__init__()
self.names = []
self.has_sweep = False
def run(self):
if self.has_sweep and len(self.names) == 1:
sleep(2)
self._exit()
def trigger_method(self, name: str):
self.names.append(name)
def sweep(self, config: SweepConfig):
self.has_sweep = True
return True
def configure_commands(self):
return [{"user_command": self.trigger_method}, {"sweep": SweepCommand(self.sweep)}]
class DummyConfig(BaseModel):
something: str
something_else: int
class DummyCommand(ClientCommand):
def run(self, something: str, something_else: int) -> None:
config = DummyConfig(something=something, something_else=something_else)
response = self.invoke_handler(config=config)
assert response == {"body": 0}
def run(config: DummyConfig):
assert isinstance(config, DummyCommand)
def run_failure_0(name: str):
pass
def run_failure_1(name):
pass
class CustomModel(BaseModel):
pass
def run_failure_2(name: CustomModel):
pass
@RunIf(skip_windows=True)
def test_command_to_method_and_metadata():
with pytest.raises(Exception, match="The provided annotation for the argument name"):
_command_to_method_and_metadata(ClientCommand(run_failure_0))
with pytest.raises(Exception, match="annotate your method"):
_command_to_method_and_metadata(ClientCommand(run_failure_1))
with pytest.raises(Exception, match="lightning_app/utilities/commands/base.py"):
_command_to_method_and_metadata(ClientCommand(run_failure_2))
def test_client_commands(monkeypatch):
import requests
resp = MagicMock()
resp.status_code = 200
value = {"body": 0}
resp.json = MagicMock(return_value=value)
post = MagicMock()
post.return_value = resp
monkeypatch.setattr(requests, "post", post)
url = "http//"
kwargs = {"something": "1", "something_else": "1"}
command = DummyCommand(run)
_, command_metadata = _command_to_method_and_metadata(command)
command_metadata.update(
{
"command": "dummy",
"affiliation": "root",
"is_client_command": True,
"owner": "root",
}
)
client_command, models = _download_command(command_metadata, None)
client_command._setup(metadata=command_metadata, models=models, app_url=url)
client_command.run(**kwargs)
def target():
app = LightningApp(FlowCommands())
MultiProcessRuntime(app).dispatch()
def test_configure_commands(monkeypatch):
process = Process(target=target)
process.start()
time_left = 15
while time_left > 0:
try:
requests.get(f"http://localhost:{APP_SERVER_PORT}/healthz")
break
except requests.exceptions.ConnectionError:
sleep(0.1)
time_left -= 0.1
sleep(0.5)
monkeypatch.setattr(sys, "argv", ["lightning", "user_command", "--name=something"])
app_command()
sleep(0.5)
state = AppState()
state._request_state()
assert state.names == ["something"]
monkeypatch.setattr(sys, "argv", ["lightning", "sweep", "--sweep_name", "my_name", "--num_trials", "1"])
app_command()
time_left = 15
while time_left > 0 or process.exitcode is None:
sleep(0.1)
time_left -= 0.1
assert process.exitcode == 0

View File

@ -15,7 +15,7 @@ from lightning_app.utilities.state import AppState
def test_app_state_not_connected(_):
"""Test an error message when a disconnected AppState tries to access attributes."""
state = AppState()
state = AppState(port=8000)
with pytest.raises(AttributeError, match="Failed to connect and fetch the app state"):
_ = state.value
with pytest.raises(AttributeError, match="Failed to connect and fetch the app state"):
@ -209,7 +209,7 @@ def test_attach_plugin():
@mock.patch("lightning_app.utilities.state._configure_session", return_value=requests)
def test_app_state_connection_error(_):
"""Test an error message when a connection to retrieve the state can't be established."""
app_state = AppState()
app_state = AppState(port=8000)
with pytest.raises(AttributeError, match=r"Failed to connect and fetch the app state\. Is the app running?"):
app_state._request_state()

View File

@ -0,0 +1,31 @@
import os
from subprocess import Popen
from time import sleep
from unittest import mock
import pytest
from tests_app import _PROJECT_ROOT
from lightning_app.testing.testing import run_app_in_cloud
@mock.patch.dict(os.environ, {"SKIP_LIGHTING_UTILITY_WHEELS_BUILD": "0"})
@pytest.mark.cloud
def test_commands_example_cloud() -> None:
with run_app_in_cloud(os.path.join(_PROJECT_ROOT, "examples/app_commands")) as (
admin_page,
_,
fetch_logs,
):
app_id = admin_page.url.split("/")[-1]
cmd = f"lightning trigger_with_client_command --name=something --app_id {app_id}"
Popen(cmd, shell=True).wait()
cmd = f"lightning trigger_without_client_command --name=else --app_id {app_id}"
Popen(cmd, shell=True).wait()
has_logs = False
while not has_logs:
for log in fetch_logs():
if "['something', 'else']" in log:
has_logs = True
sleep(1)