[App] Introduce Commands (#13602)
This commit is contained in:
parent
a8d7b4476c
commit
4c35867b61
|
@ -54,6 +54,7 @@ jobs:
|
|||
- custom_work_dependencies
|
||||
- drive
|
||||
- payload
|
||||
- commands
|
||||
timeout-minutes: 35
|
||||
steps:
|
||||
- uses: actions/checkout@v2
|
||||
|
@ -155,7 +156,7 @@ jobs:
|
|||
shell: bash
|
||||
run: |
|
||||
mkdir -p ${VIDEO_LOCATION}
|
||||
HEADLESS=1 python -m pytest tests/tests_app_examples/test_${{ matrix.app_name }}.py::test_${{ matrix.app_name }}_example_cloud --timeout=900 --capture=no -v --color=yes
|
||||
HEADLESS=1 PACKAGE_LIGHTNING=1 python -m pytest tests/tests_app_examples/test_${{ matrix.app_name }}.py::test_${{ matrix.app_name }}_example_cloud --timeout=900 --capture=no -v --color=yes
|
||||
# Delete the artifacts if successful
|
||||
rm -r ${VIDEO_LOCATION}/${{ matrix.app_name }}
|
||||
|
||||
|
|
|
@ -109,6 +109,7 @@ celerybeat-schedule
|
|||
|
||||
# dotenv
|
||||
.env
|
||||
.env_stagging
|
||||
|
||||
# virtualenv
|
||||
.venv
|
||||
|
@ -160,3 +161,5 @@ tags
|
|||
.tags
|
||||
src/lightning_app/ui/*
|
||||
*examples/template_react_ui*
|
||||
hars*
|
||||
artifacts/*
|
||||
|
|
|
@ -0,0 +1 @@
|
|||
name: app-commands
|
|
@ -0,0 +1,39 @@
|
|||
from command import CustomCommand, CustomConfig
|
||||
|
||||
from lightning import LightningFlow
|
||||
from lightning_app.core.app import LightningApp
|
||||
|
||||
|
||||
class ChildFlow(LightningFlow):
|
||||
def trigger_method(self, name: str):
|
||||
print(f"Hello {name}")
|
||||
|
||||
def configure_commands(self):
|
||||
return [{"nested_trigger_command": self.trigger_method}]
|
||||
|
||||
|
||||
class FlowCommands(LightningFlow):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.names = []
|
||||
self.child_flow = ChildFlow()
|
||||
|
||||
def run(self):
|
||||
if len(self.names):
|
||||
print(self.names)
|
||||
|
||||
def trigger_without_client_command(self, name: str):
|
||||
self.names.append(name)
|
||||
|
||||
def trigger_with_client_command(self, config: CustomConfig):
|
||||
self.names.append(config.name)
|
||||
|
||||
def configure_commands(self):
|
||||
commands = [
|
||||
{"trigger_without_client_command": self.trigger_without_client_command},
|
||||
{"trigger_with_client_command": CustomCommand(self.trigger_with_client_command)},
|
||||
]
|
||||
return commands + self.child_flow.configure_commands()
|
||||
|
||||
|
||||
app = LightningApp(FlowCommands())
|
|
@ -0,0 +1,17 @@
|
|||
from argparse import ArgumentParser
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from lightning.app.utilities.commands import ClientCommand
|
||||
|
||||
|
||||
class CustomConfig(BaseModel):
|
||||
name: str
|
||||
|
||||
|
||||
class CustomCommand(ClientCommand):
|
||||
def run(self):
|
||||
parser = ArgumentParser()
|
||||
parser.add_argument("--name", type=str)
|
||||
args = parser.parse_args()
|
||||
self.invoke_handler(config=CustomConfig(name=args.name))
|
|
@ -8,6 +8,10 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
|
|||
|
||||
### Added
|
||||
|
||||
- Add support for `Lightning App Commands` through the `configure_commands` hook on the Lightning Flow and the `ClientCommand` ([#13602](https://github.com/Lightning-AI/lightning/pull/13602))
|
||||
|
||||
### Changed
|
||||
|
||||
- Update the Lightning App docs ([#13537](https://github.com/PyTorchLightning/pytorch-lightning/pull/13537))
|
||||
|
||||
### Changed
|
||||
|
|
|
@ -1,9 +1,13 @@
|
|||
import logging
|
||||
import os
|
||||
import sys
|
||||
from argparse import ArgumentParser
|
||||
from pathlib import Path
|
||||
from typing import List, Tuple, Union
|
||||
from uuid import uuid4
|
||||
|
||||
import click
|
||||
import requests
|
||||
from requests.exceptions import ConnectionError
|
||||
|
||||
from lightning_app import __version__ as ver
|
||||
|
@ -11,9 +15,13 @@ from lightning_app.cli import cmd_init, cmd_install, cmd_pl_init, cmd_react_ui_i
|
|||
from lightning_app.core.constants import get_lightning_cloud_url, LOCAL_LAUNCH_ADMIN_VIEW
|
||||
from lightning_app.runners.runtime import dispatch
|
||||
from lightning_app.runners.runtime_type import RuntimeType
|
||||
from lightning_app.utilities.cli_helpers import _format_input_env_variables
|
||||
from lightning_app.utilities.cli_helpers import (
|
||||
_format_input_env_variables,
|
||||
_retrieve_application_url_and_available_commands,
|
||||
)
|
||||
from lightning_app.utilities.install_components import register_all_external_components
|
||||
from lightning_app.utilities.login import Auth
|
||||
from lightning_app.utilities.state import headers_for
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
@ -26,14 +34,23 @@ def get_app_url(runtime_type: RuntimeType, *args) -> str:
|
|||
return "http://127.0.0.1:7501/admin" if LOCAL_LAUNCH_ADMIN_VIEW else "http://127.0.0.1:7501/view"
|
||||
|
||||
|
||||
def main():
|
||||
if len(sys.argv) == 1:
|
||||
_main()
|
||||
elif sys.argv[1] in _main.commands.keys() or sys.argv[1] == "--help":
|
||||
_main()
|
||||
else:
|
||||
app_command()
|
||||
|
||||
|
||||
@click.group()
|
||||
@click.version_option(ver)
|
||||
def main():
|
||||
def _main():
|
||||
register_all_external_components()
|
||||
pass
|
||||
|
||||
|
||||
@main.command()
|
||||
@_main.command()
|
||||
def login():
|
||||
"""Log in to your Lightning.ai account."""
|
||||
auth = Auth()
|
||||
|
@ -46,7 +63,7 @@ def login():
|
|||
exit(1)
|
||||
|
||||
|
||||
@main.command()
|
||||
@_main.command()
|
||||
def logout():
|
||||
"""Log out of your Lightning.ai account."""
|
||||
Auth().clear()
|
||||
|
@ -93,7 +110,7 @@ def _run_app(
|
|||
click.echo("Application is ready in the cloud")
|
||||
|
||||
|
||||
@main.group()
|
||||
@_main.group()
|
||||
def run():
|
||||
"""Run your application."""
|
||||
|
||||
|
@ -125,31 +142,83 @@ def run_app(
|
|||
_run_app(file, cloud, without_server, no_cache, name, blocking, open_ui, env)
|
||||
|
||||
|
||||
@main.group(hidden=True)
|
||||
def app_command():
|
||||
"""Execute a function in a running application from its name."""
|
||||
from lightning_app.utilities.commands.base import _download_command
|
||||
|
||||
logger.warn("Lightning Commands are a beta feature and APIs aren't stable yet.")
|
||||
|
||||
debug_mode = bool(int(os.getenv("DEBUG", "0")))
|
||||
|
||||
parser = ArgumentParser()
|
||||
parser.add_argument("--app_id", default=None, type=str, help="Optional argument to identify an application.")
|
||||
hparams, argv = parser.parse_known_args()
|
||||
|
||||
# 1: Collect the url and comments from the running application
|
||||
url, commands = _retrieve_application_url_and_available_commands(hparams.app_id)
|
||||
if url is None or commands is None:
|
||||
raise Exception("We couldn't find any matching running app.")
|
||||
|
||||
if not commands:
|
||||
raise Exception("This application doesn't expose any commands yet.")
|
||||
|
||||
command = argv[0]
|
||||
|
||||
command_names = [c["command"] for c in commands]
|
||||
if command not in command_names:
|
||||
raise Exception(f"The provided command {command} isn't available in {command_names}")
|
||||
|
||||
# 2: Send the command from the user
|
||||
command_metadata = [c for c in commands if c["command"] == command][0]
|
||||
params = command_metadata["params"]
|
||||
|
||||
# 3: Execute the command
|
||||
if not command_metadata["is_client_command"]:
|
||||
# TODO: Improve what is supported there.
|
||||
kwargs = {k.split("=")[0].replace("--", ""): k.split("=")[1] for k in argv[1:]}
|
||||
for param in params:
|
||||
if param not in kwargs:
|
||||
raise Exception(f"The argument --{param}=X hasn't been provided.")
|
||||
json = {
|
||||
"command_name": command,
|
||||
"command_arguments": kwargs,
|
||||
"affiliation": command_metadata["affiliation"],
|
||||
"id": str(uuid4()),
|
||||
}
|
||||
resp = requests.post(url + "/api/v1/commands", json=json, headers=headers_for({}))
|
||||
assert resp.status_code == 200, resp.json()
|
||||
else:
|
||||
client_command, models = _download_command(command_metadata, hparams.app_id, debug_mode=debug_mode)
|
||||
client_command._setup(metadata=command_metadata, models=models, app_url=url)
|
||||
sys.argv = argv
|
||||
client_command.run()
|
||||
|
||||
|
||||
@_main.group(hidden=True)
|
||||
def fork():
|
||||
"""Fork an application."""
|
||||
pass
|
||||
|
||||
|
||||
@main.group(hidden=True)
|
||||
@_main.group(hidden=True)
|
||||
def stop():
|
||||
"""Stop your application."""
|
||||
pass
|
||||
|
||||
|
||||
@main.group(hidden=True)
|
||||
@_main.group(hidden=True)
|
||||
def delete():
|
||||
"""Delete an application."""
|
||||
pass
|
||||
|
||||
|
||||
@main.group(name="list", hidden=True)
|
||||
@_main.group(name="list", hidden=True)
|
||||
def get_list():
|
||||
"""List your applications."""
|
||||
pass
|
||||
|
||||
|
||||
@main.group()
|
||||
@_main.group()
|
||||
def install():
|
||||
"""Install Lightning apps and components."""
|
||||
|
||||
|
@ -207,7 +276,7 @@ def install_component(name, yes, version):
|
|||
cmd_install.gallery_component(name, yes, version)
|
||||
|
||||
|
||||
@main.group()
|
||||
@_main.group()
|
||||
def init():
|
||||
"""Init a Lightning app and component."""
|
||||
|
||||
|
|
|
@ -93,8 +93,6 @@ class TracerPythonScript(LightningWork):
|
|||
:language: python
|
||||
"""
|
||||
super().__init__(**kwargs)
|
||||
if not os.path.exists(script_path):
|
||||
raise FileNotFoundError(f"The provided `script_path` {script_path}` wasn't found.")
|
||||
self.script_path = str(script_path)
|
||||
if isinstance(script_args, str):
|
||||
script_args = script_args.split(" ")
|
||||
|
@ -105,6 +103,8 @@ class TracerPythonScript(LightningWork):
|
|||
setattr(self, name, None)
|
||||
|
||||
def run(self, **kwargs):
|
||||
if not os.path.exists(self.script_path):
|
||||
raise FileNotFoundError(f"The provided `script_path` {self.script_path}` wasn't found.")
|
||||
kwargs = {k: v.value if isinstance(v, Payload) else v for k, v in kwargs.items()}
|
||||
init_globals = globals()
|
||||
init_globals.update(kwargs)
|
||||
|
|
|
@ -3,6 +3,8 @@ import logging
|
|||
import os
|
||||
import queue
|
||||
import sys
|
||||
import time
|
||||
import traceback
|
||||
from copy import deepcopy
|
||||
from multiprocessing import Queue
|
||||
from threading import Event, Lock, Thread
|
||||
|
@ -40,6 +42,10 @@ STATE_EVENT = "State changed"
|
|||
frontend_static_dir = os.path.join(FRONTEND_DIR, "static")
|
||||
|
||||
api_app_delta_queue: Queue = None
|
||||
api_commands_requests_queue: Queue = None
|
||||
api_commands_metadata_queue: Queue = None
|
||||
api_commands_responses_queue: Queue = None
|
||||
|
||||
template = {"ui": {}, "app": {}}
|
||||
templates = Jinja2Templates(directory=FRONTEND_DIR)
|
||||
|
||||
|
@ -50,6 +56,8 @@ global_app_state_store.add(TEST_SESSION_UUID)
|
|||
lock = Lock()
|
||||
|
||||
app_spec: Optional[List] = None
|
||||
app_commands_metadata: Optional[Dict] = None
|
||||
commands_response_store = {}
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
@ -59,16 +67,22 @@ logger = logging.getLogger(__name__)
|
|||
|
||||
|
||||
class UIRefresher(Thread):
|
||||
def __init__(self, api_publish_state_queue) -> None:
|
||||
def __init__(self, api_publish_state_queue, api_commands_metadata_queue, api_commands_responses_queue) -> None:
|
||||
super().__init__(daemon=True)
|
||||
self.api_publish_state_queue = api_publish_state_queue
|
||||
self.api_commands_metadata_queue = api_commands_metadata_queue
|
||||
self.api_commands_responses_queue = api_commands_responses_queue
|
||||
self._exit_event = Event()
|
||||
|
||||
def run(self):
|
||||
# TODO: Create multiple threads to handle the background logic
|
||||
# TODO: Investigate the use of `parallel=True`
|
||||
while not self._exit_event.is_set():
|
||||
self.run_once()
|
||||
try:
|
||||
while not self._exit_event.is_set():
|
||||
self.run_once()
|
||||
except Exception as e:
|
||||
logger.error(traceback.print_exc())
|
||||
raise e
|
||||
|
||||
def run_once(self):
|
||||
try:
|
||||
|
@ -78,6 +92,22 @@ class UIRefresher(Thread):
|
|||
except queue.Empty:
|
||||
pass
|
||||
|
||||
try:
|
||||
metadata = self.api_commands_metadata_queue.get(timeout=0)
|
||||
with lock:
|
||||
global app_commands_metadata
|
||||
app_commands_metadata = metadata
|
||||
except queue.Empty:
|
||||
pass
|
||||
|
||||
try:
|
||||
response = self.api_commands_responses_queue.get(timeout=0)
|
||||
with lock:
|
||||
global commands_response_store
|
||||
commands_response_store[response["id"]] = response["response"]
|
||||
except queue.Empty:
|
||||
pass
|
||||
|
||||
def join(self, timeout: Optional[float] = None) -> None:
|
||||
self._exit_event.set()
|
||||
super().join(timeout)
|
||||
|
@ -146,6 +176,43 @@ async def get_spec(
|
|||
return app_spec or []
|
||||
|
||||
|
||||
@fastapi_service.post("/api/v1/commands", response_class=JSONResponse)
|
||||
async def run_remote_command(
|
||||
request: Request,
|
||||
) -> None:
|
||||
data = await request.json()
|
||||
command_name = data.get("command_name", None)
|
||||
if not command_name:
|
||||
raise Exception("The provided command name is empty.")
|
||||
command_arguments = data.get("command_arguments", None)
|
||||
if not command_arguments:
|
||||
raise Exception("The provided command metadata is empty.")
|
||||
affiliation = data.get("affiliation", None)
|
||||
if not affiliation:
|
||||
raise Exception("The provided affiliation is empty.")
|
||||
|
||||
async def fn(data):
|
||||
request_id = data["id"]
|
||||
api_commands_requests_queue.put(data)
|
||||
|
||||
t0 = time.time()
|
||||
while request_id not in commands_response_store:
|
||||
await asyncio.sleep(0.1)
|
||||
if (time.time() - t0) > 15:
|
||||
raise Exception("The response was never received.")
|
||||
|
||||
return commands_response_store[request_id]
|
||||
|
||||
return await asyncio.create_task(fn(data))
|
||||
|
||||
|
||||
@fastapi_service.get("/api/v1/commands", response_class=JSONResponse)
|
||||
async def get_commands() -> Optional[Dict]:
|
||||
global app_commands_metadata
|
||||
with lock:
|
||||
return app_commands_metadata
|
||||
|
||||
|
||||
@fastapi_service.post("/api/v1/delta")
|
||||
async def post_delta(
|
||||
request: Request,
|
||||
|
@ -279,6 +346,9 @@ class LightningUvicornServer(uvicorn.Server):
|
|||
def start_server(
|
||||
api_publish_state_queue,
|
||||
api_delta_queue,
|
||||
commands_requests_queue,
|
||||
commands_responses_queue,
|
||||
commands_metadata_queue,
|
||||
has_started_queue: Optional[Queue] = None,
|
||||
host="127.0.0.1",
|
||||
port=8000,
|
||||
|
@ -288,16 +358,22 @@ def start_server(
|
|||
):
|
||||
global api_app_delta_queue
|
||||
global global_app_state_store
|
||||
global api_commands_requests_queue
|
||||
global api_commands_responses_queue
|
||||
global app_spec
|
||||
|
||||
app_spec = spec
|
||||
api_app_delta_queue = api_delta_queue
|
||||
api_commands_requests_queue = commands_requests_queue
|
||||
api_commands_responses_queue = commands_responses_queue
|
||||
api_commands_metadata_queue = commands_metadata_queue
|
||||
|
||||
if app_state_store is not None:
|
||||
global_app_state_store = app_state_store
|
||||
|
||||
global_app_state_store.add(TEST_SESSION_UUID)
|
||||
|
||||
refresher = UIRefresher(api_publish_state_queue)
|
||||
refresher = UIRefresher(api_publish_state_queue, api_commands_metadata_queue, commands_responses_queue)
|
||||
refresher.setDaemon(True)
|
||||
refresher.start()
|
||||
|
||||
|
|
|
@ -16,6 +16,7 @@ from lightning_app.core.queues import BaseQueue, SingleProcessQueue
|
|||
from lightning_app.frontend import Frontend
|
||||
from lightning_app.storage.path import storage_root_dir
|
||||
from lightning_app.utilities.app_helpers import _delta_to_appstate_delta, _LightningAppRef
|
||||
from lightning_app.utilities.commands.base import _populate_commands_endpoint, _process_command_requests
|
||||
from lightning_app.utilities.component import _convert_paths_after_init
|
||||
from lightning_app.utilities.enum import AppStage
|
||||
from lightning_app.utilities.exceptions import CacheMissException, ExitAppException
|
||||
|
@ -72,6 +73,9 @@ class LightningApp:
|
|||
# queues definition.
|
||||
self.delta_queue: t.Optional[BaseQueue] = None
|
||||
self.readiness_queue: t.Optional[BaseQueue] = None
|
||||
self.commands_requests_queue: t.Optional[BaseQueue] = None
|
||||
self.commands_responses_queue: t.Optional[BaseQueue] = None
|
||||
self.commands_metadata_queue: t.Optional[BaseQueue] = None
|
||||
self.api_publish_state_queue: t.Optional[BaseQueue] = None
|
||||
self.api_delta_queue: t.Optional[BaseQueue] = None
|
||||
self.error_queue: t.Optional[BaseQueue] = None
|
||||
|
@ -81,6 +85,7 @@ class LightningApp:
|
|||
self.copy_response_queues: t.Optional[t.Dict[str, BaseQueue]] = None
|
||||
self.caller_queues: t.Optional[t.Dict[str, BaseQueue]] = None
|
||||
self.work_queues: t.Optional[t.Dict[str, BaseQueue]] = None
|
||||
self.commands: t.Optional[t.List] = None
|
||||
|
||||
self.should_publish_changes_to_api = False
|
||||
self.component_affiliation = None
|
||||
|
@ -345,6 +350,8 @@ class LightningApp:
|
|||
elif self.stage == AppStage.RESTARTING:
|
||||
return self._apply_restarting()
|
||||
|
||||
_process_command_requests(self)
|
||||
|
||||
try:
|
||||
self.check_error_queue()
|
||||
t0 = time()
|
||||
|
@ -397,6 +404,8 @@ class LightningApp:
|
|||
|
||||
self._reset_run_time_monitor()
|
||||
|
||||
_populate_commands_endpoint(self)
|
||||
|
||||
while not done:
|
||||
done = self.run_once()
|
||||
|
||||
|
|
|
@ -22,7 +22,7 @@ REDIS_QUEUES_READ_DEFAULT_TIMEOUT = 0.005
|
|||
REDIS_WARNING_QUEUE_SIZE = 1000
|
||||
USER_ID = os.getenv("USER_ID", "1234")
|
||||
FRONTEND_DIR = os.path.join(os.path.dirname(lightning_app.__file__), "ui")
|
||||
PREPARE_LIGHTING = bool(int(os.getenv("PREPARE_LIGHTING", "0")))
|
||||
PACKAGE_LIGHTNING = os.getenv("PACKAGE_LIGHTNING", None)
|
||||
LOCAL_LAUNCH_ADMIN_VIEW = bool(int(os.getenv("LOCAL_LAUNCH_ADMIN_VIEW", "0")))
|
||||
CLOUD_UPLOAD_WARNING = int(os.getenv("CLOUD_UPLOAD_WARNING", "2"))
|
||||
DISABLE_DEPENDENCY_CACHE = bool(int(os.getenv("DISABLE_DEPENDENCY_CACHE", "0")))
|
||||
|
|
|
@ -356,8 +356,7 @@ class LightningFlow:
|
|||
class Flow(LightningFlow):
|
||||
def run(self):
|
||||
if self.schedule("hourly"):
|
||||
# run some code once every hour.
|
||||
print("run this every hour")
|
||||
print("run some code every hour")
|
||||
|
||||
Arguments:
|
||||
cron_pattern: The cron pattern to provide. Learn more at https://crontab.guru/.
|
||||
|
@ -509,20 +508,16 @@ class LightningFlow:
|
|||
# add your streamlit code here!
|
||||
import streamlit as st
|
||||
|
||||
st.button("Hello!")
|
||||
|
||||
**Example:** Arrange the UI of my children in tabs (default UI by Lightning).
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
class Flow(LightningFlow):
|
||||
...
|
||||
|
||||
def configure_layout(self):
|
||||
return [
|
||||
dict(name="First Tab", content=self.child0),
|
||||
dict(name="Second Tab", content=self.child1),
|
||||
# You can include direct URLs too
|
||||
dict(name="Lightning", content="https://lightning.ai"),
|
||||
]
|
||||
|
||||
|
@ -608,3 +603,33 @@ class LightningFlow:
|
|||
yield value
|
||||
|
||||
self._calls[call_hash].update({"has_finished": True})
|
||||
|
||||
def configure_commands(self):
|
||||
"""Configure the commands of this LightningFlow.
|
||||
|
||||
Returns a list of dictionaries mapping a command name to a flow method.
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
class Flow(LightningFlow):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.names = []
|
||||
|
||||
def configure_commands(self):
|
||||
return {"my_command_name": self.my_remote_method}
|
||||
|
||||
def my_remote_method(self, name):
|
||||
self.names.append(name)
|
||||
|
||||
Once the app is running with the following command:
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
lightning run app app.py
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
lightning my_command_name --args name=my_own_name
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
|
|
@ -36,6 +36,9 @@ ORCHESTRATOR_RESPONSE_CONSTANT = "ORCHESTRATOR_RESPONSE"
|
|||
ORCHESTRATOR_COPY_REQUEST_CONSTANT = "ORCHESTRATOR_COPY_REQUEST"
|
||||
ORCHESTRATOR_COPY_RESPONSE_CONSTANT = "ORCHESTRATOR_COPY_RESPONSE"
|
||||
WORK_QUEUE_CONSTANT = "WORK_QUEUE"
|
||||
COMMANDS_REQUESTS_QUEUE_CONSTANT = "COMMANDS_REQUESTS_QUEUE"
|
||||
COMMANDS_RESPONSES_QUEUE_CONSTANT = "COMMANDS_RESPONSES_QUEUE"
|
||||
COMMANDS_METADATA_QUEUE_CONSTANT = "COMMANDS_METADATA_QUEUE"
|
||||
|
||||
|
||||
class QueuingSystem(Enum):
|
||||
|
@ -51,6 +54,20 @@ class QueuingSystem(Enum):
|
|||
else:
|
||||
return SingleProcessQueue(queue_name, default_timeout=STATE_UPDATE_TIMEOUT)
|
||||
|
||||
def get_commands_requests_queue(self, queue_id: Optional[str] = None) -> "BaseQueue":
|
||||
queue_name = f"{queue_id}_{COMMANDS_REQUESTS_QUEUE_CONSTANT}" if queue_id else COMMANDS_REQUESTS_QUEUE_CONSTANT
|
||||
return self._get_queue(queue_name)
|
||||
|
||||
def get_commands_responses_queue(self, queue_id: Optional[str] = None) -> "BaseQueue":
|
||||
queue_name = (
|
||||
f"{queue_id}_{COMMANDS_RESPONSES_QUEUE_CONSTANT}" if queue_id else COMMANDS_RESPONSES_QUEUE_CONSTANT
|
||||
)
|
||||
return self._get_queue(queue_name)
|
||||
|
||||
def get_commands_metadata_queue(self, queue_id: Optional[str] = None) -> "BaseQueue":
|
||||
queue_name = f"{queue_id}_{COMMANDS_METADATA_QUEUE_CONSTANT}" if queue_id else COMMANDS_METADATA_QUEUE_CONSTANT
|
||||
return self._get_queue(queue_name)
|
||||
|
||||
def get_readiness_queue(self, queue_id: Optional[str] = None) -> "BaseQueue":
|
||||
queue_name = f"{queue_id}_{READINESS_QUEUE_CONSTANT}" if queue_id else READINESS_QUEUE_CONSTANT
|
||||
return self._get_queue(queue_name)
|
||||
|
|
|
@ -82,9 +82,11 @@ class Backend(ABC):
|
|||
kw = dict(queue_id=self.queue_id)
|
||||
app.delta_queue = self.queues.get_delta_queue(**kw)
|
||||
app.readiness_queue = self.queues.get_readiness_queue(**kw)
|
||||
app.commands_requests_queue = self.queues.get_commands_requests_queue(**kw)
|
||||
app.commands_responses_queue = self.queues.get_commands_responses_queue(**kw)
|
||||
app.commands_metadata_queue = self.queues.get_commands_metadata_queue(**kw)
|
||||
app.error_queue = self.queues.get_error_queue(**kw)
|
||||
app.delta_queue = self.queues.get_delta_queue(**kw)
|
||||
app.readiness_queue = self.queues.get_readiness_queue(**kw)
|
||||
app.error_queue = self.queues.get_error_queue(**kw)
|
||||
app.api_publish_state_queue = self.queues.get_api_state_publish_queue(**kw)
|
||||
app.api_delta_queue = self.queues.get_api_delta_queue(**kw)
|
||||
|
|
|
@ -66,6 +66,9 @@ class MultiProcessRuntime(Runtime):
|
|||
api_publish_state_queue=self.app.api_publish_state_queue,
|
||||
api_delta_queue=self.app.api_delta_queue,
|
||||
has_started_queue=has_started_queue,
|
||||
commands_requests_queue=self.app.commands_requests_queue,
|
||||
commands_responses_queue=self.app.commands_responses_queue,
|
||||
commands_metadata_queue=self.app.commands_metadata_queue,
|
||||
spec=extract_metadata_from_app(self.app),
|
||||
)
|
||||
server_proc = multiprocessing.Process(target=start_server, kwargs=kwargs)
|
||||
|
|
|
@ -39,16 +39,15 @@ class FileUploader:
|
|||
self.total_size = total_size
|
||||
self.name = name
|
||||
|
||||
@staticmethod
|
||||
def upload_s3_data(url: str, data: bytes, retries: int, disconnect_retry_wait_seconds: int) -> str:
|
||||
"""Send data to s3 url.
|
||||
def upload_data(self, url: str, data: bytes, retries: int, disconnect_retry_wait_seconds: int) -> str:
|
||||
"""Send data to url.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
url: str
|
||||
S3 url string to send data to
|
||||
url string to send data to
|
||||
data: bytes
|
||||
Bytes of data to send to S3
|
||||
Bytes of data to send to url
|
||||
retries: int
|
||||
Amount of retries
|
||||
disconnect_retry_wait_seconds: int
|
||||
|
@ -65,16 +64,19 @@ class FileUploader:
|
|||
retries = Retry(total=10)
|
||||
with requests.Session() as s:
|
||||
s.mount("https://", HTTPAdapter(max_retries=retries))
|
||||
response = s.put(url, data=data)
|
||||
if "ETag" not in response.headers:
|
||||
raise ValueError(f"Unexpected response from S3, response: {response.content}")
|
||||
return response.headers["ETag"]
|
||||
return self._upload_data(s, url, data)
|
||||
except BrokenPipeError:
|
||||
time.sleep(disconnect_retry_wait_seconds)
|
||||
disconnect_retries -= 1
|
||||
|
||||
raise ValueError("Unable to upload file after multiple attempts")
|
||||
|
||||
def _upload_data(self, s: requests.Session, url: str, data: bytes):
|
||||
resp = s.put(url, data=data)
|
||||
if "ETag" not in resp.headers:
|
||||
raise ValueError(f"Unexpected response from {url}, response: {resp.content}")
|
||||
return resp.headers["ETag"]
|
||||
|
||||
def upload(self) -> None:
|
||||
"""Upload files from source dir into target path in S3."""
|
||||
task_id = self.progress.add_task("upload", filename=self.name, total=self.total_size)
|
||||
|
@ -82,7 +84,7 @@ class FileUploader:
|
|||
try:
|
||||
with open(self.source_file, "rb") as f:
|
||||
data = f.read()
|
||||
self.upload_s3_data(self.presigned_url, data, self.retries, self.disconnect_retry_wait_seconds)
|
||||
self.upload_data(self.presigned_url, data, self.retries, self.disconnect_retry_wait_seconds)
|
||||
self.progress.update(task_id, advance=len(data))
|
||||
finally:
|
||||
self.progress.stop()
|
||||
|
|
|
@ -179,11 +179,7 @@ def run_app_in_cloud(app_folder: str, app_name: str = "app.py") -> Generator:
|
|||
# 5. Create chromium browser, auth to lightning_app.ai and yield the admin and view pages.
|
||||
with sync_playwright() as p:
|
||||
browser = p.chromium.launch(headless=bool(int(os.getenv("HEADLESS", "0"))))
|
||||
payload = {
|
||||
"apiKey": Config.api_key,
|
||||
"username": Config.username,
|
||||
"duration": "120000",
|
||||
}
|
||||
payload = {"apiKey": Config.api_key, "username": Config.username, "duration": "120000"}
|
||||
context = browser.new_context(
|
||||
# Eventually this will need to be deleted
|
||||
http_credentials=HttpCredentials(
|
||||
|
|
|
@ -1,5 +1,11 @@
|
|||
import re
|
||||
from typing import Dict
|
||||
from typing import Dict, Optional
|
||||
|
||||
import requests
|
||||
|
||||
from lightning_app.core.constants import APP_SERVER_PORT
|
||||
from lightning_app.utilities.cloud import _get_project
|
||||
from lightning_app.utilities.network import LightningClient
|
||||
|
||||
|
||||
def _format_input_env_variables(env_list: tuple) -> Dict[str, str]:
|
||||
|
@ -35,3 +41,55 @@ def _format_input_env_variables(env_list: tuple) -> Dict[str, str]:
|
|||
|
||||
env_vars_dict[var_name] = value
|
||||
return env_vars_dict
|
||||
|
||||
|
||||
def _is_url(id: Optional[str]) -> bool:
|
||||
if isinstance(id, str) and (id.startswith("https://") or id.startswith("http://")):
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def _retrieve_application_url_and_available_commands(app_id_or_name_or_url: Optional[str]):
|
||||
"""This function is used to retrieve the current url associated with an id."""
|
||||
|
||||
if _is_url(app_id_or_name_or_url):
|
||||
url = app_id_or_name_or_url
|
||||
assert url
|
||||
resp = requests.get(url + "/api/v1/commands")
|
||||
if resp.status_code != 200:
|
||||
raise Exception(f"The server didn't process the request properly. Found {resp.json()}")
|
||||
return url, resp.json()
|
||||
|
||||
# 2: If no identifier has been provided, evaluate the local application
|
||||
failed_locally = False
|
||||
|
||||
if app_id_or_name_or_url is None:
|
||||
try:
|
||||
url = f"http://localhost:{APP_SERVER_PORT}"
|
||||
resp = requests.get(f"{url}/api/v1/commands")
|
||||
if resp.status_code != 200:
|
||||
raise Exception(f"The server didn't process the request properly. Found {resp.json()}")
|
||||
return url, resp.json()
|
||||
except requests.exceptions.ConnectionError:
|
||||
failed_locally = True
|
||||
|
||||
# 3: If an identified was provided or the local evaluation has failed, evaluate the cloud.
|
||||
if app_id_or_name_or_url or failed_locally:
|
||||
client = LightningClient()
|
||||
project = _get_project(client)
|
||||
list_lightningapps = client.lightningapp_instance_service_list_lightningapp_instances(project.project_id)
|
||||
|
||||
lightningapp_names = [lightningapp.name for lightningapp in list_lightningapps.lightningapps]
|
||||
|
||||
if not app_id_or_name_or_url:
|
||||
raise Exception(f"Provide an application name, id or url with --app_id=X. Found {lightningapp_names}")
|
||||
|
||||
for lightningapp in list_lightningapps.lightningapps:
|
||||
if lightningapp.id == app_id_or_name_or_url or lightningapp.name == app_id_or_name_or_url:
|
||||
if lightningapp.status.url == "":
|
||||
raise Exception("The application is starting. Try in a few moments.")
|
||||
resp = requests.get(lightningapp.status.url + "/api/v1/commands")
|
||||
if resp.status_code != 200:
|
||||
raise Exception(f"The server didn't process the request properly. Found {resp.json()}")
|
||||
return lightningapp.status.url, resp.json()
|
||||
return None, None
|
||||
|
|
|
@ -0,0 +1,3 @@
|
|||
from lightning_app.utilities.commands.base import ClientCommand
|
||||
|
||||
__all__ = ["ClientCommand"]
|
|
@ -0,0 +1,245 @@
|
|||
import errno
|
||||
import inspect
|
||||
import logging
|
||||
import os
|
||||
import os.path as osp
|
||||
import shutil
|
||||
import sys
|
||||
from getpass import getuser
|
||||
from importlib.util import module_from_spec, spec_from_file_location
|
||||
from tempfile import gettempdir
|
||||
from typing import Any, Callable, Dict, List, Optional, Tuple
|
||||
from uuid import uuid4
|
||||
|
||||
import requests
|
||||
from pydantic import BaseModel
|
||||
|
||||
from lightning_app.utilities.app_helpers import is_overridden
|
||||
from lightning_app.utilities.cloud import _get_project
|
||||
from lightning_app.utilities.network import LightningClient
|
||||
from lightning_app.utilities.state import AppState
|
||||
|
||||
_logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def makedirs(path: str):
|
||||
r"""Recursive directory creation function."""
|
||||
try:
|
||||
os.makedirs(osp.expanduser(osp.normpath(path)))
|
||||
except OSError as e:
|
||||
if e.errno != errno.EEXIST and osp.isdir(path):
|
||||
raise e
|
||||
|
||||
|
||||
class _ClientCommandConfig(BaseModel):
|
||||
command: str
|
||||
affiliation: str
|
||||
params: Dict[str, str]
|
||||
is_client_command: bool
|
||||
cls_path: str
|
||||
cls_name: str
|
||||
owner: str
|
||||
requirements: Optional[List[str]]
|
||||
|
||||
|
||||
class ClientCommand:
|
||||
def __init__(self, method: Callable, requirements: Optional[List[str]] = None) -> None:
|
||||
self.method = method
|
||||
flow = getattr(method, "__self__", None)
|
||||
self.owner = flow.name if flow else None
|
||||
self.requirements = requirements
|
||||
self.metadata = None
|
||||
self.models: Optional[Dict[str, BaseModel]] = None
|
||||
self.app_url = None
|
||||
self._state = None
|
||||
|
||||
def _setup(self, metadata: Dict[str, Any], models: Dict[str, BaseModel], app_url: str) -> None:
|
||||
self.metadata = metadata
|
||||
self.models = models
|
||||
self.app_url = app_url
|
||||
|
||||
@property
|
||||
def state(self):
|
||||
if self._state is None:
|
||||
assert self.app_url
|
||||
# TODO: Resolve this hack
|
||||
os.environ["LIGHTNING_APP_STATE_URL"] = "1"
|
||||
self._state = AppState(host=self.app_url)
|
||||
self._state._request_state()
|
||||
os.environ.pop("LIGHTNING_APP_STATE_URL")
|
||||
return self._state
|
||||
|
||||
def run(self, **cli_kwargs) -> None:
|
||||
"""Overrides with the logic to execute on the client side."""
|
||||
|
||||
def invoke_handler(self, **kwargs: Any) -> Dict[str, Any]:
|
||||
from lightning.app.utilities.state import headers_for
|
||||
|
||||
assert kwargs.keys() == self.models.keys()
|
||||
for k, v in kwargs.items():
|
||||
assert isinstance(v, self.models[k])
|
||||
json = {
|
||||
"command_name": self.metadata["command"],
|
||||
"command_arguments": {k: v.json() for k, v in kwargs.items()},
|
||||
"affiliation": self.metadata["affiliation"],
|
||||
"id": str(uuid4()),
|
||||
}
|
||||
resp = requests.post(self.app_url + "/api/v1/commands", json=json, headers=headers_for({}))
|
||||
assert resp.status_code == 200, resp.json()
|
||||
return resp.json()
|
||||
|
||||
def _to_dict(self):
|
||||
return {"owner": self.owner, "requirements": self.requirements}
|
||||
|
||||
def __call__(self, **kwargs: Any) -> Any:
|
||||
assert self.models
|
||||
input = {}
|
||||
for k, v in kwargs.items():
|
||||
input[k] = self.models[k].parse_raw(v)
|
||||
return self.method(**input)
|
||||
|
||||
|
||||
def _download_command(
|
||||
command_metadata: Dict[str, Any],
|
||||
app_id: Optional[str],
|
||||
debug_mode: bool = False,
|
||||
) -> Tuple[ClientCommand, Dict[str, BaseModel]]:
|
||||
# TODO: This is a skateboard implementation and the final version will rely on versioned
|
||||
# immutable commands for security concerns
|
||||
config = _ClientCommandConfig(**command_metadata)
|
||||
tmpdir = osp.join(gettempdir(), f"{getuser()}_commands")
|
||||
makedirs(tmpdir)
|
||||
target_file = osp.join(tmpdir, f"{config.command}.py")
|
||||
if app_id:
|
||||
client = LightningClient()
|
||||
project_id = _get_project(client).project_id
|
||||
response = client.lightningapp_instance_service_list_lightningapp_instance_artifacts(project_id, app_id)
|
||||
for artifact in response.artifacts:
|
||||
if f"commands/{config.command}.py" == artifact.filename:
|
||||
r = requests.get(artifact.url, allow_redirects=True)
|
||||
with open(target_file, "wb") as f:
|
||||
f.write(r.content)
|
||||
else:
|
||||
if not debug_mode:
|
||||
shutil.copy(config.cls_path, target_file)
|
||||
|
||||
cls_name = config.cls_name
|
||||
spec = spec_from_file_location(config.cls_name, config.cls_path if debug_mode else target_file)
|
||||
mod = module_from_spec(spec)
|
||||
sys.modules[cls_name] = mod
|
||||
spec.loader.exec_module(mod)
|
||||
command = getattr(mod, cls_name)(method=None, requirements=config.requirements)
|
||||
models = {k: getattr(mod, v) for k, v in config.params.items()}
|
||||
if debug_mode:
|
||||
shutil.rmtree(tmpdir)
|
||||
return command, models
|
||||
|
||||
|
||||
def _to_annotation(anno: str) -> str:
|
||||
anno = anno.split("'")[1]
|
||||
if "." in anno:
|
||||
return anno.split(".")[-1]
|
||||
return anno
|
||||
|
||||
|
||||
def _command_to_method_and_metadata(command: ClientCommand) -> Tuple[Callable, Dict[str, Any]]:
|
||||
"""Extract method and its metadata from a ClientCommand."""
|
||||
params = inspect.signature(command.method).parameters
|
||||
command_metadata = {
|
||||
"cls_path": inspect.getfile(command.__class__),
|
||||
"cls_name": command.__class__.__name__,
|
||||
"params": {p.name: _to_annotation(str(p.annotation)) for p in params.values()},
|
||||
**command._to_dict(),
|
||||
}
|
||||
method = command.method
|
||||
command.models = {}
|
||||
for k, v in command_metadata["params"].items():
|
||||
if v == "_empty":
|
||||
raise Exception(
|
||||
f"Please, annotate your method {method} with pydantic BaseModel. Refer to the documentation."
|
||||
)
|
||||
config = getattr(sys.modules[command.__module__], v, None)
|
||||
if config is None:
|
||||
config = getattr(sys.modules[method.__module__], v, None)
|
||||
if config:
|
||||
raise Exception(
|
||||
f"The provided annotation for the argument {k} should in the file "
|
||||
f"{inspect.getfile(command.__class__)}, not {inspect.getfile(command.method)}."
|
||||
)
|
||||
if config is None or not issubclass(config, BaseModel):
|
||||
raise Exception(
|
||||
f"The provided annotation for the argument {k} shouldn't an instance of pydantic BaseModel."
|
||||
)
|
||||
command.models[k] = config
|
||||
return method, command_metadata
|
||||
|
||||
|
||||
def _upload_command(command_name: str, command: ClientCommand) -> Optional[str]:
|
||||
from lightning_app.storage.path import _is_s3fs_available, filesystem, shared_storage_path
|
||||
|
||||
filepath = f"commands/{command_name}.py"
|
||||
remote_url = str(shared_storage_path() / "artifacts" / filepath)
|
||||
fs = filesystem()
|
||||
|
||||
if _is_s3fs_available():
|
||||
from s3fs import S3FileSystem
|
||||
|
||||
if not isinstance(fs, S3FileSystem):
|
||||
return
|
||||
source_file = str(inspect.getfile(command.__class__))
|
||||
remote_url = str(shared_storage_path() / "artifacts" / filepath)
|
||||
fs.put(source_file, remote_url)
|
||||
return filepath
|
||||
|
||||
|
||||
def _populate_commands_endpoint(app):
|
||||
if not is_overridden("configure_commands", app.root):
|
||||
return
|
||||
|
||||
# 1: Populate commands metadata
|
||||
commands = app.root.configure_commands()
|
||||
commands_metadata = []
|
||||
command_names = set()
|
||||
for command_mapping in commands:
|
||||
for command_name, command in command_mapping.items():
|
||||
is_client_command = isinstance(command, ClientCommand)
|
||||
extras = {}
|
||||
if is_client_command:
|
||||
_upload_command(command_name, command)
|
||||
command, extras = _command_to_method_and_metadata(command)
|
||||
if command_name in command_names:
|
||||
raise Exception(f"The component name {command_name} has already been used. They need to be unique.")
|
||||
command_names.add(command_name)
|
||||
params = inspect.signature(command).parameters
|
||||
commands_metadata.append(
|
||||
{
|
||||
"command": command_name,
|
||||
"affiliation": command.__self__.name,
|
||||
"params": list(params.keys()),
|
||||
"is_client_command": is_client_command,
|
||||
**extras,
|
||||
}
|
||||
)
|
||||
|
||||
# 1.2: Pass the collected commands through the queue to the Rest API.
|
||||
app.commands_metadata_queue.put(commands_metadata)
|
||||
app.commands = commands
|
||||
|
||||
|
||||
def _process_command_requests(app):
|
||||
if not is_overridden("configure_commands", app.root):
|
||||
return
|
||||
|
||||
# 1: Populate commands metadata
|
||||
commands = app.commands
|
||||
|
||||
# 2: Collect requests metadata
|
||||
command_query = app.get_state_changed_from_queue(app.commands_requests_queue)
|
||||
if command_query:
|
||||
for command in commands:
|
||||
for command_name, method in command.items():
|
||||
if command_query["command_name"] == command_name:
|
||||
# 2.1: Evaluate the method associated to a specific command.
|
||||
# Validation is done on the CLI side.
|
||||
response = method(**command_query["command_arguments"])
|
||||
app.commands_responses_queue.put({"response": response, "id": command_query["id"]})
|
|
@ -15,7 +15,7 @@ from packaging.version import Version
|
|||
|
||||
from lightning_app import _logger, _PROJECT_ROOT, _root_logger
|
||||
from lightning_app.__version__ import version
|
||||
from lightning_app.core.constants import PREPARE_LIGHTING
|
||||
from lightning_app.core.constants import PACKAGE_LIGHTNING
|
||||
from lightning_app.utilities.git import check_github_repository, get_dir_name
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
@ -96,11 +96,13 @@ def _prepare_lightning_wheels_and_requirements(root: Path) -> Optional[Callable]
|
|||
# Packaging the Lightning codebase happens only inside the `lightning` repo.
|
||||
git_dir_name = get_dir_name() if check_github_repository() else None
|
||||
|
||||
if not PREPARE_LIGHTING and (not git_dir_name or (git_dir_name and not git_dir_name.startswith("lightning"))):
|
||||
is_lightning = git_dir_name and git_dir_name == "lightning"
|
||||
|
||||
if (PACKAGE_LIGHTNING is None and not is_lightning) or PACKAGE_LIGHTNING == "0":
|
||||
return
|
||||
if not bool(int(os.getenv("SKIP_LIGHTING_WHEELS_BUILD", "0"))):
|
||||
download_frontend(_PROJECT_ROOT)
|
||||
_prepare_wheel(_PROJECT_ROOT)
|
||||
|
||||
download_frontend(_PROJECT_ROOT)
|
||||
_prepare_wheel(_PROJECT_ROOT)
|
||||
|
||||
logger.info("Packaged Lightning with your application.")
|
||||
|
||||
|
@ -108,11 +110,12 @@ def _prepare_lightning_wheels_and_requirements(root: Path) -> Optional[Callable]
|
|||
|
||||
tar_files = [os.path.join(root, tar_name)]
|
||||
|
||||
# skipping this by default
|
||||
if not bool(int(os.getenv("SKIP_LIGHTING_UTILITY_WHEELS_BUILD", "1"))):
|
||||
# Don't skip by default
|
||||
if (PACKAGE_LIGHTNING or is_lightning) and not bool(int(os.getenv("SKIP_LIGHTING_UTILITY_WHEELS_BUILD", "0"))):
|
||||
# building and copying launcher wheel if installed in editable mode
|
||||
launcher_project_path = get_dist_path_if_editable_install("lightning_launcher")
|
||||
if launcher_project_path:
|
||||
logger.info("Packaged Lightning Launcher with your application.")
|
||||
_prepare_wheel(launcher_project_path)
|
||||
tar_name = _copy_tar(launcher_project_path, root)
|
||||
tar_files.append(os.path.join(root, tar_name))
|
||||
|
@ -120,6 +123,7 @@ def _prepare_lightning_wheels_and_requirements(root: Path) -> Optional[Callable]
|
|||
# building and copying lightning-cloud wheel if installed in editable mode
|
||||
lightning_cloud_project_path = get_dist_path_if_editable_install("lightning_cloud")
|
||||
if lightning_cloud_project_path:
|
||||
logger.info("Packaged Lightning Cloud with your application.")
|
||||
_prepare_wheel(lightning_cloud_project_path)
|
||||
tar_name = _copy_tar(lightning_cloud_project_path, root)
|
||||
tar_files.append(os.path.join(root, tar_name))
|
||||
|
|
|
@ -5,6 +5,7 @@ import signal
|
|||
import sys
|
||||
import threading
|
||||
import time
|
||||
import traceback
|
||||
import warnings
|
||||
from copy import deepcopy
|
||||
from dataclasses import dataclass
|
||||
|
@ -398,6 +399,9 @@ class WorkRunner:
|
|||
)
|
||||
self.delta_queue.put(ComponentDelta(id=self.work_name, delta=Delta(DeepDiff(state, self.work.state))))
|
||||
self.work.on_exception(e)
|
||||
print("########## CAPTURED EXCEPTION ###########")
|
||||
print(traceback.print_exc())
|
||||
print("########## CAPTURED EXCEPTION ###########")
|
||||
return
|
||||
|
||||
# 14. Copy all artifacts to the shared storage so other Works can access them while this Work gets scaled down
|
||||
|
|
|
@ -5,7 +5,7 @@ import pytest
|
|||
from click.testing import CliRunner
|
||||
from lightning_cloud.openapi import Externalv1LightningappInstance
|
||||
|
||||
from lightning_app.cli.lightning_cli import get_app_url, login, logout, main, run
|
||||
from lightning_app.cli.lightning_cli import _main, get_app_url, login, logout, run
|
||||
from lightning_app.runners.runtime_type import RuntimeType
|
||||
|
||||
|
||||
|
@ -37,7 +37,7 @@ def test_start_target_url(runtime_type, extra_args, lightning_cloud_url, expecte
|
|||
assert get_app_url(runtime_type, *extra_args) == expected_url
|
||||
|
||||
|
||||
@pytest.mark.parametrize("command", [main, run])
|
||||
@pytest.mark.parametrize("command", [_main, run])
|
||||
def test_commands(command):
|
||||
runner = CliRunner()
|
||||
result = runner.invoke(command)
|
||||
|
@ -46,12 +46,12 @@ def test_commands(command):
|
|||
|
||||
def test_main_lightning_cli_help():
|
||||
"""Validate the Lightning CLI."""
|
||||
res = os.popen("python -m lightning_app --help").read()
|
||||
res = os.popen("python -m lightning --help").read()
|
||||
assert "login " in res
|
||||
assert "logout " in res
|
||||
assert "run " in res
|
||||
|
||||
res = os.popen("python -m lightning_app run --help").read()
|
||||
res = os.popen("python -m lightning run --help").read()
|
||||
assert "app " in res
|
||||
|
||||
# hidden run commands should not appear in the help text
|
||||
|
|
|
@ -17,10 +17,9 @@ def test_non_existing_python_script():
|
|||
run_work_isolated(python_script)
|
||||
assert not python_script.has_started
|
||||
|
||||
with pytest.raises(FileNotFoundError, match=match):
|
||||
python_script = TracerPythonScript(match)
|
||||
run_work_isolated(python_script)
|
||||
assert not python_script.has_started
|
||||
python_script = TracerPythonScript(match, raise_exception=False)
|
||||
run_work_isolated(python_script)
|
||||
assert python_script.has_failed
|
||||
|
||||
|
||||
def test_simple_python_script():
|
||||
|
|
|
@ -161,10 +161,12 @@ def test_update_publish_state_and_maybe_refresh_ui():
|
|||
|
||||
app = AppStageTestingApp(FlowA(), debug=True)
|
||||
publish_state_queue = MockQueue("publish_state_queue")
|
||||
commands_metadata_queue = MockQueue("commands_metadata_queue")
|
||||
commands_responses_queue = MockQueue("commands_metadata_queue")
|
||||
|
||||
publish_state_queue.put(app.state_with_changes)
|
||||
|
||||
thread = UIRefresher(publish_state_queue)
|
||||
thread = UIRefresher(publish_state_queue, commands_metadata_queue, commands_responses_queue)
|
||||
thread.run_once()
|
||||
|
||||
assert global_app_state_store.get_app_state("1234") == app.state_with_changes
|
||||
|
@ -190,11 +192,21 @@ async def test_start_server(x_lightning_type):
|
|||
publish_state_queue = InfiniteQueue("publish_state_queue")
|
||||
change_state_queue = MockQueue("change_state_queue")
|
||||
has_started_queue = MockQueue("has_started_queue")
|
||||
commands_requests_queue = MockQueue("commands_requests_queue")
|
||||
commands_responses_queue = MockQueue("commands_responses_queue")
|
||||
commands_metadata_queue = MockQueue("commands_metadata_queue")
|
||||
state = app.state_with_changes
|
||||
publish_state_queue.put(state)
|
||||
spec = extract_metadata_from_app(app)
|
||||
ui_refresher = start_server(
|
||||
publish_state_queue, change_state_queue, has_started_queue=has_started_queue, uvicorn_run=False, spec=spec
|
||||
publish_state_queue,
|
||||
change_state_queue,
|
||||
commands_requests_queue,
|
||||
commands_responses_queue,
|
||||
commands_metadata_queue,
|
||||
has_started_queue=has_started_queue,
|
||||
uvicorn_run=False,
|
||||
spec=spec,
|
||||
)
|
||||
headers = headers_for({"type": x_lightning_type})
|
||||
|
||||
|
@ -331,10 +343,16 @@ def test_start_server_started():
|
|||
api_publish_state_queue = mp.Queue()
|
||||
api_delta_queue = mp.Queue()
|
||||
has_started_queue = mp.Queue()
|
||||
commands_requests_queue = mp.Queue()
|
||||
commands_responses_queue = mp.Queue()
|
||||
commands_metadata_queue = mp.Queue()
|
||||
kwargs = dict(
|
||||
api_publish_state_queue=api_publish_state_queue,
|
||||
api_delta_queue=api_delta_queue,
|
||||
has_started_queue=has_started_queue,
|
||||
commands_requests_queue=commands_requests_queue,
|
||||
commands_responses_queue=commands_responses_queue,
|
||||
commands_metadata_queue=commands_metadata_queue,
|
||||
port=1111,
|
||||
)
|
||||
|
||||
|
@ -354,12 +372,18 @@ def test_start_server_info_message(ui_refresher, uvicorn_run, caplog, monkeypatc
|
|||
api_publish_state_queue = MockQueue()
|
||||
api_delta_queue = MockQueue()
|
||||
has_started_queue = MockQueue()
|
||||
commands_requests_queue = MockQueue()
|
||||
commands_responses_queue = MockQueue()
|
||||
commands_metadata_queue = MockQueue()
|
||||
kwargs = dict(
|
||||
host=host,
|
||||
port=1111,
|
||||
api_publish_state_queue=api_publish_state_queue,
|
||||
api_delta_queue=api_delta_queue,
|
||||
has_started_queue=has_started_queue,
|
||||
commands_requests_queue=commands_requests_queue,
|
||||
commands_responses_queue=commands_responses_queue,
|
||||
commands_metadata_queue=commands_metadata_queue,
|
||||
)
|
||||
|
||||
monkeypatch.setattr(api, "logger", logging.getLogger())
|
||||
|
|
|
@ -39,10 +39,11 @@ def test_file_uploader():
|
|||
@mock.patch("lightning_app.source_code.uploader.requests.Session", MockedRequestSession)
|
||||
def test_file_uploader_failing_when_no_etag():
|
||||
response["response"] = MagicMock(headers={})
|
||||
presigned_url = "https://test-url"
|
||||
file_uploader = uploader.FileUploader(
|
||||
presigned_url="https://test-url", source_file="test.txt", total_size=100, name="test.txt"
|
||||
presigned_url=presigned_url, source_file="test.txt", total_size=100, name="test.txt"
|
||||
)
|
||||
file_uploader.progress = MagicMock()
|
||||
|
||||
with pytest.raises(ValueError, match="Unexpected response from S3, response"):
|
||||
with pytest.raises(ValueError, match=f"Unexpected response from {presigned_url}, response"):
|
||||
file_uploader.upload()
|
||||
|
|
|
@ -0,0 +1,162 @@
|
|||
import argparse
|
||||
import sys
|
||||
from multiprocessing import Process
|
||||
from time import sleep
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
import requests
|
||||
from pydantic import BaseModel
|
||||
|
||||
from lightning import LightningFlow
|
||||
from lightning_app import LightningApp
|
||||
from lightning_app.cli.lightning_cli import app_command
|
||||
from lightning_app.core.constants import APP_SERVER_PORT
|
||||
from lightning_app.runners import MultiProcessRuntime
|
||||
from lightning_app.testing.helpers import RunIf
|
||||
from lightning_app.utilities.commands.base import _command_to_method_and_metadata, _download_command, ClientCommand
|
||||
from lightning_app.utilities.state import AppState
|
||||
|
||||
|
||||
class SweepConfig(BaseModel):
|
||||
sweep_name: str
|
||||
num_trials: int
|
||||
|
||||
|
||||
class SweepCommand(ClientCommand):
|
||||
def run(self) -> None:
|
||||
print(sys.argv)
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--sweep_name", type=str)
|
||||
parser.add_argument("--num_trials", type=int)
|
||||
hparams = parser.parse_args()
|
||||
|
||||
config = SweepConfig(sweep_name=hparams.sweep_name, num_trials=hparams.num_trials)
|
||||
response = self.invoke_handler(config=config)
|
||||
assert response is True
|
||||
|
||||
|
||||
class FlowCommands(LightningFlow):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.names = []
|
||||
self.has_sweep = False
|
||||
|
||||
def run(self):
|
||||
if self.has_sweep and len(self.names) == 1:
|
||||
sleep(2)
|
||||
self._exit()
|
||||
|
||||
def trigger_method(self, name: str):
|
||||
self.names.append(name)
|
||||
|
||||
def sweep(self, config: SweepConfig):
|
||||
self.has_sweep = True
|
||||
return True
|
||||
|
||||
def configure_commands(self):
|
||||
return [{"user_command": self.trigger_method}, {"sweep": SweepCommand(self.sweep)}]
|
||||
|
||||
|
||||
class DummyConfig(BaseModel):
|
||||
something: str
|
||||
something_else: int
|
||||
|
||||
|
||||
class DummyCommand(ClientCommand):
|
||||
def run(self, something: str, something_else: int) -> None:
|
||||
config = DummyConfig(something=something, something_else=something_else)
|
||||
response = self.invoke_handler(config=config)
|
||||
assert response == {"body": 0}
|
||||
|
||||
|
||||
def run(config: DummyConfig):
|
||||
assert isinstance(config, DummyCommand)
|
||||
|
||||
|
||||
def run_failure_0(name: str):
|
||||
pass
|
||||
|
||||
|
||||
def run_failure_1(name):
|
||||
pass
|
||||
|
||||
|
||||
class CustomModel(BaseModel):
|
||||
pass
|
||||
|
||||
|
||||
def run_failure_2(name: CustomModel):
|
||||
pass
|
||||
|
||||
|
||||
@RunIf(skip_windows=True)
|
||||
def test_command_to_method_and_metadata():
|
||||
with pytest.raises(Exception, match="The provided annotation for the argument name"):
|
||||
_command_to_method_and_metadata(ClientCommand(run_failure_0))
|
||||
|
||||
with pytest.raises(Exception, match="annotate your method"):
|
||||
_command_to_method_and_metadata(ClientCommand(run_failure_1))
|
||||
|
||||
with pytest.raises(Exception, match="lightning_app/utilities/commands/base.py"):
|
||||
_command_to_method_and_metadata(ClientCommand(run_failure_2))
|
||||
|
||||
|
||||
def test_client_commands(monkeypatch):
|
||||
import requests
|
||||
|
||||
resp = MagicMock()
|
||||
resp.status_code = 200
|
||||
value = {"body": 0}
|
||||
resp.json = MagicMock(return_value=value)
|
||||
post = MagicMock()
|
||||
post.return_value = resp
|
||||
monkeypatch.setattr(requests, "post", post)
|
||||
url = "http//"
|
||||
kwargs = {"something": "1", "something_else": "1"}
|
||||
command = DummyCommand(run)
|
||||
_, command_metadata = _command_to_method_and_metadata(command)
|
||||
command_metadata.update(
|
||||
{
|
||||
"command": "dummy",
|
||||
"affiliation": "root",
|
||||
"is_client_command": True,
|
||||
"owner": "root",
|
||||
}
|
||||
)
|
||||
client_command, models = _download_command(command_metadata, None)
|
||||
client_command._setup(metadata=command_metadata, models=models, app_url=url)
|
||||
client_command.run(**kwargs)
|
||||
|
||||
|
||||
def target():
|
||||
app = LightningApp(FlowCommands())
|
||||
MultiProcessRuntime(app).dispatch()
|
||||
|
||||
|
||||
def test_configure_commands(monkeypatch):
|
||||
process = Process(target=target)
|
||||
process.start()
|
||||
time_left = 15
|
||||
while time_left > 0:
|
||||
try:
|
||||
requests.get(f"http://localhost:{APP_SERVER_PORT}/healthz")
|
||||
break
|
||||
except requests.exceptions.ConnectionError:
|
||||
sleep(0.1)
|
||||
time_left -= 0.1
|
||||
|
||||
sleep(0.5)
|
||||
monkeypatch.setattr(sys, "argv", ["lightning", "user_command", "--name=something"])
|
||||
app_command()
|
||||
sleep(0.5)
|
||||
state = AppState()
|
||||
state._request_state()
|
||||
assert state.names == ["something"]
|
||||
monkeypatch.setattr(sys, "argv", ["lightning", "sweep", "--sweep_name", "my_name", "--num_trials", "1"])
|
||||
app_command()
|
||||
time_left = 15
|
||||
while time_left > 0 or process.exitcode is None:
|
||||
sleep(0.1)
|
||||
time_left -= 0.1
|
||||
assert process.exitcode == 0
|
|
@ -15,7 +15,7 @@ from lightning_app.utilities.state import AppState
|
|||
def test_app_state_not_connected(_):
|
||||
|
||||
"""Test an error message when a disconnected AppState tries to access attributes."""
|
||||
state = AppState()
|
||||
state = AppState(port=8000)
|
||||
with pytest.raises(AttributeError, match="Failed to connect and fetch the app state"):
|
||||
_ = state.value
|
||||
with pytest.raises(AttributeError, match="Failed to connect and fetch the app state"):
|
||||
|
@ -209,7 +209,7 @@ def test_attach_plugin():
|
|||
@mock.patch("lightning_app.utilities.state._configure_session", return_value=requests)
|
||||
def test_app_state_connection_error(_):
|
||||
"""Test an error message when a connection to retrieve the state can't be established."""
|
||||
app_state = AppState()
|
||||
app_state = AppState(port=8000)
|
||||
with pytest.raises(AttributeError, match=r"Failed to connect and fetch the app state\. Is the app running?"):
|
||||
app_state._request_state()
|
||||
|
||||
|
|
|
@ -0,0 +1,31 @@
|
|||
import os
|
||||
from subprocess import Popen
|
||||
from time import sleep
|
||||
from unittest import mock
|
||||
|
||||
import pytest
|
||||
from tests_app import _PROJECT_ROOT
|
||||
|
||||
from lightning_app.testing.testing import run_app_in_cloud
|
||||
|
||||
|
||||
@mock.patch.dict(os.environ, {"SKIP_LIGHTING_UTILITY_WHEELS_BUILD": "0"})
|
||||
@pytest.mark.cloud
|
||||
def test_commands_example_cloud() -> None:
|
||||
with run_app_in_cloud(os.path.join(_PROJECT_ROOT, "examples/app_commands")) as (
|
||||
admin_page,
|
||||
_,
|
||||
fetch_logs,
|
||||
):
|
||||
app_id = admin_page.url.split("/")[-1]
|
||||
cmd = f"lightning trigger_with_client_command --name=something --app_id {app_id}"
|
||||
Popen(cmd, shell=True).wait()
|
||||
cmd = f"lightning trigger_without_client_command --name=else --app_id {app_id}"
|
||||
Popen(cmd, shell=True).wait()
|
||||
|
||||
has_logs = False
|
||||
while not has_logs:
|
||||
for log in fetch_logs():
|
||||
if "['something', 'else']" in log:
|
||||
has_logs = True
|
||||
sleep(1)
|
Loading…
Reference in New Issue