[App] Add support for plugins to return actions (#16832)
This commit is contained in:
parent
62e3d5854f
commit
beced48904
|
@ -1,4 +1,4 @@
|
|||
lightning-cloud>=0.5.26
|
||||
lightning-cloud>=0.5.27
|
||||
packaging
|
||||
typing-extensions>=4.0.0, <=4.4.0
|
||||
deepdiff>=5.7.0, <6.2.4
|
||||
|
|
|
@ -33,9 +33,9 @@ if "__version__" not in locals():
|
|||
|
||||
from lightning.app.core.app import LightningApp # noqa: E402
|
||||
from lightning.app.core.flow import LightningFlow # noqa: E402
|
||||
from lightning.app.core.plugin import LightningPlugin # noqa: E402
|
||||
from lightning.app.core.work import LightningWork # noqa: E402
|
||||
from lightning.app.perf import pdb # noqa: E402
|
||||
from lightning.app.plugin.plugin import LightningPlugin # noqa: E402
|
||||
from lightning.app.utilities.packaging.build_config import BuildConfig # noqa: E402
|
||||
from lightning.app.utilities.packaging.cloud_compute import CloudCompute # noqa: E402
|
||||
|
||||
|
|
|
@ -1,6 +1,5 @@
|
|||
from lightning.app.core.app import LightningApp
|
||||
from lightning.app.core.flow import LightningFlow
|
||||
from lightning.app.core.plugin import LightningPlugin
|
||||
from lightning.app.core.work import LightningWork
|
||||
|
||||
__all__ = ["LightningApp", "LightningFlow", "LightningWork", "LightningPlugin"]
|
||||
__all__ = ["LightningApp", "LightningFlow", "LightningWork"]
|
||||
|
|
|
@ -0,0 +1,5 @@
|
|||
import lightning.app.plugin.actions as actions
|
||||
from lightning.app.plugin.actions import NavigateTo, Toast, ToastSeverity
|
||||
from lightning.app.plugin.plugin import LightningPlugin
|
||||
|
||||
__all__ = ["LightningPlugin", "actions", "Toast", "ToastSeverity", "NavigateTo"]
|
|
@ -0,0 +1,72 @@
|
|||
# Copyright The Lightning AI team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
from typing import Union
|
||||
|
||||
from lightning_cloud.openapi.models import V1CloudSpaceAppAction, V1CloudSpaceAppActionType
|
||||
|
||||
|
||||
class _Action:
|
||||
"""Actions are returned by `LightningPlugin` objects to perform actions in the UI."""
|
||||
|
||||
def to_spec(self) -> V1CloudSpaceAppAction:
|
||||
"""Convert this action to a ``V1CloudSpaceAppAction``"""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
@dataclass
|
||||
class NavigateTo(_Action):
|
||||
"""The ``NavigateTo`` action can be used to navigate to a relative URL within the Lightning frontend.
|
||||
|
||||
Args:
|
||||
url: The relative URL to navigate to. E.g. ``/<username>/<project>``.
|
||||
"""
|
||||
|
||||
url: str
|
||||
|
||||
def to_spec(self) -> V1CloudSpaceAppAction:
|
||||
return V1CloudSpaceAppAction(
|
||||
type=V1CloudSpaceAppActionType.NAVIGATE_TO,
|
||||
content=self.url,
|
||||
)
|
||||
|
||||
|
||||
class ToastSeverity(Enum):
|
||||
ERROR = "error"
|
||||
INFO = "info"
|
||||
SUCCESS = "success"
|
||||
WARNING = "warning"
|
||||
|
||||
def __str__(self) -> str:
|
||||
return self.value
|
||||
|
||||
|
||||
@dataclass
|
||||
class Toast(_Action):
|
||||
"""The ``Toast`` action can be used to display a toast message to the user.
|
||||
|
||||
Args:
|
||||
severity: The severity level of the toast. One of: "error", "info", "success", "warning".
|
||||
message: The message body.
|
||||
"""
|
||||
|
||||
severity: Union[ToastSeverity, str]
|
||||
message: str
|
||||
|
||||
def to_spec(self) -> V1CloudSpaceAppAction:
|
||||
return V1CloudSpaceAppAction(
|
||||
type=V1CloudSpaceAppActionType.TOAST,
|
||||
content=f"{self.severity}:{self.message}",
|
||||
)
|
|
@ -15,7 +15,7 @@ import os
|
|||
import tarfile
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Optional
|
||||
from typing import Any, Dict, List, Optional
|
||||
from urllib.parse import urlparse
|
||||
|
||||
import requests
|
||||
|
@ -24,6 +24,8 @@ from fastapi import FastAPI, HTTPException, status
|
|||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from pydantic import BaseModel
|
||||
|
||||
from lightning.app.core import constants
|
||||
from lightning.app.plugin.actions import _Action
|
||||
from lightning.app.utilities.app_helpers import Logger
|
||||
from lightning.app.utilities.component import _set_flow_context
|
||||
from lightning.app.utilities.enum import AppStage
|
||||
|
@ -41,16 +43,20 @@ class LightningPlugin:
|
|||
self.cloudspace_id = ""
|
||||
self.cluster_id = ""
|
||||
|
||||
def run(self, *args: str, **kwargs: str) -> None:
|
||||
def run(self, *args: str, **kwargs: str) -> Optional[List[_Action]]:
|
||||
"""Override with the logic to execute on the cloudspace."""
|
||||
raise NotImplementedError
|
||||
|
||||
def run_job(self, name: str, app_entrypoint: str, env_vars: Optional[Dict[str, str]] = None) -> None:
|
||||
def run_job(self, name: str, app_entrypoint: str, env_vars: Optional[Dict[str, str]] = None) -> str:
|
||||
"""Run a job in the cloudspace associated with this plugin.
|
||||
|
||||
Args:
|
||||
name: The name of the job.
|
||||
app_entrypoint: The path of the file containing the app to run.
|
||||
env_vars: Additional env vars to set when running the app.
|
||||
|
||||
Returns:
|
||||
The relative URL of the created job.
|
||||
"""
|
||||
from lightning.app.runners.cloud import CloudRuntime
|
||||
|
||||
|
@ -74,12 +80,14 @@ class LightningPlugin:
|
|||
# Used to indicate Lightning has been dispatched
|
||||
os.environ["LIGHTNING_DISPATCHED"] = "1"
|
||||
|
||||
runtime.cloudspace_dispatch(
|
||||
url = runtime.cloudspace_dispatch(
|
||||
project_id=self.project_id,
|
||||
cloudspace_id=self.cloudspace_id,
|
||||
name=name,
|
||||
cluster_id=self.cluster_id,
|
||||
)
|
||||
# Return a relative URL so it can be used with the NavigateTo action.
|
||||
return url.replace(constants.get_lightning_cloud_url(), "")
|
||||
|
||||
def _setup(
|
||||
self,
|
||||
|
@ -101,7 +109,7 @@ class _Run(BaseModel):
|
|||
plugin_arguments: Dict[str, str]
|
||||
|
||||
|
||||
def _run_plugin(run: _Run) -> List:
|
||||
def _run_plugin(run: _Run) -> Dict[str, Any]:
|
||||
"""Create a run with the given name and entrypoint under the cloudspace with the given ID."""
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
download_path = os.path.join(tmpdir, "source.tar.gz")
|
||||
|
@ -115,6 +123,9 @@ def _run_plugin(run: _Run) -> List:
|
|||
|
||||
response = requests.get(source_code_url)
|
||||
|
||||
# TODO: Backoff retry a few times in case the URL is flaky
|
||||
response.raise_for_status()
|
||||
|
||||
with open(download_path, "wb") as f:
|
||||
f.write(response.content)
|
||||
except Exception as e:
|
||||
|
@ -152,7 +163,8 @@ def _run_plugin(run: _Run) -> List:
|
|||
cloudspace_id=run.cloudspace_id,
|
||||
cluster_id=run.cluster_id,
|
||||
)
|
||||
plugin.run(**run.plugin_arguments)
|
||||
actions = plugin.run(**run.plugin_arguments) or []
|
||||
return {"actions": [action.to_spec().to_dict() for action in actions]}
|
||||
except Exception as e:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"Error running plugin: {str(e)}."
|
||||
|
@ -160,9 +172,6 @@ def _run_plugin(run: _Run) -> List:
|
|||
finally:
|
||||
os.chdir(cwd)
|
||||
|
||||
# TODO: Return actions from the plugin here
|
||||
return []
|
||||
|
||||
|
||||
def _start_plugin_server(host: str, port: int) -> None:
|
||||
"""Start the plugin server which can be used to dispatch apps or run plugins."""
|
|
@ -196,7 +196,7 @@ class CloudRuntime(Runtime):
|
|||
cloudspace_id: str,
|
||||
name: str,
|
||||
cluster_id: str,
|
||||
):
|
||||
) -> str:
|
||||
"""Slim dispatch for creating runs from a cloudspace. This dispatch avoids resolution of some properties
|
||||
such as the project and cluster IDs that are instead passed directly.
|
||||
|
||||
|
@ -210,12 +210,15 @@ class CloudRuntime(Runtime):
|
|||
ApiException: If there was an issue in the backend.
|
||||
RuntimeError: If there are validation errors.
|
||||
ValueError: If there are validation errors.
|
||||
|
||||
Returns:
|
||||
The URL of the created job.
|
||||
"""
|
||||
# Dispatch in four phases: resolution, validation, spec creation, API transactions
|
||||
# Resolution
|
||||
root = self._resolve_root()
|
||||
repo = self._resolve_repo(root)
|
||||
self._resolve_cloudspace(project_id, cloudspace_id)
|
||||
project = self._resolve_project(project_id=project_id)
|
||||
existing_instances = self._resolve_run_instances_by_name(project_id, name)
|
||||
name = self._resolve_run_name(name, existing_instances)
|
||||
queue_server_type = self._resolve_queue_server_type()
|
||||
|
@ -240,7 +243,7 @@ class CloudRuntime(Runtime):
|
|||
run = self._api_create_run(project_id, cloudspace_id, run_body)
|
||||
self._api_package_and_upload_repo(repo, run)
|
||||
|
||||
self._api_create_run_instance(
|
||||
run_instance = self._api_create_run_instance(
|
||||
cluster_id,
|
||||
project_id,
|
||||
name,
|
||||
|
@ -251,6 +254,8 @@ class CloudRuntime(Runtime):
|
|||
env_vars,
|
||||
)
|
||||
|
||||
return self._get_app_url(project, run_instance, "logs" if run.is_headless else "web-ui")
|
||||
|
||||
def dispatch(
|
||||
self,
|
||||
name: str = "",
|
||||
|
@ -451,16 +456,9 @@ class CloudRuntime(Runtime):
|
|||
|
||||
return LocalSourceCodeDir(path=root, ignore_functions=ignore_functions)
|
||||
|
||||
def _resolve_project(self) -> V1Membership:
|
||||
def _resolve_project(self, project_id: Optional[str] = None) -> V1Membership:
|
||||
"""Determine the project to run on, choosing a default if multiple projects are found."""
|
||||
return _get_project(self.backend.client)
|
||||
|
||||
def _resolve_cloudspace(self, project_id: str, cloudspace_id: str) -> V1CloudSpace:
|
||||
"""Get a cloudspace by project / cloudspace ID."""
|
||||
return self.backend.client.cloud_space_service_get_cloud_space(
|
||||
project_id=project_id,
|
||||
id=cloudspace_id,
|
||||
)
|
||||
return _get_project(self.backend.client, project_id=project_id)
|
||||
|
||||
def _resolve_existing_cloudspaces(self, project_id: str, cloudspace_name: str) -> List[V1CloudSpace]:
|
||||
"""Lists all the cloudspaces with a name matching the provided cloudspace name."""
|
||||
|
|
|
@ -13,6 +13,7 @@
|
|||
# limitations under the License.
|
||||
|
||||
import os
|
||||
from typing import Optional
|
||||
|
||||
from lightning_cloud.openapi import V1Membership
|
||||
|
||||
|
@ -22,10 +23,11 @@ from lightning.app.utilities.enum import AppStage
|
|||
from lightning.app.utilities.network import LightningClient
|
||||
|
||||
|
||||
def _get_project(
|
||||
client: LightningClient, project_id: str = LIGHTNING_CLOUD_PROJECT_ID, verbose: bool = True
|
||||
) -> V1Membership:
|
||||
def _get_project(client: LightningClient, project_id: Optional[str] = None, verbose: bool = True) -> V1Membership:
|
||||
"""Get a project membership for the user from the backend."""
|
||||
if project_id is None:
|
||||
project_id = LIGHTNING_CLOUD_PROJECT_ID
|
||||
|
||||
projects = client.projects_service_list_memberships()
|
||||
if project_id is not None:
|
||||
for membership in projects.memberships:
|
||||
|
|
|
@ -25,7 +25,7 @@ from lightning.app.utilities.exceptions import MisconfigurationException
|
|||
|
||||
if TYPE_CHECKING:
|
||||
from lightning.app import LightningApp, LightningFlow, LightningWork
|
||||
from lightning.app.core.plugin import LightningPlugin
|
||||
from lightning.app.plugin.plugin import LightningPlugin
|
||||
|
||||
from lightning.app.utilities.app_helpers import _mock_missing_imports, Logger
|
||||
|
||||
|
@ -85,7 +85,7 @@ def _load_objects_from_file(
|
|||
|
||||
|
||||
def _load_plugin_from_file(filepath: str) -> "LightningPlugin":
|
||||
from lightning.app.core.plugin import LightningPlugin
|
||||
from lightning.app.plugin.plugin import LightningPlugin
|
||||
|
||||
# TODO: Plugin should be run in the context of the created main module here
|
||||
plugins, _ = _load_objects_from_file(filepath, LightningPlugin, raise_exception=True, mock_imports=False)
|
||||
|
|
|
@ -1,4 +1,5 @@
|
|||
import io
|
||||
import json
|
||||
import sys
|
||||
import tarfile
|
||||
from dataclasses import dataclass
|
||||
|
@ -9,11 +10,11 @@ import pytest
|
|||
from fastapi import status
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
from lightning.app.core.plugin import _Run, _start_plugin_server
|
||||
from lightning.app.plugin.plugin import _Run, _start_plugin_server
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
@mock.patch("lightning.app.core.plugin.uvicorn")
|
||||
@mock.patch("lightning.app.plugin.plugin.uvicorn")
|
||||
def mock_plugin_server(mock_uvicorn) -> TestClient:
|
||||
"""This fixture returns a `TestClient` for the plugin server."""
|
||||
|
||||
|
@ -33,6 +34,9 @@ def mock_plugin_server(mock_uvicorn) -> TestClient:
|
|||
class _MockResponse:
|
||||
content: bytes
|
||||
|
||||
def raise_for_status(self):
|
||||
pass
|
||||
|
||||
|
||||
def mock_requests_get(valid_url, return_value):
|
||||
"""Used to replace `requests.get` with a function that returns the given value for the given valid URL and
|
||||
|
@ -59,7 +63,7 @@ def as_tar_bytes(file_name, content):
|
|||
|
||||
|
||||
_plugin_with_internal_error = """
|
||||
from lightning.app.core.plugin import LightningPlugin
|
||||
from lightning.app.plugin.plugin import LightningPlugin
|
||||
|
||||
class TestPlugin(LightningPlugin):
|
||||
def run(self):
|
||||
|
@ -127,7 +131,7 @@ plugin = TestPlugin()
|
|||
),
|
||||
],
|
||||
)
|
||||
@mock.patch("lightning.app.core.plugin.requests")
|
||||
@mock.patch("lightning.app.plugin.plugin.requests")
|
||||
def test_run_errors(mock_requests, mock_plugin_server, body, message, tar_file_name, content):
|
||||
if tar_file_name is not None:
|
||||
content = as_tar_bytes(tar_file_name, content)
|
||||
|
@ -140,8 +144,8 @@ def test_run_errors(mock_requests, mock_plugin_server, body, message, tar_file_n
|
|||
assert message in response.text
|
||||
|
||||
|
||||
_plugin_with_job_run = """
|
||||
from lightning.app.core.plugin import LightningPlugin
|
||||
_plugin_with_job_run_no_actions = """
|
||||
from lightning.app.plugin.plugin import LightningPlugin
|
||||
|
||||
class TestPlugin(LightningPlugin):
|
||||
def run(self, name, entrypoint):
|
||||
|
@ -151,13 +155,46 @@ plugin = TestPlugin()
|
|||
"""
|
||||
|
||||
|
||||
_plugin_with_job_run_toast = """
|
||||
from lightning.app.plugin.actions import Toast
|
||||
from lightning.app.plugin.plugin import LightningPlugin
|
||||
|
||||
class TestPlugin(LightningPlugin):
|
||||
def run(self, name, entrypoint):
|
||||
self.run_job(name, entrypoint)
|
||||
return [Toast("info", "testing")]
|
||||
|
||||
plugin = TestPlugin()
|
||||
"""
|
||||
|
||||
_plugin_with_job_run_navigate = """
|
||||
from lightning.app.plugin.actions import NavigateTo
|
||||
from lightning.app.plugin.plugin import LightningPlugin
|
||||
|
||||
class TestPlugin(LightningPlugin):
|
||||
def run(self, name, entrypoint):
|
||||
self.run_job(name, entrypoint)
|
||||
return [NavigateTo("/testing")]
|
||||
|
||||
plugin = TestPlugin()
|
||||
"""
|
||||
|
||||
|
||||
@pytest.mark.skipif(sys.platform == "win32", reason="the plugin server is only intended to run on linux.")
|
||||
@pytest.mark.parametrize(
|
||||
"plugin_source, actions",
|
||||
[
|
||||
(_plugin_with_job_run_no_actions, []),
|
||||
(_plugin_with_job_run_toast, [{"content": "info:testing", "type": "TOAST"}]),
|
||||
(_plugin_with_job_run_navigate, [{"content": "/testing", "type": "NAVIGATE_TO"}]),
|
||||
],
|
||||
)
|
||||
@mock.patch("lightning.app.runners.cloud.CloudRuntime")
|
||||
@mock.patch("lightning.app.core.plugin.requests")
|
||||
def test_run_job(mock_requests, mock_cloud_runtime, mock_plugin_server):
|
||||
@mock.patch("lightning.app.plugin.plugin.requests")
|
||||
def test_run_job(mock_requests, mock_cloud_runtime, mock_plugin_server, plugin_source, actions):
|
||||
"""Tests that running a job from a plugin calls the correct `CloudRuntime` methods with the correct
|
||||
arguments."""
|
||||
content = as_tar_bytes("plugin.py", _plugin_with_job_run)
|
||||
content = as_tar_bytes("plugin.py", plugin_source)
|
||||
mock_requests.get.side_effect = mock_requests_get("http://test.tar.gz", content)
|
||||
|
||||
body = _Run(
|
||||
|
@ -175,6 +212,7 @@ def test_run_job(mock_requests, mock_cloud_runtime, mock_plugin_server):
|
|||
response = mock_plugin_server.post("/v1/runs", json=body.dict(exclude_none=True))
|
||||
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
assert json.loads(response.text)["actions"] == actions
|
||||
|
||||
mock_cloud_runtime.load_app_from_file.assert_called_once()
|
||||
assert "test_entrypoint" in mock_cloud_runtime.load_app_from_file.call_args[0][0]
|
|
@ -629,7 +629,7 @@ class TestAppCreationClient:
|
|||
cloud_runtime = cloud.CloudRuntime(app=app, entrypoint=(source_code_root_dir / "entrypoint.py"))
|
||||
monkeypatch.setattr(
|
||||
"lightning.app.runners.cloud._get_project",
|
||||
lambda x: V1Membership(name="test-project", project_id="test-project-id"),
|
||||
lambda _, project_id: V1Membership(name="test-project", project_id="test-project-id"),
|
||||
)
|
||||
cloud_runtime.dispatch()
|
||||
|
||||
|
@ -819,7 +819,7 @@ class TestAppCreationClient:
|
|||
cloud_runtime = cloud.CloudRuntime(app=app, entrypoint=(source_code_root_dir / "entrypoint.py"))
|
||||
monkeypatch.setattr(
|
||||
"lightning.app.runners.cloud._get_project",
|
||||
lambda x: V1Membership(name="test-project", project_id="test-project-id"),
|
||||
lambda _, project_id: V1Membership(name="test-project", project_id="test-project-id"),
|
||||
)
|
||||
cloud_runtime.dispatch()
|
||||
|
||||
|
@ -954,7 +954,7 @@ class TestAppCreationClient:
|
|||
cloud_runtime = cloud.CloudRuntime(app=app, entrypoint=(source_code_root_dir / "entrypoint.py"))
|
||||
monkeypatch.setattr(
|
||||
"lightning.app.runners.cloud._get_project",
|
||||
lambda x: V1Membership(name="test-project", project_id="test-project-id"),
|
||||
lambda _, project_id: V1Membership(name="test-project", project_id="test-project-id"),
|
||||
)
|
||||
cloud_runtime.run_app_comment_commands = True
|
||||
cloud_runtime.dispatch()
|
||||
|
@ -1094,7 +1094,7 @@ class TestAppCreationClient:
|
|||
cloud_runtime = cloud.CloudRuntime(app=app, entrypoint=(source_code_root_dir / "entrypoint.py"))
|
||||
monkeypatch.setattr(
|
||||
"lightning.app.runners.cloud._get_project",
|
||||
lambda x: V1Membership(name="test-project", project_id="test-project-id"),
|
||||
lambda _, project_id: V1Membership(name="test-project", project_id="test-project-id"),
|
||||
)
|
||||
cloud_runtime.dispatch()
|
||||
|
||||
|
@ -1314,7 +1314,7 @@ class TestAppCreationClient:
|
|||
cloud_runtime = cloud.CloudRuntime(app=app, entrypoint=(source_code_root_dir / "entrypoint.py"))
|
||||
monkeypatch.setattr(
|
||||
"lightning.app.runners.cloud._get_project",
|
||||
lambda x: V1Membership(name="test-project", project_id="test-project-id"),
|
||||
lambda _, project_id: V1Membership(name="test-project", project_id="test-project-id"),
|
||||
)
|
||||
cloud_runtime.dispatch()
|
||||
|
||||
|
@ -1596,6 +1596,7 @@ class TestCloudspaceDispatch:
|
|||
mock_client = mock.MagicMock()
|
||||
mock_client.auth_service_get_user.return_value = V1GetUserResponse(
|
||||
username="tester",
|
||||
features=V1UserFeatures(),
|
||||
)
|
||||
mock_client.projects_service_list_memberships.return_value = V1ListMembershipsResponse(
|
||||
memberships=[V1Membership(name="project", project_id="project_id")]
|
||||
|
|
Loading…
Reference in New Issue