[App] Add support for private data (#16738)

Co-authored-by: thomas <thomas@thomass-MacBook-Pro.local>
This commit is contained in:
thomas chaton 2023-02-14 12:34:44 +00:00 committed by GitHub
parent 32e71377a8
commit 7e8400d277
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 41 additions and 26 deletions

View File

@ -1,4 +1,4 @@
lightning-cloud>=0.5.24
lightning-cloud>=0.5.26
packaging
typing-extensions>=4.0.0, <=4.4.0
deepdiff>=5.7.0, <6.2.4

View File

@ -9,7 +9,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
### Added
-
- Added `lightning connect data` to register data connection to private s3 buckets ([#16738](https://github.com/Lightning-AI/lightning/pull/16738))
### Changed

View File

@ -19,7 +19,6 @@ from typing import Generator, List, Optional
import click
import rich
from fastapi import HTTPException
from lightning_cloud.openapi import Externalv1LightningappInstance
from rich.console import Console
from rich.live import Live
@ -65,7 +64,7 @@ def ls(path: Optional[str] = None, print: bool = True, use_live: bool = True) ->
lines = f.readlines()
root = lines[0].replace("\n", "")
client = LightningClient()
client = LightningClient(retry=False)
projects = client.projects_service_list_memberships()
if root == "/":
@ -256,7 +255,7 @@ def _collect_artifacts(
page_token=response.next_page_token,
tokens=tokens,
)
except HTTPException:
except Exception:
# Note: This is triggered when the request is wrong.
# This is currently happening due to looping through the user clusters.
pass

View File

@ -12,9 +12,11 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import ast
import sys
import click
import lightning_cloud
import rich
from rich.live import Live
from rich.spinner import Spinner
@ -29,20 +31,32 @@ logger = Logger(__name__)
@click.argument("name", required=True)
@click.argument("region", required=True)
@click.argument("source", required=True)
@click.argument("destination", required=False)
@click.argument("project_name", required=False)
@click.option("--region", help="The AWS region of your bucket. Example: `us-west-1`.", required=True)
@click.option(
"--source", help="The URL path to your AWS S3 folder. Example: `s3://pl-flash-data/images/`.", required=True
)
@click.option(
"--secret_arn_name",
help="The name of role stored as a secret on Lightning AI to access your data. "
"Learn more with https://gist.github.com/tchaton/12ad4b788012e83c0eb35e6223ae09fc. "
"Example: `my_role`.",
required=False,
)
@click.option(
"--destination", help="Where your data should appear in the cloud. Currently not supported.", required=False
)
@click.option("--project_name", help="The project name on which to create the data connection.", required=False)
def connect_data(
name: str,
region: str,
source: str,
secret_arn_name: str = "",
destination: str = "",
project_name: str = "",
) -> None:
"""Create a new data connection."""
from lightning_cloud.openapi import ProjectIdDataConnectionsBody
from lightning_cloud.openapi import Create, V1AwsDataConnection
if sys.platform == "win32":
_error_and_exit("Data connection isn't supported on windows. Open an issue on Github.")
@ -51,7 +65,7 @@ def connect_data(
live.stop()
client = LightningClient()
client = LightningClient(retry=False)
projects = client.projects_service_list_memberships()
project_id = None
@ -71,12 +85,15 @@ def connect_data(
)
try:
_ = client.data_connection_service_create_data_connection(
body=ProjectIdDataConnectionsBody(
client.data_connection_service_create_data_connection(
body=Create(
name=name,
region=region,
source=source,
destination=destination,
aws=V1AwsDataConnection(
region=region,
source=source,
destination=destination,
secret_arn_name=secret_arn_name,
),
),
project_id=project_id,
)
@ -86,8 +103,8 @@ def connect_data(
# project_id=project_id,
# id=response.id,
# )
# print(response)
except Exception:
_error_and_exit("The data connection creation failed.")
except lightning_cloud.openapi.rest.ApiException as e:
message = ast.literal_eval(e.body.decode("utf-8"))["message"]
_error_and_exit(f"The data connection creation failed. Message: {message}")
rich.print(f"[green]Succeeded[/green]: You have created a new data connection {name}.")

View File

@ -1,3 +1,4 @@
import sys
from unittest.mock import MagicMock
import pytest
@ -5,7 +6,7 @@ import pytest
from lightning.app.cli.connect import data
@pytest.mark.skipif(True, reason="In progress")
@pytest.mark.skipif(sys.platform == "win32", reason="lightning connect data isn't supported on windows")
def test_connect_data_no_project(monkeypatch):
from lightning_cloud.openapi import V1ListMembershipsResponse, V1Membership
@ -26,10 +27,10 @@ def test_connect_data_no_project(monkeypatch):
_get_project.assert_called()
@pytest.mark.skipif(True, reason="In progress")
@pytest.mark.skipif(sys.platform == "win32", reason="lightning connect data isn't supported on windows")
def test_connect_data(monkeypatch):
from lightning_cloud.openapi import ProjectIdDataConnectionsBody, V1ListMembershipsResponse, V1Membership
from lightning_cloud.openapi import Create, V1AwsDataConnection, V1ListMembershipsResponse, V1Membership
client = MagicMock()
client.projects_service_list_memberships.return_value = V1ListMembershipsResponse(
@ -53,10 +54,8 @@ def test_connect_data(monkeypatch):
client.data_connection_service_create_data_connection.assert_called_with(
project_id="project-id-0",
body=ProjectIdDataConnectionsBody(
destination="",
region="us-east-1",
body=Create(
name="imagenet",
source="s3://imagenet",
aws=V1AwsDataConnection(destination="", region="us-east-1", source="s3://imagenet", secret_arn_name=""),
),
)