From beced489040f76e7eee2f4a82d29823834b77327 Mon Sep 17 00:00:00 2001 From: Ethan Harris Date: Wed, 22 Feb 2023 17:12:04 +0000 Subject: [PATCH] [App] Add support for plugins to return actions (#16832) --- requirements/app/base.txt | 2 +- src/lightning/app/__init__.py | 2 +- src/lightning/app/core/__init__.py | 3 +- src/lightning/app/plugin/__init__.py | 5 ++ src/lightning/app/plugin/actions.py | 72 +++++++++++++++++++ src/lightning/app/{core => plugin}/plugin.py | 27 ++++--- src/lightning/app/runners/cloud.py | 22 +++--- src/lightning/app/utilities/cloud.py | 8 ++- src/lightning/app/utilities/load_app.py | 4 +- tests/tests_app/plugin/__init__.py | 0 .../tests_app/{core => plugin}/test_plugin.py | 56 ++++++++++++--- tests/tests_app/runners/test_cloud.py | 11 +-- 12 files changed, 168 insertions(+), 44 deletions(-) create mode 100644 src/lightning/app/plugin/__init__.py create mode 100644 src/lightning/app/plugin/actions.py rename src/lightning/app/{core => plugin}/plugin.py (86%) create mode 100644 tests/tests_app/plugin/__init__.py rename tests/tests_app/{core => plugin}/test_plugin.py (78%) diff --git a/requirements/app/base.txt b/requirements/app/base.txt index c656a00b51..cd25c5c113 100644 --- a/requirements/app/base.txt +++ b/requirements/app/base.txt @@ -1,4 +1,4 @@ -lightning-cloud>=0.5.26 +lightning-cloud>=0.5.27 packaging typing-extensions>=4.0.0, <=4.4.0 deepdiff>=5.7.0, <6.2.4 diff --git a/src/lightning/app/__init__.py b/src/lightning/app/__init__.py index 6377d42232..d791ec20d5 100644 --- a/src/lightning/app/__init__.py +++ b/src/lightning/app/__init__.py @@ -33,9 +33,9 @@ if "__version__" not in locals(): from lightning.app.core.app import LightningApp # noqa: E402 from lightning.app.core.flow import LightningFlow # noqa: E402 -from lightning.app.core.plugin import LightningPlugin # noqa: E402 from lightning.app.core.work import LightningWork # noqa: E402 from lightning.app.perf import pdb # noqa: E402 +from lightning.app.plugin.plugin import LightningPlugin # noqa: E402 from lightning.app.utilities.packaging.build_config import BuildConfig # noqa: E402 from lightning.app.utilities.packaging.cloud_compute import CloudCompute # noqa: E402 diff --git a/src/lightning/app/core/__init__.py b/src/lightning/app/core/__init__.py index a789cdfaf6..cdf8b6aee1 100644 --- a/src/lightning/app/core/__init__.py +++ b/src/lightning/app/core/__init__.py @@ -1,6 +1,5 @@ from lightning.app.core.app import LightningApp from lightning.app.core.flow import LightningFlow -from lightning.app.core.plugin import LightningPlugin from lightning.app.core.work import LightningWork -__all__ = ["LightningApp", "LightningFlow", "LightningWork", "LightningPlugin"] +__all__ = ["LightningApp", "LightningFlow", "LightningWork"] diff --git a/src/lightning/app/plugin/__init__.py b/src/lightning/app/plugin/__init__.py new file mode 100644 index 0000000000..0899490827 --- /dev/null +++ b/src/lightning/app/plugin/__init__.py @@ -0,0 +1,5 @@ +import lightning.app.plugin.actions as actions +from lightning.app.plugin.actions import NavigateTo, Toast, ToastSeverity +from lightning.app.plugin.plugin import LightningPlugin + +__all__ = ["LightningPlugin", "actions", "Toast", "ToastSeverity", "NavigateTo"] diff --git a/src/lightning/app/plugin/actions.py b/src/lightning/app/plugin/actions.py new file mode 100644 index 0000000000..fd3c3cd0e3 --- /dev/null +++ b/src/lightning/app/plugin/actions.py @@ -0,0 +1,72 @@ +# Copyright The Lightning AI team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from dataclasses import dataclass +from enum import Enum +from typing import Union + +from lightning_cloud.openapi.models import V1CloudSpaceAppAction, V1CloudSpaceAppActionType + + +class _Action: + """Actions are returned by `LightningPlugin` objects to perform actions in the UI.""" + + def to_spec(self) -> V1CloudSpaceAppAction: + """Convert this action to a ``V1CloudSpaceAppAction``""" + raise NotImplementedError + + +@dataclass +class NavigateTo(_Action): + """The ``NavigateTo`` action can be used to navigate to a relative URL within the Lightning frontend. + + Args: + url: The relative URL to navigate to. E.g. ``//``. + """ + + url: str + + def to_spec(self) -> V1CloudSpaceAppAction: + return V1CloudSpaceAppAction( + type=V1CloudSpaceAppActionType.NAVIGATE_TO, + content=self.url, + ) + + +class ToastSeverity(Enum): + ERROR = "error" + INFO = "info" + SUCCESS = "success" + WARNING = "warning" + + def __str__(self) -> str: + return self.value + + +@dataclass +class Toast(_Action): + """The ``Toast`` action can be used to display a toast message to the user. + + Args: + severity: The severity level of the toast. One of: "error", "info", "success", "warning". + message: The message body. + """ + + severity: Union[ToastSeverity, str] + message: str + + def to_spec(self) -> V1CloudSpaceAppAction: + return V1CloudSpaceAppAction( + type=V1CloudSpaceAppActionType.TOAST, + content=f"{self.severity}:{self.message}", + ) diff --git a/src/lightning/app/core/plugin.py b/src/lightning/app/plugin/plugin.py similarity index 86% rename from src/lightning/app/core/plugin.py rename to src/lightning/app/plugin/plugin.py index 65781ec234..415f1e9132 100644 --- a/src/lightning/app/core/plugin.py +++ b/src/lightning/app/plugin/plugin.py @@ -15,7 +15,7 @@ import os import tarfile import tempfile from pathlib import Path -from typing import Dict, List, Optional +from typing import Any, Dict, List, Optional from urllib.parse import urlparse import requests @@ -24,6 +24,8 @@ from fastapi import FastAPI, HTTPException, status from fastapi.middleware.cors import CORSMiddleware from pydantic import BaseModel +from lightning.app.core import constants +from lightning.app.plugin.actions import _Action from lightning.app.utilities.app_helpers import Logger from lightning.app.utilities.component import _set_flow_context from lightning.app.utilities.enum import AppStage @@ -41,16 +43,20 @@ class LightningPlugin: self.cloudspace_id = "" self.cluster_id = "" - def run(self, *args: str, **kwargs: str) -> None: + def run(self, *args: str, **kwargs: str) -> Optional[List[_Action]]: """Override with the logic to execute on the cloudspace.""" + raise NotImplementedError - def run_job(self, name: str, app_entrypoint: str, env_vars: Optional[Dict[str, str]] = None) -> None: + def run_job(self, name: str, app_entrypoint: str, env_vars: Optional[Dict[str, str]] = None) -> str: """Run a job in the cloudspace associated with this plugin. Args: name: The name of the job. app_entrypoint: The path of the file containing the app to run. env_vars: Additional env vars to set when running the app. + + Returns: + The relative URL of the created job. """ from lightning.app.runners.cloud import CloudRuntime @@ -74,12 +80,14 @@ class LightningPlugin: # Used to indicate Lightning has been dispatched os.environ["LIGHTNING_DISPATCHED"] = "1" - runtime.cloudspace_dispatch( + url = runtime.cloudspace_dispatch( project_id=self.project_id, cloudspace_id=self.cloudspace_id, name=name, cluster_id=self.cluster_id, ) + # Return a relative URL so it can be used with the NavigateTo action. + return url.replace(constants.get_lightning_cloud_url(), "") def _setup( self, @@ -101,7 +109,7 @@ class _Run(BaseModel): plugin_arguments: Dict[str, str] -def _run_plugin(run: _Run) -> List: +def _run_plugin(run: _Run) -> Dict[str, Any]: """Create a run with the given name and entrypoint under the cloudspace with the given ID.""" with tempfile.TemporaryDirectory() as tmpdir: download_path = os.path.join(tmpdir, "source.tar.gz") @@ -115,6 +123,9 @@ def _run_plugin(run: _Run) -> List: response = requests.get(source_code_url) + # TODO: Backoff retry a few times in case the URL is flaky + response.raise_for_status() + with open(download_path, "wb") as f: f.write(response.content) except Exception as e: @@ -152,7 +163,8 @@ def _run_plugin(run: _Run) -> List: cloudspace_id=run.cloudspace_id, cluster_id=run.cluster_id, ) - plugin.run(**run.plugin_arguments) + actions = plugin.run(**run.plugin_arguments) or [] + return {"actions": [action.to_spec().to_dict() for action in actions]} except Exception as e: raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"Error running plugin: {str(e)}." @@ -160,9 +172,6 @@ def _run_plugin(run: _Run) -> List: finally: os.chdir(cwd) - # TODO: Return actions from the plugin here - return [] - def _start_plugin_server(host: str, port: int) -> None: """Start the plugin server which can be used to dispatch apps or run plugins.""" diff --git a/src/lightning/app/runners/cloud.py b/src/lightning/app/runners/cloud.py index 9a21d63742..610985e6b8 100644 --- a/src/lightning/app/runners/cloud.py +++ b/src/lightning/app/runners/cloud.py @@ -196,7 +196,7 @@ class CloudRuntime(Runtime): cloudspace_id: str, name: str, cluster_id: str, - ): + ) -> str: """Slim dispatch for creating runs from a cloudspace. This dispatch avoids resolution of some properties such as the project and cluster IDs that are instead passed directly. @@ -210,12 +210,15 @@ class CloudRuntime(Runtime): ApiException: If there was an issue in the backend. RuntimeError: If there are validation errors. ValueError: If there are validation errors. + + Returns: + The URL of the created job. """ # Dispatch in four phases: resolution, validation, spec creation, API transactions # Resolution root = self._resolve_root() repo = self._resolve_repo(root) - self._resolve_cloudspace(project_id, cloudspace_id) + project = self._resolve_project(project_id=project_id) existing_instances = self._resolve_run_instances_by_name(project_id, name) name = self._resolve_run_name(name, existing_instances) queue_server_type = self._resolve_queue_server_type() @@ -240,7 +243,7 @@ class CloudRuntime(Runtime): run = self._api_create_run(project_id, cloudspace_id, run_body) self._api_package_and_upload_repo(repo, run) - self._api_create_run_instance( + run_instance = self._api_create_run_instance( cluster_id, project_id, name, @@ -251,6 +254,8 @@ class CloudRuntime(Runtime): env_vars, ) + return self._get_app_url(project, run_instance, "logs" if run.is_headless else "web-ui") + def dispatch( self, name: str = "", @@ -451,16 +456,9 @@ class CloudRuntime(Runtime): return LocalSourceCodeDir(path=root, ignore_functions=ignore_functions) - def _resolve_project(self) -> V1Membership: + def _resolve_project(self, project_id: Optional[str] = None) -> V1Membership: """Determine the project to run on, choosing a default if multiple projects are found.""" - return _get_project(self.backend.client) - - def _resolve_cloudspace(self, project_id: str, cloudspace_id: str) -> V1CloudSpace: - """Get a cloudspace by project / cloudspace ID.""" - return self.backend.client.cloud_space_service_get_cloud_space( - project_id=project_id, - id=cloudspace_id, - ) + return _get_project(self.backend.client, project_id=project_id) def _resolve_existing_cloudspaces(self, project_id: str, cloudspace_name: str) -> List[V1CloudSpace]: """Lists all the cloudspaces with a name matching the provided cloudspace name.""" diff --git a/src/lightning/app/utilities/cloud.py b/src/lightning/app/utilities/cloud.py index c1b66f3d47..f5145fe10c 100644 --- a/src/lightning/app/utilities/cloud.py +++ b/src/lightning/app/utilities/cloud.py @@ -13,6 +13,7 @@ # limitations under the License. import os +from typing import Optional from lightning_cloud.openapi import V1Membership @@ -22,10 +23,11 @@ from lightning.app.utilities.enum import AppStage from lightning.app.utilities.network import LightningClient -def _get_project( - client: LightningClient, project_id: str = LIGHTNING_CLOUD_PROJECT_ID, verbose: bool = True -) -> V1Membership: +def _get_project(client: LightningClient, project_id: Optional[str] = None, verbose: bool = True) -> V1Membership: """Get a project membership for the user from the backend.""" + if project_id is None: + project_id = LIGHTNING_CLOUD_PROJECT_ID + projects = client.projects_service_list_memberships() if project_id is not None: for membership in projects.memberships: diff --git a/src/lightning/app/utilities/load_app.py b/src/lightning/app/utilities/load_app.py index 7f50c344db..a98c9d934d 100644 --- a/src/lightning/app/utilities/load_app.py +++ b/src/lightning/app/utilities/load_app.py @@ -25,7 +25,7 @@ from lightning.app.utilities.exceptions import MisconfigurationException if TYPE_CHECKING: from lightning.app import LightningApp, LightningFlow, LightningWork - from lightning.app.core.plugin import LightningPlugin + from lightning.app.plugin.plugin import LightningPlugin from lightning.app.utilities.app_helpers import _mock_missing_imports, Logger @@ -85,7 +85,7 @@ def _load_objects_from_file( def _load_plugin_from_file(filepath: str) -> "LightningPlugin": - from lightning.app.core.plugin import LightningPlugin + from lightning.app.plugin.plugin import LightningPlugin # TODO: Plugin should be run in the context of the created main module here plugins, _ = _load_objects_from_file(filepath, LightningPlugin, raise_exception=True, mock_imports=False) diff --git a/tests/tests_app/plugin/__init__.py b/tests/tests_app/plugin/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/tests_app/core/test_plugin.py b/tests/tests_app/plugin/test_plugin.py similarity index 78% rename from tests/tests_app/core/test_plugin.py rename to tests/tests_app/plugin/test_plugin.py index fddcec0500..87896fd0c0 100644 --- a/tests/tests_app/core/test_plugin.py +++ b/tests/tests_app/plugin/test_plugin.py @@ -1,4 +1,5 @@ import io +import json import sys import tarfile from dataclasses import dataclass @@ -9,11 +10,11 @@ import pytest from fastapi import status from fastapi.testclient import TestClient -from lightning.app.core.plugin import _Run, _start_plugin_server +from lightning.app.plugin.plugin import _Run, _start_plugin_server @pytest.fixture() -@mock.patch("lightning.app.core.plugin.uvicorn") +@mock.patch("lightning.app.plugin.plugin.uvicorn") def mock_plugin_server(mock_uvicorn) -> TestClient: """This fixture returns a `TestClient` for the plugin server.""" @@ -33,6 +34,9 @@ def mock_plugin_server(mock_uvicorn) -> TestClient: class _MockResponse: content: bytes + def raise_for_status(self): + pass + def mock_requests_get(valid_url, return_value): """Used to replace `requests.get` with a function that returns the given value for the given valid URL and @@ -59,7 +63,7 @@ def as_tar_bytes(file_name, content): _plugin_with_internal_error = """ -from lightning.app.core.plugin import LightningPlugin +from lightning.app.plugin.plugin import LightningPlugin class TestPlugin(LightningPlugin): def run(self): @@ -127,7 +131,7 @@ plugin = TestPlugin() ), ], ) -@mock.patch("lightning.app.core.plugin.requests") +@mock.patch("lightning.app.plugin.plugin.requests") def test_run_errors(mock_requests, mock_plugin_server, body, message, tar_file_name, content): if tar_file_name is not None: content = as_tar_bytes(tar_file_name, content) @@ -140,8 +144,8 @@ def test_run_errors(mock_requests, mock_plugin_server, body, message, tar_file_n assert message in response.text -_plugin_with_job_run = """ -from lightning.app.core.plugin import LightningPlugin +_plugin_with_job_run_no_actions = """ +from lightning.app.plugin.plugin import LightningPlugin class TestPlugin(LightningPlugin): def run(self, name, entrypoint): @@ -151,13 +155,46 @@ plugin = TestPlugin() """ +_plugin_with_job_run_toast = """ +from lightning.app.plugin.actions import Toast +from lightning.app.plugin.plugin import LightningPlugin + +class TestPlugin(LightningPlugin): + def run(self, name, entrypoint): + self.run_job(name, entrypoint) + return [Toast("info", "testing")] + +plugin = TestPlugin() +""" + +_plugin_with_job_run_navigate = """ +from lightning.app.plugin.actions import NavigateTo +from lightning.app.plugin.plugin import LightningPlugin + +class TestPlugin(LightningPlugin): + def run(self, name, entrypoint): + self.run_job(name, entrypoint) + return [NavigateTo("/testing")] + +plugin = TestPlugin() +""" + + @pytest.mark.skipif(sys.platform == "win32", reason="the plugin server is only intended to run on linux.") +@pytest.mark.parametrize( + "plugin_source, actions", + [ + (_plugin_with_job_run_no_actions, []), + (_plugin_with_job_run_toast, [{"content": "info:testing", "type": "TOAST"}]), + (_plugin_with_job_run_navigate, [{"content": "/testing", "type": "NAVIGATE_TO"}]), + ], +) @mock.patch("lightning.app.runners.cloud.CloudRuntime") -@mock.patch("lightning.app.core.plugin.requests") -def test_run_job(mock_requests, mock_cloud_runtime, mock_plugin_server): +@mock.patch("lightning.app.plugin.plugin.requests") +def test_run_job(mock_requests, mock_cloud_runtime, mock_plugin_server, plugin_source, actions): """Tests that running a job from a plugin calls the correct `CloudRuntime` methods with the correct arguments.""" - content = as_tar_bytes("plugin.py", _plugin_with_job_run) + content = as_tar_bytes("plugin.py", plugin_source) mock_requests.get.side_effect = mock_requests_get("http://test.tar.gz", content) body = _Run( @@ -175,6 +212,7 @@ def test_run_job(mock_requests, mock_cloud_runtime, mock_plugin_server): response = mock_plugin_server.post("/v1/runs", json=body.dict(exclude_none=True)) assert response.status_code == status.HTTP_200_OK + assert json.loads(response.text)["actions"] == actions mock_cloud_runtime.load_app_from_file.assert_called_once() assert "test_entrypoint" in mock_cloud_runtime.load_app_from_file.call_args[0][0] diff --git a/tests/tests_app/runners/test_cloud.py b/tests/tests_app/runners/test_cloud.py index 1b10b0e360..e5a7132f01 100644 --- a/tests/tests_app/runners/test_cloud.py +++ b/tests/tests_app/runners/test_cloud.py @@ -629,7 +629,7 @@ class TestAppCreationClient: cloud_runtime = cloud.CloudRuntime(app=app, entrypoint=(source_code_root_dir / "entrypoint.py")) monkeypatch.setattr( "lightning.app.runners.cloud._get_project", - lambda x: V1Membership(name="test-project", project_id="test-project-id"), + lambda _, project_id: V1Membership(name="test-project", project_id="test-project-id"), ) cloud_runtime.dispatch() @@ -819,7 +819,7 @@ class TestAppCreationClient: cloud_runtime = cloud.CloudRuntime(app=app, entrypoint=(source_code_root_dir / "entrypoint.py")) monkeypatch.setattr( "lightning.app.runners.cloud._get_project", - lambda x: V1Membership(name="test-project", project_id="test-project-id"), + lambda _, project_id: V1Membership(name="test-project", project_id="test-project-id"), ) cloud_runtime.dispatch() @@ -954,7 +954,7 @@ class TestAppCreationClient: cloud_runtime = cloud.CloudRuntime(app=app, entrypoint=(source_code_root_dir / "entrypoint.py")) monkeypatch.setattr( "lightning.app.runners.cloud._get_project", - lambda x: V1Membership(name="test-project", project_id="test-project-id"), + lambda _, project_id: V1Membership(name="test-project", project_id="test-project-id"), ) cloud_runtime.run_app_comment_commands = True cloud_runtime.dispatch() @@ -1094,7 +1094,7 @@ class TestAppCreationClient: cloud_runtime = cloud.CloudRuntime(app=app, entrypoint=(source_code_root_dir / "entrypoint.py")) monkeypatch.setattr( "lightning.app.runners.cloud._get_project", - lambda x: V1Membership(name="test-project", project_id="test-project-id"), + lambda _, project_id: V1Membership(name="test-project", project_id="test-project-id"), ) cloud_runtime.dispatch() @@ -1314,7 +1314,7 @@ class TestAppCreationClient: cloud_runtime = cloud.CloudRuntime(app=app, entrypoint=(source_code_root_dir / "entrypoint.py")) monkeypatch.setattr( "lightning.app.runners.cloud._get_project", - lambda x: V1Membership(name="test-project", project_id="test-project-id"), + lambda _, project_id: V1Membership(name="test-project", project_id="test-project-id"), ) cloud_runtime.dispatch() @@ -1596,6 +1596,7 @@ class TestCloudspaceDispatch: mock_client = mock.MagicMock() mock_client.auth_service_get_user.return_value = V1GetUserResponse( username="tester", + features=V1UserFeatures(), ) mock_client.projects_service_list_memberships.return_value = V1ListMembershipsResponse( memberships=[V1Membership(name="project", project_id="project_id")]