[App] Add support for plugins to return actions (#16832)

This commit is contained in:
Ethan Harris 2023-02-22 17:12:04 +00:00 committed by GitHub
parent 62e3d5854f
commit beced48904
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
12 changed files with 168 additions and 44 deletions

View File

@ -1,4 +1,4 @@
lightning-cloud>=0.5.26
lightning-cloud>=0.5.27
packaging
typing-extensions>=4.0.0, <=4.4.0
deepdiff>=5.7.0, <6.2.4

View File

@ -33,9 +33,9 @@ if "__version__" not in locals():
from lightning.app.core.app import LightningApp # noqa: E402
from lightning.app.core.flow import LightningFlow # noqa: E402
from lightning.app.core.plugin import LightningPlugin # noqa: E402
from lightning.app.core.work import LightningWork # noqa: E402
from lightning.app.perf import pdb # noqa: E402
from lightning.app.plugin.plugin import LightningPlugin # noqa: E402
from lightning.app.utilities.packaging.build_config import BuildConfig # noqa: E402
from lightning.app.utilities.packaging.cloud_compute import CloudCompute # noqa: E402

View File

@ -1,6 +1,5 @@
from lightning.app.core.app import LightningApp
from lightning.app.core.flow import LightningFlow
from lightning.app.core.plugin import LightningPlugin
from lightning.app.core.work import LightningWork
__all__ = ["LightningApp", "LightningFlow", "LightningWork", "LightningPlugin"]
__all__ = ["LightningApp", "LightningFlow", "LightningWork"]

View File

@ -0,0 +1,5 @@
import lightning.app.plugin.actions as actions
from lightning.app.plugin.actions import NavigateTo, Toast, ToastSeverity
from lightning.app.plugin.plugin import LightningPlugin
__all__ = ["LightningPlugin", "actions", "Toast", "ToastSeverity", "NavigateTo"]

View File

@ -0,0 +1,72 @@
# Copyright The Lightning AI team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass
from enum import Enum
from typing import Union
from lightning_cloud.openapi.models import V1CloudSpaceAppAction, V1CloudSpaceAppActionType
class _Action:
"""Actions are returned by `LightningPlugin` objects to perform actions in the UI."""
def to_spec(self) -> V1CloudSpaceAppAction:
"""Convert this action to a ``V1CloudSpaceAppAction``"""
raise NotImplementedError
@dataclass
class NavigateTo(_Action):
"""The ``NavigateTo`` action can be used to navigate to a relative URL within the Lightning frontend.
Args:
url: The relative URL to navigate to. E.g. ``/<username>/<project>``.
"""
url: str
def to_spec(self) -> V1CloudSpaceAppAction:
return V1CloudSpaceAppAction(
type=V1CloudSpaceAppActionType.NAVIGATE_TO,
content=self.url,
)
class ToastSeverity(Enum):
ERROR = "error"
INFO = "info"
SUCCESS = "success"
WARNING = "warning"
def __str__(self) -> str:
return self.value
@dataclass
class Toast(_Action):
"""The ``Toast`` action can be used to display a toast message to the user.
Args:
severity: The severity level of the toast. One of: "error", "info", "success", "warning".
message: The message body.
"""
severity: Union[ToastSeverity, str]
message: str
def to_spec(self) -> V1CloudSpaceAppAction:
return V1CloudSpaceAppAction(
type=V1CloudSpaceAppActionType.TOAST,
content=f"{self.severity}:{self.message}",
)

View File

@ -15,7 +15,7 @@ import os
import tarfile
import tempfile
from pathlib import Path
from typing import Dict, List, Optional
from typing import Any, Dict, List, Optional
from urllib.parse import urlparse
import requests
@ -24,6 +24,8 @@ from fastapi import FastAPI, HTTPException, status
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
from lightning.app.core import constants
from lightning.app.plugin.actions import _Action
from lightning.app.utilities.app_helpers import Logger
from lightning.app.utilities.component import _set_flow_context
from lightning.app.utilities.enum import AppStage
@ -41,16 +43,20 @@ class LightningPlugin:
self.cloudspace_id = ""
self.cluster_id = ""
def run(self, *args: str, **kwargs: str) -> None:
def run(self, *args: str, **kwargs: str) -> Optional[List[_Action]]:
"""Override with the logic to execute on the cloudspace."""
raise NotImplementedError
def run_job(self, name: str, app_entrypoint: str, env_vars: Optional[Dict[str, str]] = None) -> None:
def run_job(self, name: str, app_entrypoint: str, env_vars: Optional[Dict[str, str]] = None) -> str:
"""Run a job in the cloudspace associated with this plugin.
Args:
name: The name of the job.
app_entrypoint: The path of the file containing the app to run.
env_vars: Additional env vars to set when running the app.
Returns:
The relative URL of the created job.
"""
from lightning.app.runners.cloud import CloudRuntime
@ -74,12 +80,14 @@ class LightningPlugin:
# Used to indicate Lightning has been dispatched
os.environ["LIGHTNING_DISPATCHED"] = "1"
runtime.cloudspace_dispatch(
url = runtime.cloudspace_dispatch(
project_id=self.project_id,
cloudspace_id=self.cloudspace_id,
name=name,
cluster_id=self.cluster_id,
)
# Return a relative URL so it can be used with the NavigateTo action.
return url.replace(constants.get_lightning_cloud_url(), "")
def _setup(
self,
@ -101,7 +109,7 @@ class _Run(BaseModel):
plugin_arguments: Dict[str, str]
def _run_plugin(run: _Run) -> List:
def _run_plugin(run: _Run) -> Dict[str, Any]:
"""Create a run with the given name and entrypoint under the cloudspace with the given ID."""
with tempfile.TemporaryDirectory() as tmpdir:
download_path = os.path.join(tmpdir, "source.tar.gz")
@ -115,6 +123,9 @@ def _run_plugin(run: _Run) -> List:
response = requests.get(source_code_url)
# TODO: Backoff retry a few times in case the URL is flaky
response.raise_for_status()
with open(download_path, "wb") as f:
f.write(response.content)
except Exception as e:
@ -152,7 +163,8 @@ def _run_plugin(run: _Run) -> List:
cloudspace_id=run.cloudspace_id,
cluster_id=run.cluster_id,
)
plugin.run(**run.plugin_arguments)
actions = plugin.run(**run.plugin_arguments) or []
return {"actions": [action.to_spec().to_dict() for action in actions]}
except Exception as e:
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"Error running plugin: {str(e)}."
@ -160,9 +172,6 @@ def _run_plugin(run: _Run) -> List:
finally:
os.chdir(cwd)
# TODO: Return actions from the plugin here
return []
def _start_plugin_server(host: str, port: int) -> None:
"""Start the plugin server which can be used to dispatch apps or run plugins."""

View File

@ -196,7 +196,7 @@ class CloudRuntime(Runtime):
cloudspace_id: str,
name: str,
cluster_id: str,
):
) -> str:
"""Slim dispatch for creating runs from a cloudspace. This dispatch avoids resolution of some properties
such as the project and cluster IDs that are instead passed directly.
@ -210,12 +210,15 @@ class CloudRuntime(Runtime):
ApiException: If there was an issue in the backend.
RuntimeError: If there are validation errors.
ValueError: If there are validation errors.
Returns:
The URL of the created job.
"""
# Dispatch in four phases: resolution, validation, spec creation, API transactions
# Resolution
root = self._resolve_root()
repo = self._resolve_repo(root)
self._resolve_cloudspace(project_id, cloudspace_id)
project = self._resolve_project(project_id=project_id)
existing_instances = self._resolve_run_instances_by_name(project_id, name)
name = self._resolve_run_name(name, existing_instances)
queue_server_type = self._resolve_queue_server_type()
@ -240,7 +243,7 @@ class CloudRuntime(Runtime):
run = self._api_create_run(project_id, cloudspace_id, run_body)
self._api_package_and_upload_repo(repo, run)
self._api_create_run_instance(
run_instance = self._api_create_run_instance(
cluster_id,
project_id,
name,
@ -251,6 +254,8 @@ class CloudRuntime(Runtime):
env_vars,
)
return self._get_app_url(project, run_instance, "logs" if run.is_headless else "web-ui")
def dispatch(
self,
name: str = "",
@ -451,16 +456,9 @@ class CloudRuntime(Runtime):
return LocalSourceCodeDir(path=root, ignore_functions=ignore_functions)
def _resolve_project(self) -> V1Membership:
def _resolve_project(self, project_id: Optional[str] = None) -> V1Membership:
"""Determine the project to run on, choosing a default if multiple projects are found."""
return _get_project(self.backend.client)
def _resolve_cloudspace(self, project_id: str, cloudspace_id: str) -> V1CloudSpace:
"""Get a cloudspace by project / cloudspace ID."""
return self.backend.client.cloud_space_service_get_cloud_space(
project_id=project_id,
id=cloudspace_id,
)
return _get_project(self.backend.client, project_id=project_id)
def _resolve_existing_cloudspaces(self, project_id: str, cloudspace_name: str) -> List[V1CloudSpace]:
"""Lists all the cloudspaces with a name matching the provided cloudspace name."""

View File

@ -13,6 +13,7 @@
# limitations under the License.
import os
from typing import Optional
from lightning_cloud.openapi import V1Membership
@ -22,10 +23,11 @@ from lightning.app.utilities.enum import AppStage
from lightning.app.utilities.network import LightningClient
def _get_project(
client: LightningClient, project_id: str = LIGHTNING_CLOUD_PROJECT_ID, verbose: bool = True
) -> V1Membership:
def _get_project(client: LightningClient, project_id: Optional[str] = None, verbose: bool = True) -> V1Membership:
"""Get a project membership for the user from the backend."""
if project_id is None:
project_id = LIGHTNING_CLOUD_PROJECT_ID
projects = client.projects_service_list_memberships()
if project_id is not None:
for membership in projects.memberships:

View File

@ -25,7 +25,7 @@ from lightning.app.utilities.exceptions import MisconfigurationException
if TYPE_CHECKING:
from lightning.app import LightningApp, LightningFlow, LightningWork
from lightning.app.core.plugin import LightningPlugin
from lightning.app.plugin.plugin import LightningPlugin
from lightning.app.utilities.app_helpers import _mock_missing_imports, Logger
@ -85,7 +85,7 @@ def _load_objects_from_file(
def _load_plugin_from_file(filepath: str) -> "LightningPlugin":
from lightning.app.core.plugin import LightningPlugin
from lightning.app.plugin.plugin import LightningPlugin
# TODO: Plugin should be run in the context of the created main module here
plugins, _ = _load_objects_from_file(filepath, LightningPlugin, raise_exception=True, mock_imports=False)

View File

View File

@ -1,4 +1,5 @@
import io
import json
import sys
import tarfile
from dataclasses import dataclass
@ -9,11 +10,11 @@ import pytest
from fastapi import status
from fastapi.testclient import TestClient
from lightning.app.core.plugin import _Run, _start_plugin_server
from lightning.app.plugin.plugin import _Run, _start_plugin_server
@pytest.fixture()
@mock.patch("lightning.app.core.plugin.uvicorn")
@mock.patch("lightning.app.plugin.plugin.uvicorn")
def mock_plugin_server(mock_uvicorn) -> TestClient:
"""This fixture returns a `TestClient` for the plugin server."""
@ -33,6 +34,9 @@ def mock_plugin_server(mock_uvicorn) -> TestClient:
class _MockResponse:
content: bytes
def raise_for_status(self):
pass
def mock_requests_get(valid_url, return_value):
"""Used to replace `requests.get` with a function that returns the given value for the given valid URL and
@ -59,7 +63,7 @@ def as_tar_bytes(file_name, content):
_plugin_with_internal_error = """
from lightning.app.core.plugin import LightningPlugin
from lightning.app.plugin.plugin import LightningPlugin
class TestPlugin(LightningPlugin):
def run(self):
@ -127,7 +131,7 @@ plugin = TestPlugin()
),
],
)
@mock.patch("lightning.app.core.plugin.requests")
@mock.patch("lightning.app.plugin.plugin.requests")
def test_run_errors(mock_requests, mock_plugin_server, body, message, tar_file_name, content):
if tar_file_name is not None:
content = as_tar_bytes(tar_file_name, content)
@ -140,8 +144,8 @@ def test_run_errors(mock_requests, mock_plugin_server, body, message, tar_file_n
assert message in response.text
_plugin_with_job_run = """
from lightning.app.core.plugin import LightningPlugin
_plugin_with_job_run_no_actions = """
from lightning.app.plugin.plugin import LightningPlugin
class TestPlugin(LightningPlugin):
def run(self, name, entrypoint):
@ -151,13 +155,46 @@ plugin = TestPlugin()
"""
_plugin_with_job_run_toast = """
from lightning.app.plugin.actions import Toast
from lightning.app.plugin.plugin import LightningPlugin
class TestPlugin(LightningPlugin):
def run(self, name, entrypoint):
self.run_job(name, entrypoint)
return [Toast("info", "testing")]
plugin = TestPlugin()
"""
_plugin_with_job_run_navigate = """
from lightning.app.plugin.actions import NavigateTo
from lightning.app.plugin.plugin import LightningPlugin
class TestPlugin(LightningPlugin):
def run(self, name, entrypoint):
self.run_job(name, entrypoint)
return [NavigateTo("/testing")]
plugin = TestPlugin()
"""
@pytest.mark.skipif(sys.platform == "win32", reason="the plugin server is only intended to run on linux.")
@pytest.mark.parametrize(
"plugin_source, actions",
[
(_plugin_with_job_run_no_actions, []),
(_plugin_with_job_run_toast, [{"content": "info:testing", "type": "TOAST"}]),
(_plugin_with_job_run_navigate, [{"content": "/testing", "type": "NAVIGATE_TO"}]),
],
)
@mock.patch("lightning.app.runners.cloud.CloudRuntime")
@mock.patch("lightning.app.core.plugin.requests")
def test_run_job(mock_requests, mock_cloud_runtime, mock_plugin_server):
@mock.patch("lightning.app.plugin.plugin.requests")
def test_run_job(mock_requests, mock_cloud_runtime, mock_plugin_server, plugin_source, actions):
"""Tests that running a job from a plugin calls the correct `CloudRuntime` methods with the correct
arguments."""
content = as_tar_bytes("plugin.py", _plugin_with_job_run)
content = as_tar_bytes("plugin.py", plugin_source)
mock_requests.get.side_effect = mock_requests_get("http://test.tar.gz", content)
body = _Run(
@ -175,6 +212,7 @@ def test_run_job(mock_requests, mock_cloud_runtime, mock_plugin_server):
response = mock_plugin_server.post("/v1/runs", json=body.dict(exclude_none=True))
assert response.status_code == status.HTTP_200_OK
assert json.loads(response.text)["actions"] == actions
mock_cloud_runtime.load_app_from_file.assert_called_once()
assert "test_entrypoint" in mock_cloud_runtime.load_app_from_file.call_args[0][0]

View File

@ -629,7 +629,7 @@ class TestAppCreationClient:
cloud_runtime = cloud.CloudRuntime(app=app, entrypoint=(source_code_root_dir / "entrypoint.py"))
monkeypatch.setattr(
"lightning.app.runners.cloud._get_project",
lambda x: V1Membership(name="test-project", project_id="test-project-id"),
lambda _, project_id: V1Membership(name="test-project", project_id="test-project-id"),
)
cloud_runtime.dispatch()
@ -819,7 +819,7 @@ class TestAppCreationClient:
cloud_runtime = cloud.CloudRuntime(app=app, entrypoint=(source_code_root_dir / "entrypoint.py"))
monkeypatch.setattr(
"lightning.app.runners.cloud._get_project",
lambda x: V1Membership(name="test-project", project_id="test-project-id"),
lambda _, project_id: V1Membership(name="test-project", project_id="test-project-id"),
)
cloud_runtime.dispatch()
@ -954,7 +954,7 @@ class TestAppCreationClient:
cloud_runtime = cloud.CloudRuntime(app=app, entrypoint=(source_code_root_dir / "entrypoint.py"))
monkeypatch.setattr(
"lightning.app.runners.cloud._get_project",
lambda x: V1Membership(name="test-project", project_id="test-project-id"),
lambda _, project_id: V1Membership(name="test-project", project_id="test-project-id"),
)
cloud_runtime.run_app_comment_commands = True
cloud_runtime.dispatch()
@ -1094,7 +1094,7 @@ class TestAppCreationClient:
cloud_runtime = cloud.CloudRuntime(app=app, entrypoint=(source_code_root_dir / "entrypoint.py"))
monkeypatch.setattr(
"lightning.app.runners.cloud._get_project",
lambda x: V1Membership(name="test-project", project_id="test-project-id"),
lambda _, project_id: V1Membership(name="test-project", project_id="test-project-id"),
)
cloud_runtime.dispatch()
@ -1314,7 +1314,7 @@ class TestAppCreationClient:
cloud_runtime = cloud.CloudRuntime(app=app, entrypoint=(source_code_root_dir / "entrypoint.py"))
monkeypatch.setattr(
"lightning.app.runners.cloud._get_project",
lambda x: V1Membership(name="test-project", project_id="test-project-id"),
lambda _, project_id: V1Membership(name="test-project", project_id="test-project-id"),
)
cloud_runtime.dispatch()
@ -1596,6 +1596,7 @@ class TestCloudspaceDispatch:
mock_client = mock.MagicMock()
mock_client.auth_service_get_user.return_value = V1GetUserResponse(
username="tester",
features=V1UserFeatures(),
)
mock_client.projects_service_list_memberships.return_value = V1ListMembershipsResponse(
memberships=[V1Membership(name="project", project_id="project_id")]