59 lines
2.4 KiB
Python
59 lines
2.4 KiB
Python
import sys
|
|
from unittest.mock import MagicMock
|
|
|
|
import pytest
|
|
from lightning.app.cli.connect import data
|
|
|
|
|
|
@pytest.mark.skipif(sys.platform == "win32", reason="lightning connect data isn't supported on windows")
|
|
def test_connect_data_no_project(monkeypatch):
|
|
from lightning_cloud.openapi import V1ListMembershipsResponse, V1Membership
|
|
|
|
client = MagicMock()
|
|
client.projects_service_list_memberships.return_value = V1ListMembershipsResponse(memberships=[])
|
|
monkeypatch.setattr(data, "LightningClient", MagicMock(return_value=client))
|
|
|
|
_error_and_exit = MagicMock()
|
|
monkeypatch.setattr(data, "_error_and_exit", _error_and_exit)
|
|
|
|
_get_project = MagicMock()
|
|
_get_project.return_value = V1Membership(name="project-0", project_id="project-id-0")
|
|
monkeypatch.setattr(data, "_get_project", _get_project)
|
|
|
|
data.connect_data("imagenet", region="us-east-1", source="imagenet", destination="", project_name="project-0")
|
|
|
|
_get_project.assert_called()
|
|
|
|
|
|
@pytest.mark.skipif(sys.platform == "win32", reason="lightning connect data isn't supported on windows")
|
|
def test_connect_data(monkeypatch):
|
|
from lightning_cloud.openapi import Create, V1AwsDataConnection, V1ListMembershipsResponse, V1Membership
|
|
|
|
client = MagicMock()
|
|
client.projects_service_list_memberships.return_value = V1ListMembershipsResponse(
|
|
memberships=[
|
|
V1Membership(name="project-0", project_id="project-id-0"),
|
|
V1Membership(name="project-1", project_id="project-id-1"),
|
|
V1Membership(name="project 2", project_id="project-id-2"),
|
|
]
|
|
)
|
|
monkeypatch.setattr(data, "LightningClient", MagicMock(return_value=client))
|
|
|
|
_error_and_exit = MagicMock()
|
|
monkeypatch.setattr(data, "_error_and_exit", _error_and_exit)
|
|
data.connect_data("imagenet", region="us-east-1", source="imagenet", destination="", project_name="project-0")
|
|
|
|
_error_and_exit.assert_called_with(
|
|
"Only public S3 folders are supported for now. Please, open a Github issue with your use case."
|
|
)
|
|
|
|
data.connect_data("imagenet", region="us-east-1", source="s3://imagenet", destination="", project_name="project-0")
|
|
|
|
client.data_connection_service_create_data_connection.assert_called_with(
|
|
project_id="project-id-0",
|
|
body=Create(
|
|
name="imagenet",
|
|
aws=V1AwsDataConnection(destination="", region="us-east-1", source="s3://imagenet", secret_arn_name=""),
|
|
),
|
|
)
|