[App] Add support for private data (#16738)
Co-authored-by: thomas <thomas@thomass-MacBook-Pro.local>
This commit is contained in:
parent
32e71377a8
commit
7e8400d277
|
@ -1,4 +1,4 @@
|
|||
lightning-cloud>=0.5.24
|
||||
lightning-cloud>=0.5.26
|
||||
packaging
|
||||
typing-extensions>=4.0.0, <=4.4.0
|
||||
deepdiff>=5.7.0, <6.2.4
|
||||
|
|
|
@ -9,7 +9,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
|
|||
|
||||
### Added
|
||||
|
||||
-
|
||||
- Added `lightning connect data` to register data connection to private s3 buckets ([#16738](https://github.com/Lightning-AI/lightning/pull/16738))
|
||||
|
||||
|
||||
### Changed
|
||||
|
|
|
@ -19,7 +19,6 @@ from typing import Generator, List, Optional
|
|||
|
||||
import click
|
||||
import rich
|
||||
from fastapi import HTTPException
|
||||
from lightning_cloud.openapi import Externalv1LightningappInstance
|
||||
from rich.console import Console
|
||||
from rich.live import Live
|
||||
|
@ -65,7 +64,7 @@ def ls(path: Optional[str] = None, print: bool = True, use_live: bool = True) ->
|
|||
lines = f.readlines()
|
||||
root = lines[0].replace("\n", "")
|
||||
|
||||
client = LightningClient()
|
||||
client = LightningClient(retry=False)
|
||||
projects = client.projects_service_list_memberships()
|
||||
|
||||
if root == "/":
|
||||
|
@ -256,7 +255,7 @@ def _collect_artifacts(
|
|||
page_token=response.next_page_token,
|
||||
tokens=tokens,
|
||||
)
|
||||
except HTTPException:
|
||||
except Exception:
|
||||
# Note: This is triggered when the request is wrong.
|
||||
# This is currently happening due to looping through the user clusters.
|
||||
pass
|
||||
|
|
|
@ -12,9 +12,11 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import ast
|
||||
import sys
|
||||
|
||||
import click
|
||||
import lightning_cloud
|
||||
import rich
|
||||
from rich.live import Live
|
||||
from rich.spinner import Spinner
|
||||
|
@ -29,20 +31,32 @@ logger = Logger(__name__)
|
|||
|
||||
|
||||
@click.argument("name", required=True)
|
||||
@click.argument("region", required=True)
|
||||
@click.argument("source", required=True)
|
||||
@click.argument("destination", required=False)
|
||||
@click.argument("project_name", required=False)
|
||||
@click.option("--region", help="The AWS region of your bucket. Example: `us-west-1`.", required=True)
|
||||
@click.option(
|
||||
"--source", help="The URL path to your AWS S3 folder. Example: `s3://pl-flash-data/images/`.", required=True
|
||||
)
|
||||
@click.option(
|
||||
"--secret_arn_name",
|
||||
help="The name of role stored as a secret on Lightning AI to access your data. "
|
||||
"Learn more with https://gist.github.com/tchaton/12ad4b788012e83c0eb35e6223ae09fc. "
|
||||
"Example: `my_role`.",
|
||||
required=False,
|
||||
)
|
||||
@click.option(
|
||||
"--destination", help="Where your data should appear in the cloud. Currently not supported.", required=False
|
||||
)
|
||||
@click.option("--project_name", help="The project name on which to create the data connection.", required=False)
|
||||
def connect_data(
|
||||
name: str,
|
||||
region: str,
|
||||
source: str,
|
||||
secret_arn_name: str = "",
|
||||
destination: str = "",
|
||||
project_name: str = "",
|
||||
) -> None:
|
||||
"""Create a new data connection."""
|
||||
|
||||
from lightning_cloud.openapi import ProjectIdDataConnectionsBody
|
||||
from lightning_cloud.openapi import Create, V1AwsDataConnection
|
||||
|
||||
if sys.platform == "win32":
|
||||
_error_and_exit("Data connection isn't supported on windows. Open an issue on Github.")
|
||||
|
@ -51,7 +65,7 @@ def connect_data(
|
|||
|
||||
live.stop()
|
||||
|
||||
client = LightningClient()
|
||||
client = LightningClient(retry=False)
|
||||
projects = client.projects_service_list_memberships()
|
||||
|
||||
project_id = None
|
||||
|
@ -71,12 +85,15 @@ def connect_data(
|
|||
)
|
||||
|
||||
try:
|
||||
_ = client.data_connection_service_create_data_connection(
|
||||
body=ProjectIdDataConnectionsBody(
|
||||
client.data_connection_service_create_data_connection(
|
||||
body=Create(
|
||||
name=name,
|
||||
region=region,
|
||||
source=source,
|
||||
destination=destination,
|
||||
aws=V1AwsDataConnection(
|
||||
region=region,
|
||||
source=source,
|
||||
destination=destination,
|
||||
secret_arn_name=secret_arn_name,
|
||||
),
|
||||
),
|
||||
project_id=project_id,
|
||||
)
|
||||
|
@ -86,8 +103,8 @@ def connect_data(
|
|||
# project_id=project_id,
|
||||
# id=response.id,
|
||||
# )
|
||||
# print(response)
|
||||
except Exception:
|
||||
_error_and_exit("The data connection creation failed.")
|
||||
except lightning_cloud.openapi.rest.ApiException as e:
|
||||
message = ast.literal_eval(e.body.decode("utf-8"))["message"]
|
||||
_error_and_exit(f"The data connection creation failed. Message: {message}")
|
||||
|
||||
rich.print(f"[green]Succeeded[/green]: You have created a new data connection {name}.")
|
||||
|
|
|
@ -1,3 +1,4 @@
|
|||
import sys
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
|
@ -5,7 +6,7 @@ import pytest
|
|||
from lightning.app.cli.connect import data
|
||||
|
||||
|
||||
@pytest.mark.skipif(True, reason="In progress")
|
||||
@pytest.mark.skipif(sys.platform == "win32", reason="lightning connect data isn't supported on windows")
|
||||
def test_connect_data_no_project(monkeypatch):
|
||||
|
||||
from lightning_cloud.openapi import V1ListMembershipsResponse, V1Membership
|
||||
|
@ -26,10 +27,10 @@ def test_connect_data_no_project(monkeypatch):
|
|||
_get_project.assert_called()
|
||||
|
||||
|
||||
@pytest.mark.skipif(True, reason="In progress")
|
||||
@pytest.mark.skipif(sys.platform == "win32", reason="lightning connect data isn't supported on windows")
|
||||
def test_connect_data(monkeypatch):
|
||||
|
||||
from lightning_cloud.openapi import ProjectIdDataConnectionsBody, V1ListMembershipsResponse, V1Membership
|
||||
from lightning_cloud.openapi import Create, V1AwsDataConnection, V1ListMembershipsResponse, V1Membership
|
||||
|
||||
client = MagicMock()
|
||||
client.projects_service_list_memberships.return_value = V1ListMembershipsResponse(
|
||||
|
@ -53,10 +54,8 @@ def test_connect_data(monkeypatch):
|
|||
|
||||
client.data_connection_service_create_data_connection.assert_called_with(
|
||||
project_id="project-id-0",
|
||||
body=ProjectIdDataConnectionsBody(
|
||||
destination="",
|
||||
region="us-east-1",
|
||||
body=Create(
|
||||
name="imagenet",
|
||||
source="s3://imagenet",
|
||||
aws=V1AwsDataConnection(destination="", region="us-east-1", source="s3://imagenet", secret_arn_name=""),
|
||||
),
|
||||
)
|
||||
|
|
Loading…
Reference in New Issue