lightning/tests/tests_app/cli/test_cmd_clusters.py

255 lines
9.5 KiB
Python

from unittest import mock
from unittest.mock import MagicMock
import click
import pytest
from lightning_cloud.openapi import (
Externalv1LightningappInstance,
V1AWSClusterDriverSpec,
V1ClusterDriver,
V1ClusterPerformanceProfile,
V1ClusterSpec,
V1ClusterState,
V1ClusterStatus,
V1ClusterType,
V1CreateClusterRequest,
V1GetClusterResponse,
V1KubernetesClusterDriver,
V1LightningappInstanceState,
V1LightningappInstanceStatus,
V1ListLightningappInstancesResponse,
V1ListMembershipsResponse,
V1Membership,
)
from lightning_app.cli import cmd_clusters
from lightning_app.cli.cmd_clusters import AWSClusterManager
@pytest.fixture(params=[True, False])
def async_or_interrupt(request, monkeypatch):
# Simulate hitting ctrl-c immediately while waiting for cluster to create
if not request.param:
monkeypatch.setattr(cmd_clusters, "_wait_for_cluster_state", mock.MagicMock(side_effect=KeyboardInterrupt))
return request.param
@pytest.fixture
def spec():
return V1ClusterSpec(
driver=V1ClusterDriver(
kubernetes=V1KubernetesClusterDriver(
aws=V1AWSClusterDriverSpec(
bucket_name="test-bucket",
),
),
),
)
class FakeLightningClient:
def __init__(self, get_responses=[], consume=True):
self.get_responses = get_responses
self.get_call_count = 0
self.consume = consume
def cluster_service_get_cluster(self, id: str):
self.get_call_count = self.get_call_count + 1
if self.consume:
return self.get_responses.pop(0)
return self.get_responses[0]
@mock.patch("lightning_cloud.login.Auth.authenticate", MagicMock())
@mock.patch("lightning_app.utilities.network.LightningClient.cluster_service_create_cluster")
def test_create_cluster_api(api: mock.MagicMock, async_or_interrupt):
cluster_manager = AWSClusterManager()
cluster_manager.create(
cluster_id="test-7",
external_id="dummy",
role_arn="arn:aws:iam::1234567890:role/lai-byoc",
region="us-west-2",
do_async=async_or_interrupt,
)
api.assert_called_once_with(
body=V1CreateClusterRequest(
name="test-7",
spec=V1ClusterSpec(
cluster_type=V1ClusterType.BYOC,
performance_profile=V1ClusterPerformanceProfile.DEFAULT,
driver=V1ClusterDriver(
kubernetes=V1KubernetesClusterDriver(
aws=V1AWSClusterDriverSpec(
region="us-west-2",
role_arn="arn:aws:iam::1234567890:role/lai-byoc",
external_id="dummy",
)
)
),
),
)
)
@mock.patch("lightning_cloud.login.Auth.authenticate", MagicMock())
@mock.patch("lightning_app.utilities.network.LightningClient.cluster_service_list_clusters")
def test_list_clusters(api: mock.MagicMock):
cluster_manager = AWSClusterManager()
cluster_manager.list()
api.assert_called_once_with(phase_not_in=[V1ClusterState.DELETED])
@pytest.fixture()
def fixture_list_instances_empty():
return V1ListLightningappInstancesResponse([])
@mock.patch("lightning_cloud.login.Auth.authenticate", MagicMock())
@mock.patch("lightning_app.utilities.network.LightningClient.projects_service_list_memberships")
@mock.patch("lightning_app.utilities.network.LightningClient.lightningapp_instance_service_list_lightningapp_instances")
@mock.patch("lightning_app.utilities.network.LightningClient.cluster_service_delete_cluster")
@mock.patch("lightning_app.utilities.network.LightningClient.cluster_service_get_cluster")
def test_delete_cluster_api(
api_get: mock.MagicMock,
api_delete: mock.MagicMock,
api_list_instances: mock.MagicMock,
api_list_memberships: mock.MagicMock,
async_or_interrupt,
spec,
fixture_list_instances_empty,
):
api_list_memberships.return_value = V1ListMembershipsResponse([V1Membership(project_id="test-project")])
api_list_instances.return_value = fixture_list_instances_empty
api_get.return_value = V1GetClusterResponse(spec=spec)
cluster_manager = AWSClusterManager()
cluster_manager.delete(cluster_id="test-7", do_async=async_or_interrupt)
api_delete.assert_called_once_with(id="test-7", force=False)
@mock.patch("click.confirm")
@mock.patch("lightning_cloud.login.Auth.authenticate", MagicMock())
@mock.patch("lightning_app.utilities.network.LightningClient.projects_service_list_memberships")
@mock.patch("lightning_app.utilities.network.LightningClient.lightningapp_instance_service_list_lightningapp_instances")
@mock.patch("lightning_app.utilities.network.LightningClient.cluster_service_delete_cluster")
@mock.patch("lightning_app.utilities.network.LightningClient.cluster_service_get_cluster")
def test_delete_cluster_with_stopped_apps(
api_get: mock.MagicMock,
api_delete: mock.MagicMock,
api_list_instances: mock.MagicMock,
api_list_memberships: mock.MagicMock,
click_mock: mock.MagicMock,
spec,
):
api_list_memberships.return_value = V1ListMembershipsResponse([V1Membership(project_id="test-project")])
api_list_instances.side_effect = [
# when querying for running apps
V1ListLightningappInstancesResponse([]),
# when querying for stopped apps
V1ListLightningappInstancesResponse(
lightningapps=[
Externalv1LightningappInstance(
status=V1LightningappInstanceStatus(
phase=V1LightningappInstanceState.STOPPED,
)
)
]
),
]
api_get.return_value = V1GetClusterResponse(spec=spec)
cluster_manager = AWSClusterManager()
cluster_manager.delete(cluster_id="test-7", do_async=True)
api_delete.assert_called_once_with(id="test-7", force=False)
click_mock.assert_called_once_with("Are you sure you want to continue?", abort=True)
@mock.patch("lightning_cloud.login.Auth.authenticate", MagicMock())
@mock.patch("lightning_app.utilities.network.LightningClient.projects_service_list_memberships")
@mock.patch("lightning_app.utilities.network.LightningClient.lightningapp_instance_service_list_lightningapp_instances")
@mock.patch("lightning_app.utilities.network.LightningClient.cluster_service_get_cluster")
def test_delete_cluster_with_running_apps(
api_get: mock.MagicMock,
api_list_instances: mock.MagicMock,
api_list_memberships: mock.MagicMock,
spec,
):
api_list_memberships.return_value = V1ListMembershipsResponse([V1Membership(project_id="test-project")])
api_list_instances.return_value = V1ListLightningappInstancesResponse(
lightningapps=[
Externalv1LightningappInstance(
status=V1LightningappInstanceStatus(
phase=V1LightningappInstanceState.RUNNING,
)
)
]
)
api_get.return_value = V1GetClusterResponse(spec=spec)
cluster_manager = AWSClusterManager()
with pytest.raises(click.ClickException) as exception:
cluster_manager.delete(cluster_id="test-7")
exception.match("apps running")
class Test_check_cluster_id_is_valid:
@pytest.mark.parametrize("name", ["test-7", "0wildgoat"])
def test_valid(self, name):
assert cmd_clusters._check_cluster_id_is_valid(None, None, name)
@pytest.mark.parametrize(
"name", ["(&%)!@#", "1234567890123456789012345678901234567890123456789012345678901234567890"]
)
def test_invalid(self, name):
with pytest.raises(click.ClickException) as e:
cmd_clusters._check_cluster_id_is_valid(None, None, name)
assert "cluster name doesn't match regex pattern" in str(e.value)
class Test_wait_for_cluster_state:
# TODO(rra) add tests for pagination
@pytest.mark.parametrize("target_state", [V1ClusterState.RUNNING, V1ClusterState.DELETED])
@pytest.mark.parametrize(
"previous_state", [V1ClusterState.QUEUED, V1ClusterState.PENDING, V1ClusterState.UNSPECIFIED]
)
def test_happy_path(self, target_state, previous_state, spec):
client = FakeLightningClient(
get_responses=[
V1GetClusterResponse(
id="test-cluster",
status=V1ClusterStatus(phase=state),
spec=spec,
)
for state in [previous_state, target_state]
]
)
cmd_clusters._wait_for_cluster_state(client, "test-cluster", target_state, poll_duration_seconds=0.1)
assert client.get_call_count == 2
@pytest.mark.parametrize("target_state", [V1ClusterState.RUNNING, V1ClusterState.DELETED])
def test_times_out(self, target_state, spec):
client = FakeLightningClient(
get_responses=[
V1GetClusterResponse(
id="test-cluster",
status=V1ClusterStatus(phase=V1ClusterState.UNSPECIFIED),
spec=spec,
)
],
consume=False,
)
with pytest.raises(click.ClickException) as e:
cmd_clusters._wait_for_cluster_state(
client, "test-cluster", target_state, timeout_seconds=0.1, poll_duration_seconds=0.1
)
if target_state == V1ClusterState.DELETED:
expected_state = "deleted"
if target_state == V1ClusterState.RUNNING:
expected_state = "running"
assert e.match(f"The cluster has not entered the {expected_state} state")