From 2919dcf7eea181a2706bbcb11201f2f8515cad36 Mon Sep 17 00:00:00 2001 From: Raphael Randschau Date: Tue, 2 Aug 2022 01:31:09 -0700 Subject: [PATCH] [CLI] add support for cluster management (#13835) --- src/lightning_app/CHANGELOG.md | 1 + src/lightning_app/cli/cmd_clusters.py | 206 ++++++++++++++++++ src/lightning_app/cli/core.py | 13 ++ src/lightning_app/cli/lightning_cli.py | 16 +- src/lightning_app/cli/lightning_cli_create.py | 86 ++++++++ src/lightning_app/cli/lightning_cli_delete.py | 49 +++++ src/lightning_app/cli/lightning_cli_list.py | 16 ++ src/lightning_app/testing/testing.py | 24 ++ src/lightning_app/utilities/openapi.py | 61 ++++++ tests/tests_app/cli/test_cli.py | 58 ++++- tests/tests_app/cli/test_cmd_clusters.py | 135 ++++++++++++ tests/tests_clusters/__init__.py | 0 .../tests_clusters/test_cluster_lifecycle.py | 53 +++++ 13 files changed, 707 insertions(+), 11 deletions(-) create mode 100644 src/lightning_app/cli/cmd_clusters.py create mode 100644 src/lightning_app/cli/core.py create mode 100644 src/lightning_app/cli/lightning_cli_create.py create mode 100644 src/lightning_app/cli/lightning_cli_delete.py create mode 100644 src/lightning_app/cli/lightning_cli_list.py create mode 100644 src/lightning_app/utilities/openapi.py create mode 100644 tests/tests_app/cli/test_cmd_clusters.py create mode 100644 tests/tests_clusters/__init__.py create mode 100644 tests/tests_clusters/test_cluster_lifecycle.py diff --git a/src/lightning_app/CHANGELOG.md b/src/lightning_app/CHANGELOG.md index 89fcd61543..34fdb9665f 100644 --- a/src/lightning_app/CHANGELOG.md +++ b/src/lightning_app/CHANGELOG.md @@ -9,6 +9,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Added - Add support for `Lightning App Commands` through the `configure_commands` hook on the Lightning Flow and the `ClientCommand` ([#13602](https://github.com/Lightning-AI/lightning/pull/13602)) +- Add support for Lightning AI BYOC cluster management ([#13835](https://github.com/Lightning-AI/lightning/pull/13835)) - Adds `LightningTrainingComponent`. `LightningTrainingComponent` orchestrates multi-node training in the cloud ([#13830](https://github.com/Lightning-AI/lightning/pull/13830)) diff --git a/src/lightning_app/cli/cmd_clusters.py b/src/lightning_app/cli/cmd_clusters.py new file mode 100644 index 0000000000..7acdc9b630 --- /dev/null +++ b/src/lightning_app/cli/cmd_clusters.py @@ -0,0 +1,206 @@ +import json +import re +import time +from datetime import datetime + +import click +from lightning_cloud.openapi import ( + V1AWSClusterDriverSpec, + V1ClusterDriver, + V1ClusterPerformanceProfile, + V1ClusterSpec, + V1CreateClusterRequest, + V1InstanceSpec, + V1KubernetesClusterDriver, +) +from lightning_cloud.openapi.models import Externalv1Cluster, V1ClusterState, V1ClusterType +from rich.console import Console +from rich.table import Table +from rich.text import Text + +from lightning_app.cli.core import Formatable +from lightning_app.utilities.network import LightningClient +from lightning_app.utilities.openapi import create_openapi_object, string2dict + +CLUSTER_STATE_CHECKING_TIMEOUT = 60 +MAX_CLUSTER_WAIT_TIME = 5400 + + +class AWSClusterManager: + """AWSClusterManager implements API calls specific to Lightning AI BYOC compute clusters when the AWS provider + is selected as the backend compute.""" + + def __init__(self): + self.api_client = LightningClient() + + def create( + self, + cost_savings: bool = False, + cluster_name: str = None, + role_arn: str = None, + region: str = "us-east-1", + external_id: str = None, + instance_types: [str] = [], + edit_before_creation: bool = False, + wait: bool = False, + ): + """request Lightning AI BYOC compute cluster creation. + + Args: + cost_savings: Specifies if the cluster uses cost savings mode + cluster_name: The name of the cluster to be created + role_arn: AWS IAM Role ARN used to provision resources + region: AWS region containing compute resources + external_id: AWS IAM Role external ID + instance_types: AWS instance types supported by the cluster + edit_before_creation: Enables interactive editing of requests before submitting it to Lightning AI. + wait: Waits for the cluster to be in a RUNNING state. Only use this for debugging. + """ + performance_profile = V1ClusterPerformanceProfile.DEFAULT + if cost_savings: + """In cost saving mode the number of compute nodes is reduced to one, reducing the cost for clusters + with low utilization.""" + performance_profile = V1ClusterPerformanceProfile.COST_SAVING + + body = V1CreateClusterRequest( + name=cluster_name, + spec=V1ClusterSpec( + cluster_type=V1ClusterType.BYOC, + performance_profile=performance_profile, + driver=V1ClusterDriver( + kubernetes=V1KubernetesClusterDriver( + aws=V1AWSClusterDriverSpec( + region=region, + role_arn=role_arn, + external_id=external_id, + instance_types=[V1InstanceSpec(name=x) for x in instance_types], + ) + ) + ), + ), + ) + new_body = body + if edit_before_creation: + after = click.edit(json.dumps(body.to_dict(), indent=4)) + if after is not None: + new_body = create_openapi_object(string2dict(after), body) + if new_body == body: + click.echo("cluster unchanged") + + resp = self.api_client.cluster_service_create_cluster(body=new_body) + if wait: + _wait_for_cluster_state(self.api_client, resp.id, V1ClusterState.RUNNING) + + click.echo(f"${resp.id} cluster is ${resp.status.phase}") + + def list(self): + resp = self.api_client.cluster_service_list_clusters(phase_not_in=[V1ClusterState.DELETED]) + console = Console() + console.print(ClusterList(resp.clusters).as_table()) + + def delete(self, cluster_id: str = None, force: bool = False, wait: bool = False): + if force: + click.echo( + """ + Deletes a BYOC cluster. Lightning AI removes cluster artifacts and any resources running on the cluster.\n + WARNING: Deleting a cluster does not clean up any resources managed by Lightning AI.\n + Check your cloud provider to verify that existing cloud resources are deleted. + """ + ) + click.confirm("Do you want to continue?", abort=True) + + self.api_client.cluster_service_delete_cluster(id=cluster_id, force=force) + click.echo("Cluster deletion triggered successfully") + + if wait: + _wait_for_cluster_state(self.api_client, cluster_id, V1ClusterState.DELETED) + + +class ClusterList(Formatable): + def __init__(self, clusters: [Externalv1Cluster]): + self.clusters = clusters + + def as_json(self) -> str: + return json.dumps(self.clusters) + + def as_table(self) -> Table: + table = Table("id", "name", "type", "status", "created", show_header=True, header_style="bold green") + phases = { + V1ClusterState.QUEUED: Text("queued", style="bold yellow"), + V1ClusterState.PENDING: Text("pending", style="bold yellow"), + V1ClusterState.RUNNING: Text("running", style="bold green"), + V1ClusterState.FAILED: Text("failed", style="bold red"), + V1ClusterState.DELETED: Text("deleted", style="bold red"), + } + + cluster_type_lookup = { + V1ClusterType.BYOC: Text("byoc", style="bold yellow"), + V1ClusterType.GLOBAL: Text("lightning-cloud", style="bold green"), + } + for cluster in self.clusters: + cluster: Externalv1Cluster + status = phases[cluster.status.phase] + if cluster.spec.desired_state == V1ClusterState.DELETED and cluster.status.phase != V1ClusterState.DELETED: + status = Text("terminating", style="bold red") + + # this guard is necessary only until 0.3.93 releases which includes the `created_at` + # field to the external API + created_at = datetime.now() + if hasattr(cluster, "created_at"): + created_at = cluster.created_at + + table.add_row( + cluster.id, + cluster.name, + cluster_type_lookup.get(cluster.spec.cluster_type, Text("unknown", style="red")), + status, + created_at.strftime("%Y-%m-%d") if created_at else "", + ) + return table + + +def _wait_for_cluster_state( + api_client: LightningClient, + cluster_id: str, + target_state: V1ClusterState, + max_wait_time: int = MAX_CLUSTER_WAIT_TIME, + check_timeout: int = CLUSTER_STATE_CHECKING_TIMEOUT, +): + """_wait_for_cluster_state waits until the provided cluster has reached a desired state, or failed. + + Args: + api_client: LightningClient used for polling + cluster_id: Specifies the cluster to wait for + target_state: Specifies the desired state the target cluster needs to meet + max_wait_time: Maximum duration to wait (in seconds) + check_timeout: duration between polling for the cluster state (in seconds) + """ + start = time.time() + elapsed = 0 + while elapsed < max_wait_time: + cluster_resp = api_client.cluster_service_list_clusters() + new_cluster = None + for clust in cluster_resp.clusters: + if clust.id == cluster_id: + new_cluster = clust + break + if new_cluster is not None: + if new_cluster.status.phase == target_state: + break + elif new_cluster.status.phase == V1ClusterState.FAILED: + raise click.ClickException(f"Cluster {cluster_id} is in failed state.") + time.sleep(check_timeout) + elapsed = time.time() - start + else: + raise click.ClickException("Max wait time elapsed") + + +def _check_cluster_name_is_valid(_ctx, _param, value): + pattern = r"^(?!-)[a-z0-9-]{1,63}(? Table: + pass + + @abc.abstractmethod + def as_json(self) -> str: + pass diff --git a/src/lightning_app/cli/lightning_cli.py b/src/lightning_app/cli/lightning_cli.py index 74b2d1c492..bb81b4eda1 100644 --- a/src/lightning_app/cli/lightning_cli.py +++ b/src/lightning_app/cli/lightning_cli.py @@ -12,6 +12,9 @@ from requests.exceptions import ConnectionError from lightning_app import __version__ as ver from lightning_app.cli import cmd_init, cmd_install, cmd_pl_init, cmd_react_ui_init +from lightning_app.cli.lightning_cli_create import create +from lightning_app.cli.lightning_cli_delete import delete +from lightning_app.cli.lightning_cli_list import get_list from lightning_app.core.constants import get_lightning_cloud_url, LOCAL_LAUNCH_ADMIN_VIEW from lightning_app.runners.runtime import dispatch from lightning_app.runners.runtime_type import RuntimeType @@ -206,16 +209,9 @@ def stop(): pass -@_main.group(hidden=True) -def delete(): - """Delete an application.""" - pass - - -@_main.group(name="list", hidden=True) -def get_list(): - """List your applications.""" - pass +_main.add_command(get_list) +_main.add_command(delete) +_main.add_command(create) @_main.group() diff --git a/src/lightning_app/cli/lightning_cli_create.py b/src/lightning_app/cli/lightning_cli_create.py new file mode 100644 index 0000000000..7e45fe7e7c --- /dev/null +++ b/src/lightning_app/cli/lightning_cli_create.py @@ -0,0 +1,86 @@ +import click + +from lightning_app.cli.cmd_clusters import _check_cluster_name_is_valid, AWSClusterManager + + +@click.group("create") +def create(): + """Create Lightning AI BYOC managed resources.""" + pass + + +@create.command("cluster") +@click.argument("cluster_name", callback=_check_cluster_name_is_valid) +@click.option("--provider", "provider", type=str, default="aws", help="cloud provider to be used for your cluster") +@click.option("--external-id", "external_id", type=str, required=True) +@click.option( + "--role-arn", "role_arn", type=str, required=True, help="AWS role ARN attached to the associated resources." +) +@click.option( + "--region", + "region", + type=str, + required=False, + default="us-east-1", + help="AWS region that is used to host the associated resources.", +) +@click.option( + "--instance-types", + "instance_types", + type=str, + required=False, + default=None, + help="Instance types that you want to support, for computer jobs within the cluster.", +) +@click.option( + "--cost-savings", + "cost_savings", + type=bool, + required=False, + default=False, + is_flag=True, + help=""""Use this flag to ensure that the cluster is created with a profile that is optimized for cost savings. + This makes runs cheaper but start-up times may increase.""", +) +@click.option( + "--edit-before-creation", + default=False, + is_flag=True, + help="Edit the cluster specs before submitting them to the API server.", +) +@click.option( + "--wait", + "wait", + type=bool, + required=False, + default=False, + is_flag=True, + help="Enabling this flag makes the CLI wait until the cluster is running.", +) +def create_cluster( + cluster_name: str, + region: str, + role_arn: str, + external_id: str, + provider: str, + instance_types: str, + edit_before_creation: bool, + cost_savings: bool, + wait: bool, + **kwargs, +): + """Create a Lightning AI BYOC compute cluster with your cloud provider credentials.""" + if provider != "aws": + click.echo("Only AWS is supported for now. But support for more providers is coming soon.") + return + cluster_manager = AWSClusterManager() + cluster_manager.create( + cluster_name=cluster_name, + region=region, + role_arn=role_arn, + external_id=external_id, + instance_types=instance_types.split(","), + edit_before_creation=edit_before_creation, + cost_savings=cost_savings, + wait=wait, + ) diff --git a/src/lightning_app/cli/lightning_cli_delete.py b/src/lightning_app/cli/lightning_cli_delete.py new file mode 100644 index 0000000000..c304b130bd --- /dev/null +++ b/src/lightning_app/cli/lightning_cli_delete.py @@ -0,0 +1,49 @@ +import click + +from lightning_app.cli.cmd_clusters import AWSClusterManager + + +@click.group("delete") +def delete(): + """Delete Lightning AI BYOC managed resources.""" + pass + + +@delete.command("cluster") +@click.argument("cluster", type=str) +@click.option( + "--force", + "force", + type=bool, + required=False, + default=False, + is_flag=True, + help="""Delete a BYOC cluster from Lightning AI. This does NOT delete any resources created by the cluster, + it just removes the entry from Lightning AI. + + WARNING: You should NOT use this under normal circumstances.""", +) +@click.option( + "--wait", + "wait", + type=bool, + required=False, + default=False, + is_flag=True, + help="Enabling this flag makes the CLI wait until the cluster is deleted.", +) +def delete_cluster(cluster: str, force: bool = False, wait: bool = False): + """Delete a Lightning AI BYOC compute cluster and all associated cloud provider resources. + + Deleting a run also deletes all Runs and Experiments that were started on the cluster. + Deletion permanently removes not only the record of all runs on a cluster, but all associated experiments, + artifacts, metrics, logs, etc. + + WARNING: This process may take a few minutes to complete, but once started it CANNOT be rolled back. + Deletion permanently removes not only the BYOC cluster from being managed by Lightning AI, but tears down + every BYOC resource Lightning AI managed (for that cluster id) in the host cloud. + + All object stores, container registries, logs, compute nodes, volumes, etc. are deleted and cannot be recovered. + """ + cluster_manager = AWSClusterManager() + cluster_manager.delete(cluster_id=cluster, force=force, wait=wait) diff --git a/src/lightning_app/cli/lightning_cli_list.py b/src/lightning_app/cli/lightning_cli_list.py new file mode 100644 index 0000000000..31f46537e8 --- /dev/null +++ b/src/lightning_app/cli/lightning_cli_list.py @@ -0,0 +1,16 @@ +import click + +from lightning_app.cli.cmd_clusters import AWSClusterManager + + +@click.group(name="list") +def get_list(): + """List your Lightning AI BYOC managed resources.""" + pass + + +@get_list.command("clusters") +def list_clusters(**kwargs): + """List your Lightning AI BYOC compute clusters.""" + cluster_manager = AWSClusterManager() + cluster_manager.list() diff --git a/src/lightning_app/testing/testing.py b/src/lightning_app/testing/testing.py index cc03f5bade..10abdac4aa 100644 --- a/src/lightning_app/testing/testing.py +++ b/src/lightning_app/testing/testing.py @@ -2,6 +2,7 @@ import asyncio import json import os import shutil +import subprocess import sys import tempfile import time @@ -130,6 +131,29 @@ def browser_context_args(browser_context_args: Dict) -> Dict: } +@contextmanager +def run_cli(args) -> Generator: + """This utility is used to automate end-to-end testing of the Lightning AI CLI.""" + cmd = [ + sys.executable, + "-m", + "lightning", + ] + args + + with tempfile.TemporaryDirectory() as tmpdir: + env_copy = os.environ.copy() + process = Popen( + cmd, + cwd=tmpdir, + env=env_copy, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + ) + process.wait() + + yield process.stdout.read().decode("UTF-8"), process.stderr.read().decode("UTF-8") + + @requires("playwright") @contextmanager def run_app_in_cloud(app_folder: str, app_name: str = "app.py") -> Generator: diff --git a/src/lightning_app/utilities/openapi.py b/src/lightning_app/utilities/openapi.py new file mode 100644 index 0000000000..f533b1f8de --- /dev/null +++ b/src/lightning_app/utilities/openapi.py @@ -0,0 +1,61 @@ +import json +from typing import Any, Dict + + +def _duplicate_checker(js): + """_duplicate_checker verifies that your JSON object doesn't contain duplicate keys.""" + result = {} + for name, value in js: + if name in result: + raise ValueError( + f"Unable to load JSON. A duplicate key {name} was detected. JSON objects must have unique keys." + ) + result[name] = value + return result + + +def string2dict(text): + """string2dict parses a JSON string into a dictionary, ensuring no keys are duplicated by accident.""" + if not isinstance(text, str): + text = text.decode("utf-8") + try: + js = json.loads(text, object_pairs_hook=_duplicate_checker) + return js + except ValueError as e: + raise ValueError(f"Unable to load JSON: {str(e)}.") + + +def is_openapi(obj): + """is_openopi checks if an object was generated by OpenAPI.""" + return hasattr(obj, "swagger_types") + + +def create_openapi_object(json_obj: Dict, target: Any): + """Create the OpenAPI object from the given JSON dict and based on the target object. + + Lightning AI uses the target object to make new objects from the given JSON spec so the target must be a valid + object. + """ + if not isinstance(json_obj, dict): + raise TypeError("json_obj must be a dictionary") + if not is_openapi(target): + raise TypeError("target must be an openapi object") + + target_attribs = {} + for key, value in json_obj.items(): + try: + # user provided key is not a valid key on openapi object + sub_target = getattr(target, key) + except AttributeError: + raise ValueError(f"Field {key} not found in the target object") + + if is_openapi(sub_target): # it's an openapi object + target_attribs[key] = create_openapi_object(value, sub_target) + else: + target_attribs[key] = value + + # TODO(sherin) - specifically process list and dict and do the validation. Also do the + # verification for enum types + + new_target = target.__class__(**target_attribs) + return new_target diff --git a/tests/tests_app/cli/test_cli.py b/tests/tests_app/cli/test_cli.py index 39d8d6b789..3e00329369 100644 --- a/tests/tests_app/cli/test_cli.py +++ b/tests/tests_app/cli/test_cli.py @@ -1,11 +1,15 @@ import os from unittest import mock +from unittest.mock import MagicMock import pytest from click.testing import CliRunner from lightning_cloud.openapi import Externalv1LightningappInstance from lightning_app.cli.lightning_cli import _main, get_app_url, login, logout, run +from lightning_app.cli.lightning_cli_create import create, create_cluster +from lightning_app.cli.lightning_cli_delete import delete, delete_cluster +from lightning_app.cli.lightning_cli_list import get_list, list_clusters from lightning_app.runners.runtime_type import RuntimeType @@ -37,7 +41,7 @@ def test_start_target_url(runtime_type, extra_args, lightning_cloud_url, expecte assert get_app_url(runtime_type, *extra_args) == expected_url -@pytest.mark.parametrize("command", [_main, run]) +@pytest.mark.parametrize("command", [_main, run, get_list, create, delete]) def test_commands(command): runner = CliRunner() result = runner.invoke(command) @@ -50,6 +54,9 @@ def test_main_lightning_cli_help(): assert "login " in res assert "logout " in res assert "run " in res + assert "list " in res + assert "delete " in res + assert "create " in res res = os.popen("python -m lightning run --help").read() assert "app " in res @@ -61,6 +68,55 @@ def test_main_lightning_cli_help(): assert "frontend" not in res +@mock.patch("lightning_cloud.login.Auth.authenticate", MagicMock()) +@mock.patch("lightning_app.cli.cmd_clusters.AWSClusterManager.create") +def test_create_cluster(create: mock.MagicMock): + runner = CliRunner() + runner.invoke( + create_cluster, + [ + "test-7", + "--provider", + "aws", + "--external-id", + "dummy", + "--role-arn", + "arn:aws:iam::1234567890:role/lai-byoc", + "--instance-types", + "t2.small", + ], + ) + + create.assert_called_once_with( + cluster_name="test-7", + region="us-east-1", + role_arn="arn:aws:iam::1234567890:role/lai-byoc", + external_id="dummy", + instance_types=["t2.small"], + edit_before_creation=False, + cost_savings=False, + wait=False, + ) + + +@mock.patch("lightning_cloud.login.Auth.authenticate", MagicMock()) +@mock.patch("lightning_app.cli.cmd_clusters.AWSClusterManager.list") +def test_list_clusters(list: mock.MagicMock): + runner = CliRunner() + runner.invoke(list_clusters) + + list.assert_called_once_with() + + +@mock.patch("lightning_cloud.login.Auth.authenticate", MagicMock()) +@mock.patch("lightning_app.cli.cmd_clusters.AWSClusterManager.delete") +def test_delete_cluster(delete: mock.MagicMock): + runner = CliRunner() + runner.invoke(delete_cluster, ["test-7"]) + + delete.assert_called_once_with(cluster_id="test-7", force=False, wait=False) + + @mock.patch("lightning_app.utilities.login.Auth._run_server") @mock.patch("lightning_app.utilities.login.Auth.clear") def test_cli_login(clear: mock.MagicMock, run_server: mock.MagicMock): diff --git a/tests/tests_app/cli/test_cmd_clusters.py b/tests/tests_app/cli/test_cmd_clusters.py new file mode 100644 index 0000000000..e835643fd9 --- /dev/null +++ b/tests/tests_app/cli/test_cmd_clusters.py @@ -0,0 +1,135 @@ +from unittest import mock +from unittest.mock import MagicMock + +import click +import pytest +from lightning_cloud.openapi import ( + V1AWSClusterDriverSpec, + V1ClusterDriver, + V1ClusterPerformanceProfile, + V1ClusterSpec, + V1ClusterType, + V1CreateClusterRequest, + V1InstanceSpec, + V1KubernetesClusterDriver, +) +from lightning_cloud.openapi.models import Externalv1Cluster, V1ClusterState, V1ClusterStatus, V1ListClustersResponse + +from lightning_app.cli import cmd_clusters +from lightning_app.cli.cmd_clusters import AWSClusterManager + + +class FakeLightningClient: + def __init__(self, list_responses=[], consume=True): + self.list_responses = list_responses + self.list_call_count = 0 + self.consume = consume + + def cluster_service_list_clusters(self, phase_not_in=None): + self.list_call_count = self.list_call_count + 1 + if self.consume: + return self.list_responses.pop() + return self.list_responses[0] + + +@mock.patch("lightning_cloud.login.Auth.authenticate", MagicMock()) +@mock.patch("lightning_app.utilities.network.LightningClient.cluster_service_create_cluster") +def test_create_cluster(api: mock.MagicMock): + cluster_manager = AWSClusterManager() + cluster_manager.create( + cluster_name="test-7", + external_id="dummy", + role_arn="arn:aws:iam::1234567890:role/lai-byoc", + instance_types=["t2.small"], + region="us-west-2", + ) + + api.assert_called_once_with( + body=V1CreateClusterRequest( + name="test-7", + spec=V1ClusterSpec( + cluster_type=V1ClusterType.BYOC, + performance_profile=V1ClusterPerformanceProfile.DEFAULT, + driver=V1ClusterDriver( + kubernetes=V1KubernetesClusterDriver( + aws=V1AWSClusterDriverSpec( + region="us-west-2", + role_arn="arn:aws:iam::1234567890:role/lai-byoc", + external_id="dummy", + instance_types=[V1InstanceSpec(name="t2.small")], + ) + ) + ), + ), + ) + ) + + +@mock.patch("lightning_cloud.login.Auth.authenticate", MagicMock()) +@mock.patch("lightning_app.utilities.network.LightningClient.cluster_service_list_clusters") +def test_list_clusters(api: mock.MagicMock): + cluster_manager = AWSClusterManager() + cluster_manager.list() + + api.assert_called_once_with(phase_not_in=[V1ClusterState.DELETED]) + + +@mock.patch("lightning_cloud.login.Auth.authenticate", MagicMock()) +@mock.patch("lightning_app.utilities.network.LightningClient.cluster_service_delete_cluster") +def test_delete_cluster(api: mock.MagicMock): + cluster_manager = AWSClusterManager() + cluster_manager.delete(cluster_id="test-7") + + api.assert_called_once_with(id="test-7", force=False) + + +class Test_check_cluster_name_is_valid: + @pytest.mark.parametrize("name", ["test-7", "0wildgoat"]) + def test_valid(self, name): + assert cmd_clusters._check_cluster_name_is_valid(None, None, name) + + @pytest.mark.parametrize( + "name", ["(&%)!@#", "1234567890123456789012345678901234567890123456789012345678901234567890"] + ) + def test_invalid(self, name): + with pytest.raises(click.ClickException) as e: + cmd_clusters._check_cluster_name_is_valid(None, None, name) + assert "cluster name doesn't match regex pattern" in str(e.value) + + +class Test_wait_for_cluster_state: + # TODO(rra) add tests for pagination + + @pytest.mark.parametrize("target_state", [V1ClusterState.RUNNING, V1ClusterState.DELETED]) + @pytest.mark.parametrize( + "previous_state", [V1ClusterState.QUEUED, V1ClusterState.PENDING, V1ClusterState.UNSPECIFIED] + ) + def test_happy_path(self, target_state, previous_state): + client = FakeLightningClient( + list_responses=[ + V1ListClustersResponse( + clusters=[Externalv1Cluster(id="test-cluster", status=V1ClusterStatus(phase=state))] + ) + for state in [previous_state, target_state] + ] + ) + cmd_clusters._wait_for_cluster_state(client, "test-cluster", target_state, check_timeout=0.1) + assert client.list_call_count == 1 + + @pytest.mark.parametrize("target_state", [V1ClusterState.RUNNING, V1ClusterState.DELETED]) + def test_times_out(self, target_state): + client = FakeLightningClient( + list_responses=[ + V1ListClustersResponse( + clusters=[ + Externalv1Cluster(id="test-cluster", status=V1ClusterStatus(phase=V1ClusterState.UNSPECIFIED)) + ] + ) + ], + consume=False, + ) + with pytest.raises(click.ClickException) as e: + cmd_clusters._wait_for_cluster_state( + client, "test-cluster", target_state, max_wait_time=0.4, check_timeout=0.2 + ) + assert "Max wait time elapsed" in str(e.value) diff --git a/tests/tests_clusters/__init__.py b/tests/tests_clusters/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/tests_clusters/test_cluster_lifecycle.py b/tests/tests_clusters/test_cluster_lifecycle.py new file mode 100644 index 0000000000..cd48761f5f --- /dev/null +++ b/tests/tests_clusters/test_cluster_lifecycle.py @@ -0,0 +1,53 @@ +import os +import uuid + +import pytest + +from src.lightning_app.testing.testing import run_cli + + +@pytest.mark.cloud +@pytest.mark.skipif( + os.environ.get("LIGHTNING_BYOC_ROLE_ARN") is None, reason="missing LIGHTNING_BYOC_ROLE_ARN environment variable" +) +@pytest.mark.skipif( + os.environ.get("LIGHTNING_BYOC_EXTERNAL_ID") is None, + reason="missing LIGHTNING_BYOC_EXTERNAL_ID environment variable", +) +def test_cluster_lifecycle() -> None: + role_arn = os.environ.get("LIGHTNING_BYOC_ROLE_ARN", None) + external_id = os.environ.get("LIGHTNING_BYOC_EXTERNAL_ID", None) + region = "us-west-2" + instance_types = "t2.small,t3.small" + cluster_name = "byoc-%s" % (uuid.uuid4()) + with run_cli( + [ + "create", + "cluster", + cluster_name, + "--provider", + "aws", + "--role-arn", + role_arn, + "--external-id", + external_id, + "--region", + region, + "--instance-types", + instance_types, + "--wait", + ] + ) as (stdout, stderr): + assert "success" in stdout, f"stdout: {stdout}\nstderr: {stderr}" + + with run_cli(["list", "clusters"]) as (stdout, stderr): + assert cluster_name in stdout, f"stdout: {stdout}\nstderr: {stderr}" + + with run_cli(["delete", "cluster", "--force", cluster_name]) as (stdout, stderr): + assert "success" in stdout, f"stdout: {stdout}\nstderr: {stderr}" + + +@pytest.mark.cloud +def test_cluster_list() -> None: + with run_cli(["list", "clusters"]) as (stdout, stderr): + assert "lightning-cloud" in stdout, f"stdout: {stdout}\nstderr: {stderr}"