diff --git a/requirements/app/base.txt b/requirements/app/base.txt index 87b1c1f0b5..c656a00b51 100644 --- a/requirements/app/base.txt +++ b/requirements/app/base.txt @@ -1,4 +1,4 @@ -lightning-cloud>=0.5.24 +lightning-cloud>=0.5.26 packaging typing-extensions>=4.0.0, <=4.4.0 deepdiff>=5.7.0, <6.2.4 diff --git a/src/lightning/app/CHANGELOG.md b/src/lightning/app/CHANGELOG.md index 6320b66d13..00038be348 100644 --- a/src/lightning/app/CHANGELOG.md +++ b/src/lightning/app/CHANGELOG.md @@ -9,7 +9,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Added -- +- Added `lightning connect data` to register data connection to private s3 buckets ([#16738](https://github.com/Lightning-AI/lightning/pull/16738)) ### Changed diff --git a/src/lightning/app/cli/commands/ls.py b/src/lightning/app/cli/commands/ls.py index 4b3c2a8414..96f09c6fc4 100644 --- a/src/lightning/app/cli/commands/ls.py +++ b/src/lightning/app/cli/commands/ls.py @@ -19,7 +19,6 @@ from typing import Generator, List, Optional import click import rich -from fastapi import HTTPException from lightning_cloud.openapi import Externalv1LightningappInstance from rich.console import Console from rich.live import Live @@ -65,7 +64,7 @@ def ls(path: Optional[str] = None, print: bool = True, use_live: bool = True) -> lines = f.readlines() root = lines[0].replace("\n", "") - client = LightningClient() + client = LightningClient(retry=False) projects = client.projects_service_list_memberships() if root == "/": @@ -256,7 +255,7 @@ def _collect_artifacts( page_token=response.next_page_token, tokens=tokens, ) - except HTTPException: + except Exception: # Note: This is triggered when the request is wrong. # This is currently happening due to looping through the user clusters. pass diff --git a/src/lightning/app/cli/connect/data.py b/src/lightning/app/cli/connect/data.py index d4d4af2db7..0a35b788f1 100644 --- a/src/lightning/app/cli/connect/data.py +++ b/src/lightning/app/cli/connect/data.py @@ -12,9 +12,11 @@ # See the License for the specific language governing permissions and # limitations under the License. +import ast import sys import click +import lightning_cloud import rich from rich.live import Live from rich.spinner import Spinner @@ -29,20 +31,32 @@ logger = Logger(__name__) @click.argument("name", required=True) -@click.argument("region", required=True) -@click.argument("source", required=True) -@click.argument("destination", required=False) -@click.argument("project_name", required=False) +@click.option("--region", help="The AWS region of your bucket. Example: `us-west-1`.", required=True) +@click.option( + "--source", help="The URL path to your AWS S3 folder. Example: `s3://pl-flash-data/images/`.", required=True +) +@click.option( + "--secret_arn_name", + help="The name of role stored as a secret on Lightning AI to access your data. " + "Learn more with https://gist.github.com/tchaton/12ad4b788012e83c0eb35e6223ae09fc. " + "Example: `my_role`.", + required=False, +) +@click.option( + "--destination", help="Where your data should appear in the cloud. Currently not supported.", required=False +) +@click.option("--project_name", help="The project name on which to create the data connection.", required=False) def connect_data( name: str, region: str, source: str, + secret_arn_name: str = "", destination: str = "", project_name: str = "", ) -> None: """Create a new data connection.""" - from lightning_cloud.openapi import ProjectIdDataConnectionsBody + from lightning_cloud.openapi import Create, V1AwsDataConnection if sys.platform == "win32": _error_and_exit("Data connection isn't supported on windows. Open an issue on Github.") @@ -51,7 +65,7 @@ def connect_data( live.stop() - client = LightningClient() + client = LightningClient(retry=False) projects = client.projects_service_list_memberships() project_id = None @@ -71,12 +85,15 @@ def connect_data( ) try: - _ = client.data_connection_service_create_data_connection( - body=ProjectIdDataConnectionsBody( + client.data_connection_service_create_data_connection( + body=Create( name=name, - region=region, - source=source, - destination=destination, + aws=V1AwsDataConnection( + region=region, + source=source, + destination=destination, + secret_arn_name=secret_arn_name, + ), ), project_id=project_id, ) @@ -86,8 +103,8 @@ def connect_data( # project_id=project_id, # id=response.id, # ) - # print(response) - except Exception: - _error_and_exit("The data connection creation failed.") + except lightning_cloud.openapi.rest.ApiException as e: + message = ast.literal_eval(e.body.decode("utf-8"))["message"] + _error_and_exit(f"The data connection creation failed. Message: {message}") rich.print(f"[green]Succeeded[/green]: You have created a new data connection {name}.") diff --git a/tests/tests_app/cli/test_connect_data.py b/tests/tests_app/cli/test_connect_data.py index ff88982448..3243ca28c4 100644 --- a/tests/tests_app/cli/test_connect_data.py +++ b/tests/tests_app/cli/test_connect_data.py @@ -1,3 +1,4 @@ +import sys from unittest.mock import MagicMock import pytest @@ -5,7 +6,7 @@ import pytest from lightning.app.cli.connect import data -@pytest.mark.skipif(True, reason="In progress") +@pytest.mark.skipif(sys.platform == "win32", reason="lightning connect data isn't supported on windows") def test_connect_data_no_project(monkeypatch): from lightning_cloud.openapi import V1ListMembershipsResponse, V1Membership @@ -26,10 +27,10 @@ def test_connect_data_no_project(monkeypatch): _get_project.assert_called() -@pytest.mark.skipif(True, reason="In progress") +@pytest.mark.skipif(sys.platform == "win32", reason="lightning connect data isn't supported on windows") def test_connect_data(monkeypatch): - from lightning_cloud.openapi import ProjectIdDataConnectionsBody, V1ListMembershipsResponse, V1Membership + from lightning_cloud.openapi import Create, V1AwsDataConnection, V1ListMembershipsResponse, V1Membership client = MagicMock() client.projects_service_list_memberships.return_value = V1ListMembershipsResponse( @@ -53,10 +54,8 @@ def test_connect_data(monkeypatch): client.data_connection_service_create_data_connection.assert_called_with( project_id="project-id-0", - body=ProjectIdDataConnectionsBody( - destination="", - region="us-east-1", + body=Create( name="imagenet", - source="s3://imagenet", + aws=V1AwsDataConnection(destination="", region="us-east-1", source="s3://imagenet", secret_arn_name=""), ), )