244 lines
9.5 KiB
Python
244 lines
9.5 KiB
Python
import os
|
|
import sys
|
|
from pathlib import PosixPath
|
|
from unittest.mock import MagicMock
|
|
|
|
import pytest
|
|
from lightning.app.cli.commands import cp
|
|
from lightning.app.cli.commands.cd import _CD_FILE, cd
|
|
from lightning_cloud.openapi import (
|
|
Externalv1Cluster,
|
|
Externalv1LightningappInstance,
|
|
V1CloudSpace,
|
|
V1ClusterDriver,
|
|
V1ClusterSpec,
|
|
V1GetClusterResponse,
|
|
V1KubernetesClusterDriver,
|
|
V1LightningappInstanceArtifact,
|
|
V1LightningappInstanceSpec,
|
|
V1ListCloudSpacesResponse,
|
|
V1ListClustersResponse,
|
|
V1ListLightningappInstanceArtifactsResponse,
|
|
V1ListLightningappInstancesResponse,
|
|
V1ListMembershipsResponse,
|
|
V1ListProjectClusterBindingsResponse,
|
|
V1Membership,
|
|
V1ProjectClusterBinding,
|
|
V1UploadProjectArtifactResponse,
|
|
)
|
|
|
|
|
|
@pytest.mark.skipif(sys.platform == "win32", reason="not supported on windows yet")
|
|
def test_cp_local_to_remote(tmpdir, monkeypatch):
|
|
error_and_exit = MagicMock()
|
|
monkeypatch.setattr(cp, "_error_and_exit", error_and_exit)
|
|
|
|
client = MagicMock()
|
|
client.projects_service_list_memberships.return_value = V1ListMembershipsResponse(
|
|
memberships=[V1Membership(name="project-0")]
|
|
)
|
|
|
|
client.lightningapp_instance_service_list_lightningapp_instances.return_value = V1ListLightningappInstancesResponse(
|
|
lightningapps=[Externalv1LightningappInstance(name="app-name-0", id="app-id-0")]
|
|
)
|
|
|
|
client.projects_service_list_project_cluster_bindings.return_value = V1ListProjectClusterBindingsResponse(
|
|
clusters=[V1ProjectClusterBinding(cluster_id="my-cluster", cluster_name="my-cluster")]
|
|
)
|
|
|
|
result = MagicMock()
|
|
result.get.return_value = V1UploadProjectArtifactResponse(urls=["http://foo.bar"])
|
|
client.lightningapp_instance_service_upload_project_artifact.return_value = result
|
|
|
|
monkeypatch.setattr(cp, "LightningClient", MagicMock(return_value=client))
|
|
|
|
assert cd("/", verify=False) == "/"
|
|
cp.cp(str(tmpdir), "r:.")
|
|
assert error_and_exit._mock_call_args_list[0].args[0] == "Uploading files at the project level isn't allowed yet."
|
|
|
|
assert cd("/project-0/app-name-0", verify=False) == "/project-0/app-name-0"
|
|
with open(f"{tmpdir}/a.txt", "w") as f:
|
|
f.write("hello world !")
|
|
|
|
file_uploader = MagicMock()
|
|
monkeypatch.setattr(cp, "FileUploader", file_uploader)
|
|
|
|
cp.cp(str(tmpdir), "r:.")
|
|
assert file_uploader._mock_call_args[1]["name"] == f"{tmpdir}/a.txt"
|
|
|
|
os.remove(_CD_FILE)
|
|
|
|
|
|
@pytest.mark.skipif(sys.platform == "win32", reason="not supported on windows yet")
|
|
def test_cp_cloud_to_local(tmpdir, monkeypatch):
|
|
error_and_exit = MagicMock()
|
|
monkeypatch.setattr(cp, "_error_and_exit", error_and_exit)
|
|
|
|
client = MagicMock()
|
|
client.projects_service_list_memberships.return_value = V1ListMembershipsResponse(
|
|
memberships=[V1Membership(name="project-0")]
|
|
)
|
|
|
|
clusters = MagicMock()
|
|
clusters.clusters = [MagicMock()]
|
|
client.projects_service_list_project_cluster_bindings.return_value = clusters
|
|
|
|
client.lightningapp_instance_service_list_lightningapp_instances.return_value = V1ListLightningappInstancesResponse(
|
|
lightningapps=[Externalv1LightningappInstance(name="app-name-0", id="app-id-0")]
|
|
)
|
|
|
|
artifacts = [
|
|
V1LightningappInstanceArtifact(filename=".file_1.txt", url="http://foo.bar/file_1.txt", size_bytes=123),
|
|
V1LightningappInstanceArtifact(
|
|
filename=".folder_1/file_2.txt", url="http://foo.bar/folder_1/file_2.txt", size_bytes=123
|
|
),
|
|
V1LightningappInstanceArtifact(
|
|
filename=".folder_2/folder_3/file_3.txt", url="http://foo.bar/folder_2/folder_3/file_3.txt", size_bytes=123
|
|
),
|
|
V1LightningappInstanceArtifact(
|
|
filename=".folder_4/file_4.txt", url="http://foo.bar/folder_4/file_4.txt", size_bytes=123
|
|
),
|
|
]
|
|
|
|
client.lightningapp_instance_service_list_project_artifacts.return_value = (
|
|
V1ListLightningappInstanceArtifactsResponse(artifacts=artifacts)
|
|
)
|
|
|
|
monkeypatch.setattr(cp, "LightningClient", MagicMock(return_value=client))
|
|
|
|
assert cd("/", verify=False) == "/"
|
|
cp.cp(str(tmpdir), "r:.")
|
|
assert error_and_exit._mock_call_args_list[0].args[0] == "Uploading files at the project level isn't allowed yet."
|
|
|
|
assert cd("/project-0/app-name-0", verify=False) == "/project-0/app-name-0"
|
|
|
|
download_file = MagicMock()
|
|
monkeypatch.setattr(cp, "_download_file", download_file)
|
|
|
|
cp.cp("r:.", str(tmpdir))
|
|
|
|
assert len(download_file.call_args_list) == 4
|
|
for i, call in enumerate(download_file.call_args_list):
|
|
assert call.args[0] == PosixPath(tmpdir / artifacts[i].filename)
|
|
assert call.args[1] == artifacts[i].url
|
|
|
|
# cleanup
|
|
os.remove(_CD_FILE)
|
|
|
|
|
|
@pytest.mark.skipif(sys.platform == "win32", reason="not supported on windows yet")
|
|
def test_sanitize_path():
|
|
path, is_remote = cp._sanitize_path("r:default-project", "/")
|
|
assert path == "/default-project"
|
|
assert is_remote
|
|
|
|
path, _ = cp._sanitize_path("r:foo", "/default-project")
|
|
assert path == "/default-project/foo"
|
|
|
|
path, _ = cp._sanitize_path("foo", "/default-project")
|
|
assert path == "foo"
|
|
|
|
|
|
@pytest.mark.skipif(sys.platform == "win32", reason="not supported on windows yet")
|
|
def test_cp_zip_arg_order(monkeypatch):
|
|
assert cd("/", verify=False) == "/"
|
|
|
|
error_and_exit = MagicMock()
|
|
monkeypatch.setattr(cp, "_error_and_exit", error_and_exit)
|
|
monkeypatch.setattr(cp, "LightningClient", MagicMock(return_value=MagicMock()))
|
|
cp.cp("./my-resource", "r:./my-resource", zip=True)
|
|
error_and_exit.assert_called_once()
|
|
assert "Zipping uploads isn't supported yet" in error_and_exit.call_args_list[0].args[0]
|
|
|
|
|
|
@pytest.mark.skipif(sys.platform == "win32", reason="not supported on windows yet")
|
|
def test_cp_zip_src_path_too_short(monkeypatch):
|
|
error_and_exit = MagicMock()
|
|
monkeypatch.setattr(cp, "_error_and_exit", error_and_exit)
|
|
monkeypatch.setattr(cp, "LightningClient", MagicMock(return_value=MagicMock()))
|
|
cp.cp("r:/my-project", ".", zip=True)
|
|
error_and_exit.assert_called_once()
|
|
assert "The source path must be at least two levels deep" in error_and_exit.call_args_list[0].args[0]
|
|
|
|
|
|
@pytest.mark.skipif(sys.platform == "win32", reason="not supported on windows yet")
|
|
def test_cp_zip_remote_to_local_cloudspace_artifact(monkeypatch):
|
|
assert cd("/", verify=False) == "/"
|
|
|
|
token_getter = MagicMock()
|
|
token_getter._get_api_token.return_value = "my-token"
|
|
monkeypatch.setattr(cp, "_AuthTokenGetter", MagicMock(return_value=token_getter))
|
|
|
|
client = MagicMock()
|
|
client.cluster_service_list_clusters.return_value = V1ListClustersResponse(
|
|
default_cluster="my-cluster",
|
|
clusters=[
|
|
Externalv1Cluster(
|
|
id="my-cluster",
|
|
spec=V1ClusterSpec(
|
|
driver=V1ClusterDriver(kubernetes=V1KubernetesClusterDriver(root_domain_name="my-domain"))
|
|
),
|
|
)
|
|
],
|
|
)
|
|
client.projects_service_list_memberships.return_value = V1ListMembershipsResponse(
|
|
memberships=[V1Membership(name="my-project", project_id="my-project-id")]
|
|
)
|
|
client.cloud_space_service_list_cloud_spaces.return_value = V1ListCloudSpacesResponse(
|
|
cloudspaces=[V1CloudSpace(name="my-cloudspace", id="my-cloudspace-id")],
|
|
)
|
|
monkeypatch.setattr(cp, "LightningClient", MagicMock(return_value=client))
|
|
|
|
download_file = MagicMock()
|
|
monkeypatch.setattr(cp, "_download_file", download_file)
|
|
|
|
cloudspace_artifact = "r:/my-project/my-cloudspace/my-artifact"
|
|
cp.cp(cloudspace_artifact, ".", zip=True)
|
|
|
|
download_file.assert_called_once()
|
|
assert download_file.call_args_list[0].args[0] == "./my-artifact.zip"
|
|
assert (
|
|
download_file.call_args_list[0].args[1]
|
|
== "https://storage.my-domain/v1/projects/my-project-id/artifacts/download"
|
|
+ "?prefix=/cloudspaces/my-cloudspace-id/my-artifact&token=my-token"
|
|
)
|
|
|
|
|
|
@pytest.mark.skipif(sys.platform == "win32", reason="not supported on windows yet")
|
|
def test_cp_zip_remote_to_local_app_artifact(monkeypatch):
|
|
assert cd("/", verify=False) == "/"
|
|
|
|
token_getter = MagicMock()
|
|
token_getter._get_api_token.return_value = "my-token"
|
|
monkeypatch.setattr(cp, "_AuthTokenGetter", MagicMock(return_value=token_getter))
|
|
|
|
client = MagicMock()
|
|
client.cluster_service_get_cluster.return_value = V1GetClusterResponse(
|
|
spec=V1ClusterSpec(driver=V1ClusterDriver(kubernetes=V1KubernetesClusterDriver(root_domain_name="my-domain")))
|
|
)
|
|
client.projects_service_list_memberships.return_value = V1ListMembershipsResponse(
|
|
memberships=[V1Membership(name="my-project", project_id="my-project-id")]
|
|
)
|
|
client.lightningapp_instance_service_list_lightningapp_instances.return_value = V1ListLightningappInstancesResponse(
|
|
lightningapps=[
|
|
Externalv1LightningappInstance(
|
|
name="my-app", id="my-app-id", spec=V1LightningappInstanceSpec(cluster_id="my-cluster")
|
|
)
|
|
]
|
|
)
|
|
monkeypatch.setattr(cp, "LightningClient", MagicMock(return_value=client))
|
|
|
|
download_file = MagicMock()
|
|
monkeypatch.setattr(cp, "_download_file", download_file)
|
|
|
|
app_artifact = "r:/my-project/my-app/my-artifact"
|
|
cp.cp(app_artifact, ".", zip=True)
|
|
|
|
download_file.assert_called_once()
|
|
assert download_file.call_args_list[0].args[0] == "./my-artifact.zip"
|
|
assert (
|
|
download_file.call_args_list[0].args[1]
|
|
== "https://storage.my-domain/v1/projects/my-project-id/artifacts/download"
|
|
+ "?prefix=/lightningapps/my-app-id/my-artifact&token=my-token"
|
|
)
|