136 lines
4.8 KiB
Python
136 lines
4.8 KiB
Python
from unittest import mock
|
|
from unittest.mock import MagicMock
|
|
|
|
import click
|
|
import pytest
|
|
from lightning_cloud.openapi import (
|
|
Externalv1Cluster,
|
|
V1AWSClusterDriverSpec,
|
|
V1ClusterDriver,
|
|
V1ClusterPerformanceProfile,
|
|
V1ClusterSpec,
|
|
V1ClusterState,
|
|
V1ClusterStatus,
|
|
V1ClusterType,
|
|
V1CreateClusterRequest,
|
|
V1KubernetesClusterDriver,
|
|
V1ListClustersResponse,
|
|
)
|
|
|
|
from lightning_app.cli import cmd_clusters
|
|
from lightning_app.cli.cmd_clusters import AWSClusterManager
|
|
|
|
|
|
class FakeLightningClient:
|
|
def __init__(self, list_responses=[], consume=True):
|
|
self.list_responses = list_responses
|
|
self.list_call_count = 0
|
|
self.consume = consume
|
|
|
|
def cluster_service_list_clusters(self, phase_not_in=None):
|
|
self.list_call_count = self.list_call_count + 1
|
|
if self.consume:
|
|
return self.list_responses.pop()
|
|
return self.list_responses[0]
|
|
|
|
|
|
@mock.patch("lightning_cloud.login.Auth.authenticate", MagicMock())
|
|
@mock.patch("lightning_app.utilities.network.LightningClient.cluster_service_create_cluster")
|
|
def test_create_cluster(api: mock.MagicMock):
|
|
cluster_manager = AWSClusterManager()
|
|
cluster_manager.create(
|
|
cluster_name="test-7",
|
|
external_id="dummy",
|
|
role_arn="arn:aws:iam::1234567890:role/lai-byoc",
|
|
region="us-west-2",
|
|
)
|
|
|
|
api.assert_called_once_with(
|
|
body=V1CreateClusterRequest(
|
|
name="test-7",
|
|
spec=V1ClusterSpec(
|
|
cluster_type=V1ClusterType.BYOC,
|
|
performance_profile=V1ClusterPerformanceProfile.DEFAULT,
|
|
driver=V1ClusterDriver(
|
|
kubernetes=V1KubernetesClusterDriver(
|
|
aws=V1AWSClusterDriverSpec(
|
|
region="us-west-2",
|
|
role_arn="arn:aws:iam::1234567890:role/lai-byoc",
|
|
external_id="dummy",
|
|
)
|
|
)
|
|
),
|
|
),
|
|
)
|
|
)
|
|
|
|
|
|
@mock.patch("lightning_cloud.login.Auth.authenticate", MagicMock())
|
|
@mock.patch("lightning_app.utilities.network.LightningClient.cluster_service_list_clusters")
|
|
def test_list_clusters(api: mock.MagicMock):
|
|
cluster_manager = AWSClusterManager()
|
|
cluster_manager.list()
|
|
|
|
api.assert_called_once_with(phase_not_in=[V1ClusterState.DELETED])
|
|
|
|
|
|
@mock.patch("lightning_cloud.login.Auth.authenticate", MagicMock())
|
|
@mock.patch("lightning_app.utilities.network.LightningClient.cluster_service_delete_cluster")
|
|
def test_delete_cluster(api: mock.MagicMock):
|
|
cluster_manager = AWSClusterManager()
|
|
cluster_manager.delete(cluster_id="test-7")
|
|
|
|
api.assert_called_once_with(id="test-7", force=False)
|
|
|
|
|
|
class Test_check_cluster_name_is_valid:
|
|
@pytest.mark.parametrize("name", ["test-7", "0wildgoat"])
|
|
def test_valid(self, name):
|
|
assert cmd_clusters._check_cluster_name_is_valid(None, None, name)
|
|
|
|
@pytest.mark.parametrize(
|
|
"name", ["(&%)!@#", "1234567890123456789012345678901234567890123456789012345678901234567890"]
|
|
)
|
|
def test_invalid(self, name):
|
|
with pytest.raises(click.ClickException) as e:
|
|
cmd_clusters._check_cluster_name_is_valid(None, None, name)
|
|
assert "cluster name doesn't match regex pattern" in str(e.value)
|
|
|
|
|
|
class Test_wait_for_cluster_state:
|
|
# TODO(rra) add tests for pagination
|
|
|
|
@pytest.mark.parametrize("target_state", [V1ClusterState.RUNNING, V1ClusterState.DELETED])
|
|
@pytest.mark.parametrize(
|
|
"previous_state", [V1ClusterState.QUEUED, V1ClusterState.PENDING, V1ClusterState.UNSPECIFIED]
|
|
)
|
|
def test_happy_path(self, target_state, previous_state):
|
|
client = FakeLightningClient(
|
|
list_responses=[
|
|
V1ListClustersResponse(
|
|
clusters=[Externalv1Cluster(id="test-cluster", status=V1ClusterStatus(phase=state))]
|
|
)
|
|
for state in [previous_state, target_state]
|
|
]
|
|
)
|
|
cmd_clusters._wait_for_cluster_state(client, "test-cluster", target_state, check_timeout=0.1)
|
|
assert client.list_call_count == 1
|
|
|
|
@pytest.mark.parametrize("target_state", [V1ClusterState.RUNNING, V1ClusterState.DELETED])
|
|
def test_times_out(self, target_state):
|
|
client = FakeLightningClient(
|
|
list_responses=[
|
|
V1ListClustersResponse(
|
|
clusters=[
|
|
Externalv1Cluster(id="test-cluster", status=V1ClusterStatus(phase=V1ClusterState.UNSPECIFIED))
|
|
]
|
|
)
|
|
],
|
|
consume=False,
|
|
)
|
|
with pytest.raises(click.ClickException) as e:
|
|
cmd_clusters._wait_for_cluster_state(
|
|
client, "test-cluster", target_state, max_wait_time=0.4, check_timeout=0.2
|
|
)
|
|
assert "Max wait time elapsed" in str(e.value)
|