[CLI] add support for cluster management (#13835)
This commit is contained in:
parent
b3203d93d0
commit
2919dcf7ee
|
@ -9,6 +9,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
|
|||
### Added
|
||||
|
||||
- Add support for `Lightning App Commands` through the `configure_commands` hook on the Lightning Flow and the `ClientCommand` ([#13602](https://github.com/Lightning-AI/lightning/pull/13602))
|
||||
- Add support for Lightning AI BYOC cluster management ([#13835](https://github.com/Lightning-AI/lightning/pull/13835))
|
||||
|
||||
- Adds `LightningTrainingComponent`. `LightningTrainingComponent` orchestrates multi-node training in the cloud ([#13830](https://github.com/Lightning-AI/lightning/pull/13830))
|
||||
|
||||
|
|
|
@ -0,0 +1,206 @@
|
|||
import json
|
||||
import re
|
||||
import time
|
||||
from datetime import datetime
|
||||
|
||||
import click
|
||||
from lightning_cloud.openapi import (
|
||||
V1AWSClusterDriverSpec,
|
||||
V1ClusterDriver,
|
||||
V1ClusterPerformanceProfile,
|
||||
V1ClusterSpec,
|
||||
V1CreateClusterRequest,
|
||||
V1InstanceSpec,
|
||||
V1KubernetesClusterDriver,
|
||||
)
|
||||
from lightning_cloud.openapi.models import Externalv1Cluster, V1ClusterState, V1ClusterType
|
||||
from rich.console import Console
|
||||
from rich.table import Table
|
||||
from rich.text import Text
|
||||
|
||||
from lightning_app.cli.core import Formatable
|
||||
from lightning_app.utilities.network import LightningClient
|
||||
from lightning_app.utilities.openapi import create_openapi_object, string2dict
|
||||
|
||||
CLUSTER_STATE_CHECKING_TIMEOUT = 60
|
||||
MAX_CLUSTER_WAIT_TIME = 5400
|
||||
|
||||
|
||||
class AWSClusterManager:
|
||||
"""AWSClusterManager implements API calls specific to Lightning AI BYOC compute clusters when the AWS provider
|
||||
is selected as the backend compute."""
|
||||
|
||||
def __init__(self):
|
||||
self.api_client = LightningClient()
|
||||
|
||||
def create(
|
||||
self,
|
||||
cost_savings: bool = False,
|
||||
cluster_name: str = None,
|
||||
role_arn: str = None,
|
||||
region: str = "us-east-1",
|
||||
external_id: str = None,
|
||||
instance_types: [str] = [],
|
||||
edit_before_creation: bool = False,
|
||||
wait: bool = False,
|
||||
):
|
||||
"""request Lightning AI BYOC compute cluster creation.
|
||||
|
||||
Args:
|
||||
cost_savings: Specifies if the cluster uses cost savings mode
|
||||
cluster_name: The name of the cluster to be created
|
||||
role_arn: AWS IAM Role ARN used to provision resources
|
||||
region: AWS region containing compute resources
|
||||
external_id: AWS IAM Role external ID
|
||||
instance_types: AWS instance types supported by the cluster
|
||||
edit_before_creation: Enables interactive editing of requests before submitting it to Lightning AI.
|
||||
wait: Waits for the cluster to be in a RUNNING state. Only use this for debugging.
|
||||
"""
|
||||
performance_profile = V1ClusterPerformanceProfile.DEFAULT
|
||||
if cost_savings:
|
||||
"""In cost saving mode the number of compute nodes is reduced to one, reducing the cost for clusters
|
||||
with low utilization."""
|
||||
performance_profile = V1ClusterPerformanceProfile.COST_SAVING
|
||||
|
||||
body = V1CreateClusterRequest(
|
||||
name=cluster_name,
|
||||
spec=V1ClusterSpec(
|
||||
cluster_type=V1ClusterType.BYOC,
|
||||
performance_profile=performance_profile,
|
||||
driver=V1ClusterDriver(
|
||||
kubernetes=V1KubernetesClusterDriver(
|
||||
aws=V1AWSClusterDriverSpec(
|
||||
region=region,
|
||||
role_arn=role_arn,
|
||||
external_id=external_id,
|
||||
instance_types=[V1InstanceSpec(name=x) for x in instance_types],
|
||||
)
|
||||
)
|
||||
),
|
||||
),
|
||||
)
|
||||
new_body = body
|
||||
if edit_before_creation:
|
||||
after = click.edit(json.dumps(body.to_dict(), indent=4))
|
||||
if after is not None:
|
||||
new_body = create_openapi_object(string2dict(after), body)
|
||||
if new_body == body:
|
||||
click.echo("cluster unchanged")
|
||||
|
||||
resp = self.api_client.cluster_service_create_cluster(body=new_body)
|
||||
if wait:
|
||||
_wait_for_cluster_state(self.api_client, resp.id, V1ClusterState.RUNNING)
|
||||
|
||||
click.echo(f"${resp.id} cluster is ${resp.status.phase}")
|
||||
|
||||
def list(self):
|
||||
resp = self.api_client.cluster_service_list_clusters(phase_not_in=[V1ClusterState.DELETED])
|
||||
console = Console()
|
||||
console.print(ClusterList(resp.clusters).as_table())
|
||||
|
||||
def delete(self, cluster_id: str = None, force: bool = False, wait: bool = False):
|
||||
if force:
|
||||
click.echo(
|
||||
"""
|
||||
Deletes a BYOC cluster. Lightning AI removes cluster artifacts and any resources running on the cluster.\n
|
||||
WARNING: Deleting a cluster does not clean up any resources managed by Lightning AI.\n
|
||||
Check your cloud provider to verify that existing cloud resources are deleted.
|
||||
"""
|
||||
)
|
||||
click.confirm("Do you want to continue?", abort=True)
|
||||
|
||||
self.api_client.cluster_service_delete_cluster(id=cluster_id, force=force)
|
||||
click.echo("Cluster deletion triggered successfully")
|
||||
|
||||
if wait:
|
||||
_wait_for_cluster_state(self.api_client, cluster_id, V1ClusterState.DELETED)
|
||||
|
||||
|
||||
class ClusterList(Formatable):
|
||||
def __init__(self, clusters: [Externalv1Cluster]):
|
||||
self.clusters = clusters
|
||||
|
||||
def as_json(self) -> str:
|
||||
return json.dumps(self.clusters)
|
||||
|
||||
def as_table(self) -> Table:
|
||||
table = Table("id", "name", "type", "status", "created", show_header=True, header_style="bold green")
|
||||
phases = {
|
||||
V1ClusterState.QUEUED: Text("queued", style="bold yellow"),
|
||||
V1ClusterState.PENDING: Text("pending", style="bold yellow"),
|
||||
V1ClusterState.RUNNING: Text("running", style="bold green"),
|
||||
V1ClusterState.FAILED: Text("failed", style="bold red"),
|
||||
V1ClusterState.DELETED: Text("deleted", style="bold red"),
|
||||
}
|
||||
|
||||
cluster_type_lookup = {
|
||||
V1ClusterType.BYOC: Text("byoc", style="bold yellow"),
|
||||
V1ClusterType.GLOBAL: Text("lightning-cloud", style="bold green"),
|
||||
}
|
||||
for cluster in self.clusters:
|
||||
cluster: Externalv1Cluster
|
||||
status = phases[cluster.status.phase]
|
||||
if cluster.spec.desired_state == V1ClusterState.DELETED and cluster.status.phase != V1ClusterState.DELETED:
|
||||
status = Text("terminating", style="bold red")
|
||||
|
||||
# this guard is necessary only until 0.3.93 releases which includes the `created_at`
|
||||
# field to the external API
|
||||
created_at = datetime.now()
|
||||
if hasattr(cluster, "created_at"):
|
||||
created_at = cluster.created_at
|
||||
|
||||
table.add_row(
|
||||
cluster.id,
|
||||
cluster.name,
|
||||
cluster_type_lookup.get(cluster.spec.cluster_type, Text("unknown", style="red")),
|
||||
status,
|
||||
created_at.strftime("%Y-%m-%d") if created_at else "",
|
||||
)
|
||||
return table
|
||||
|
||||
|
||||
def _wait_for_cluster_state(
|
||||
api_client: LightningClient,
|
||||
cluster_id: str,
|
||||
target_state: V1ClusterState,
|
||||
max_wait_time: int = MAX_CLUSTER_WAIT_TIME,
|
||||
check_timeout: int = CLUSTER_STATE_CHECKING_TIMEOUT,
|
||||
):
|
||||
"""_wait_for_cluster_state waits until the provided cluster has reached a desired state, or failed.
|
||||
|
||||
Args:
|
||||
api_client: LightningClient used for polling
|
||||
cluster_id: Specifies the cluster to wait for
|
||||
target_state: Specifies the desired state the target cluster needs to meet
|
||||
max_wait_time: Maximum duration to wait (in seconds)
|
||||
check_timeout: duration between polling for the cluster state (in seconds)
|
||||
"""
|
||||
start = time.time()
|
||||
elapsed = 0
|
||||
while elapsed < max_wait_time:
|
||||
cluster_resp = api_client.cluster_service_list_clusters()
|
||||
new_cluster = None
|
||||
for clust in cluster_resp.clusters:
|
||||
if clust.id == cluster_id:
|
||||
new_cluster = clust
|
||||
break
|
||||
if new_cluster is not None:
|
||||
if new_cluster.status.phase == target_state:
|
||||
break
|
||||
elif new_cluster.status.phase == V1ClusterState.FAILED:
|
||||
raise click.ClickException(f"Cluster {cluster_id} is in failed state.")
|
||||
time.sleep(check_timeout)
|
||||
elapsed = time.time() - start
|
||||
else:
|
||||
raise click.ClickException("Max wait time elapsed")
|
||||
|
||||
|
||||
def _check_cluster_name_is_valid(_ctx, _param, value):
|
||||
pattern = r"^(?!-)[a-z0-9-]{1,63}(?<!-)$"
|
||||
if not re.match(pattern, value):
|
||||
raise click.ClickException(
|
||||
"""The cluster name is invalid.
|
||||
Cluster names can only contain lowercase letters, numbers, and periodic hyphens ( - ).
|
||||
Provide a cluster name using valid characters and try again."""
|
||||
)
|
||||
return value
|
|
@ -0,0 +1,13 @@
|
|||
import abc
|
||||
|
||||
from rich.table import Table
|
||||
|
||||
|
||||
class Formatable(abc.ABC):
|
||||
@abc.abstractmethod
|
||||
def as_table(self) -> Table:
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
def as_json(self) -> str:
|
||||
pass
|
|
@ -12,6 +12,9 @@ from requests.exceptions import ConnectionError
|
|||
|
||||
from lightning_app import __version__ as ver
|
||||
from lightning_app.cli import cmd_init, cmd_install, cmd_pl_init, cmd_react_ui_init
|
||||
from lightning_app.cli.lightning_cli_create import create
|
||||
from lightning_app.cli.lightning_cli_delete import delete
|
||||
from lightning_app.cli.lightning_cli_list import get_list
|
||||
from lightning_app.core.constants import get_lightning_cloud_url, LOCAL_LAUNCH_ADMIN_VIEW
|
||||
from lightning_app.runners.runtime import dispatch
|
||||
from lightning_app.runners.runtime_type import RuntimeType
|
||||
|
@ -206,16 +209,9 @@ def stop():
|
|||
pass
|
||||
|
||||
|
||||
@_main.group(hidden=True)
|
||||
def delete():
|
||||
"""Delete an application."""
|
||||
pass
|
||||
|
||||
|
||||
@_main.group(name="list", hidden=True)
|
||||
def get_list():
|
||||
"""List your applications."""
|
||||
pass
|
||||
_main.add_command(get_list)
|
||||
_main.add_command(delete)
|
||||
_main.add_command(create)
|
||||
|
||||
|
||||
@_main.group()
|
||||
|
|
|
@ -0,0 +1,86 @@
|
|||
import click
|
||||
|
||||
from lightning_app.cli.cmd_clusters import _check_cluster_name_is_valid, AWSClusterManager
|
||||
|
||||
|
||||
@click.group("create")
|
||||
def create():
|
||||
"""Create Lightning AI BYOC managed resources."""
|
||||
pass
|
||||
|
||||
|
||||
@create.command("cluster")
|
||||
@click.argument("cluster_name", callback=_check_cluster_name_is_valid)
|
||||
@click.option("--provider", "provider", type=str, default="aws", help="cloud provider to be used for your cluster")
|
||||
@click.option("--external-id", "external_id", type=str, required=True)
|
||||
@click.option(
|
||||
"--role-arn", "role_arn", type=str, required=True, help="AWS role ARN attached to the associated resources."
|
||||
)
|
||||
@click.option(
|
||||
"--region",
|
||||
"region",
|
||||
type=str,
|
||||
required=False,
|
||||
default="us-east-1",
|
||||
help="AWS region that is used to host the associated resources.",
|
||||
)
|
||||
@click.option(
|
||||
"--instance-types",
|
||||
"instance_types",
|
||||
type=str,
|
||||
required=False,
|
||||
default=None,
|
||||
help="Instance types that you want to support, for computer jobs within the cluster.",
|
||||
)
|
||||
@click.option(
|
||||
"--cost-savings",
|
||||
"cost_savings",
|
||||
type=bool,
|
||||
required=False,
|
||||
default=False,
|
||||
is_flag=True,
|
||||
help=""""Use this flag to ensure that the cluster is created with a profile that is optimized for cost savings.
|
||||
This makes runs cheaper but start-up times may increase.""",
|
||||
)
|
||||
@click.option(
|
||||
"--edit-before-creation",
|
||||
default=False,
|
||||
is_flag=True,
|
||||
help="Edit the cluster specs before submitting them to the API server.",
|
||||
)
|
||||
@click.option(
|
||||
"--wait",
|
||||
"wait",
|
||||
type=bool,
|
||||
required=False,
|
||||
default=False,
|
||||
is_flag=True,
|
||||
help="Enabling this flag makes the CLI wait until the cluster is running.",
|
||||
)
|
||||
def create_cluster(
|
||||
cluster_name: str,
|
||||
region: str,
|
||||
role_arn: str,
|
||||
external_id: str,
|
||||
provider: str,
|
||||
instance_types: str,
|
||||
edit_before_creation: bool,
|
||||
cost_savings: bool,
|
||||
wait: bool,
|
||||
**kwargs,
|
||||
):
|
||||
"""Create a Lightning AI BYOC compute cluster with your cloud provider credentials."""
|
||||
if provider != "aws":
|
||||
click.echo("Only AWS is supported for now. But support for more providers is coming soon.")
|
||||
return
|
||||
cluster_manager = AWSClusterManager()
|
||||
cluster_manager.create(
|
||||
cluster_name=cluster_name,
|
||||
region=region,
|
||||
role_arn=role_arn,
|
||||
external_id=external_id,
|
||||
instance_types=instance_types.split(","),
|
||||
edit_before_creation=edit_before_creation,
|
||||
cost_savings=cost_savings,
|
||||
wait=wait,
|
||||
)
|
|
@ -0,0 +1,49 @@
|
|||
import click
|
||||
|
||||
from lightning_app.cli.cmd_clusters import AWSClusterManager
|
||||
|
||||
|
||||
@click.group("delete")
|
||||
def delete():
|
||||
"""Delete Lightning AI BYOC managed resources."""
|
||||
pass
|
||||
|
||||
|
||||
@delete.command("cluster")
|
||||
@click.argument("cluster", type=str)
|
||||
@click.option(
|
||||
"--force",
|
||||
"force",
|
||||
type=bool,
|
||||
required=False,
|
||||
default=False,
|
||||
is_flag=True,
|
||||
help="""Delete a BYOC cluster from Lightning AI. This does NOT delete any resources created by the cluster,
|
||||
it just removes the entry from Lightning AI.
|
||||
|
||||
WARNING: You should NOT use this under normal circumstances.""",
|
||||
)
|
||||
@click.option(
|
||||
"--wait",
|
||||
"wait",
|
||||
type=bool,
|
||||
required=False,
|
||||
default=False,
|
||||
is_flag=True,
|
||||
help="Enabling this flag makes the CLI wait until the cluster is deleted.",
|
||||
)
|
||||
def delete_cluster(cluster: str, force: bool = False, wait: bool = False):
|
||||
"""Delete a Lightning AI BYOC compute cluster and all associated cloud provider resources.
|
||||
|
||||
Deleting a run also deletes all Runs and Experiments that were started on the cluster.
|
||||
Deletion permanently removes not only the record of all runs on a cluster, but all associated experiments,
|
||||
artifacts, metrics, logs, etc.
|
||||
|
||||
WARNING: This process may take a few minutes to complete, but once started it CANNOT be rolled back.
|
||||
Deletion permanently removes not only the BYOC cluster from being managed by Lightning AI, but tears down
|
||||
every BYOC resource Lightning AI managed (for that cluster id) in the host cloud.
|
||||
|
||||
All object stores, container registries, logs, compute nodes, volumes, etc. are deleted and cannot be recovered.
|
||||
"""
|
||||
cluster_manager = AWSClusterManager()
|
||||
cluster_manager.delete(cluster_id=cluster, force=force, wait=wait)
|
|
@ -0,0 +1,16 @@
|
|||
import click
|
||||
|
||||
from lightning_app.cli.cmd_clusters import AWSClusterManager
|
||||
|
||||
|
||||
@click.group(name="list")
|
||||
def get_list():
|
||||
"""List your Lightning AI BYOC managed resources."""
|
||||
pass
|
||||
|
||||
|
||||
@get_list.command("clusters")
|
||||
def list_clusters(**kwargs):
|
||||
"""List your Lightning AI BYOC compute clusters."""
|
||||
cluster_manager = AWSClusterManager()
|
||||
cluster_manager.list()
|
|
@ -2,6 +2,7 @@ import asyncio
|
|||
import json
|
||||
import os
|
||||
import shutil
|
||||
import subprocess
|
||||
import sys
|
||||
import tempfile
|
||||
import time
|
||||
|
@ -130,6 +131,29 @@ def browser_context_args(browser_context_args: Dict) -> Dict:
|
|||
}
|
||||
|
||||
|
||||
@contextmanager
|
||||
def run_cli(args) -> Generator:
|
||||
"""This utility is used to automate end-to-end testing of the Lightning AI CLI."""
|
||||
cmd = [
|
||||
sys.executable,
|
||||
"-m",
|
||||
"lightning",
|
||||
] + args
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
env_copy = os.environ.copy()
|
||||
process = Popen(
|
||||
cmd,
|
||||
cwd=tmpdir,
|
||||
env=env_copy,
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.PIPE,
|
||||
)
|
||||
process.wait()
|
||||
|
||||
yield process.stdout.read().decode("UTF-8"), process.stderr.read().decode("UTF-8")
|
||||
|
||||
|
||||
@requires("playwright")
|
||||
@contextmanager
|
||||
def run_app_in_cloud(app_folder: str, app_name: str = "app.py") -> Generator:
|
||||
|
|
|
@ -0,0 +1,61 @@
|
|||
import json
|
||||
from typing import Any, Dict
|
||||
|
||||
|
||||
def _duplicate_checker(js):
|
||||
"""_duplicate_checker verifies that your JSON object doesn't contain duplicate keys."""
|
||||
result = {}
|
||||
for name, value in js:
|
||||
if name in result:
|
||||
raise ValueError(
|
||||
f"Unable to load JSON. A duplicate key {name} was detected. JSON objects must have unique keys."
|
||||
)
|
||||
result[name] = value
|
||||
return result
|
||||
|
||||
|
||||
def string2dict(text):
|
||||
"""string2dict parses a JSON string into a dictionary, ensuring no keys are duplicated by accident."""
|
||||
if not isinstance(text, str):
|
||||
text = text.decode("utf-8")
|
||||
try:
|
||||
js = json.loads(text, object_pairs_hook=_duplicate_checker)
|
||||
return js
|
||||
except ValueError as e:
|
||||
raise ValueError(f"Unable to load JSON: {str(e)}.")
|
||||
|
||||
|
||||
def is_openapi(obj):
|
||||
"""is_openopi checks if an object was generated by OpenAPI."""
|
||||
return hasattr(obj, "swagger_types")
|
||||
|
||||
|
||||
def create_openapi_object(json_obj: Dict, target: Any):
|
||||
"""Create the OpenAPI object from the given JSON dict and based on the target object.
|
||||
|
||||
Lightning AI uses the target object to make new objects from the given JSON spec so the target must be a valid
|
||||
object.
|
||||
"""
|
||||
if not isinstance(json_obj, dict):
|
||||
raise TypeError("json_obj must be a dictionary")
|
||||
if not is_openapi(target):
|
||||
raise TypeError("target must be an openapi object")
|
||||
|
||||
target_attribs = {}
|
||||
for key, value in json_obj.items():
|
||||
try:
|
||||
# user provided key is not a valid key on openapi object
|
||||
sub_target = getattr(target, key)
|
||||
except AttributeError:
|
||||
raise ValueError(f"Field {key} not found in the target object")
|
||||
|
||||
if is_openapi(sub_target): # it's an openapi object
|
||||
target_attribs[key] = create_openapi_object(value, sub_target)
|
||||
else:
|
||||
target_attribs[key] = value
|
||||
|
||||
# TODO(sherin) - specifically process list and dict and do the validation. Also do the
|
||||
# verification for enum types
|
||||
|
||||
new_target = target.__class__(**target_attribs)
|
||||
return new_target
|
|
@ -1,11 +1,15 @@
|
|||
import os
|
||||
from unittest import mock
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
from click.testing import CliRunner
|
||||
from lightning_cloud.openapi import Externalv1LightningappInstance
|
||||
|
||||
from lightning_app.cli.lightning_cli import _main, get_app_url, login, logout, run
|
||||
from lightning_app.cli.lightning_cli_create import create, create_cluster
|
||||
from lightning_app.cli.lightning_cli_delete import delete, delete_cluster
|
||||
from lightning_app.cli.lightning_cli_list import get_list, list_clusters
|
||||
from lightning_app.runners.runtime_type import RuntimeType
|
||||
|
||||
|
||||
|
@ -37,7 +41,7 @@ def test_start_target_url(runtime_type, extra_args, lightning_cloud_url, expecte
|
|||
assert get_app_url(runtime_type, *extra_args) == expected_url
|
||||
|
||||
|
||||
@pytest.mark.parametrize("command", [_main, run])
|
||||
@pytest.mark.parametrize("command", [_main, run, get_list, create, delete])
|
||||
def test_commands(command):
|
||||
runner = CliRunner()
|
||||
result = runner.invoke(command)
|
||||
|
@ -50,6 +54,9 @@ def test_main_lightning_cli_help():
|
|||
assert "login " in res
|
||||
assert "logout " in res
|
||||
assert "run " in res
|
||||
assert "list " in res
|
||||
assert "delete " in res
|
||||
assert "create " in res
|
||||
|
||||
res = os.popen("python -m lightning run --help").read()
|
||||
assert "app " in res
|
||||
|
@ -61,6 +68,55 @@ def test_main_lightning_cli_help():
|
|||
assert "frontend" not in res
|
||||
|
||||
|
||||
@mock.patch("lightning_cloud.login.Auth.authenticate", MagicMock())
|
||||
@mock.patch("lightning_app.cli.cmd_clusters.AWSClusterManager.create")
|
||||
def test_create_cluster(create: mock.MagicMock):
|
||||
runner = CliRunner()
|
||||
runner.invoke(
|
||||
create_cluster,
|
||||
[
|
||||
"test-7",
|
||||
"--provider",
|
||||
"aws",
|
||||
"--external-id",
|
||||
"dummy",
|
||||
"--role-arn",
|
||||
"arn:aws:iam::1234567890:role/lai-byoc",
|
||||
"--instance-types",
|
||||
"t2.small",
|
||||
],
|
||||
)
|
||||
|
||||
create.assert_called_once_with(
|
||||
cluster_name="test-7",
|
||||
region="us-east-1",
|
||||
role_arn="arn:aws:iam::1234567890:role/lai-byoc",
|
||||
external_id="dummy",
|
||||
instance_types=["t2.small"],
|
||||
edit_before_creation=False,
|
||||
cost_savings=False,
|
||||
wait=False,
|
||||
)
|
||||
|
||||
|
||||
@mock.patch("lightning_cloud.login.Auth.authenticate", MagicMock())
|
||||
@mock.patch("lightning_app.cli.cmd_clusters.AWSClusterManager.list")
|
||||
def test_list_clusters(list: mock.MagicMock):
|
||||
runner = CliRunner()
|
||||
runner.invoke(list_clusters)
|
||||
|
||||
list.assert_called_once_with()
|
||||
|
||||
|
||||
@mock.patch("lightning_cloud.login.Auth.authenticate", MagicMock())
|
||||
@mock.patch("lightning_app.cli.cmd_clusters.AWSClusterManager.delete")
|
||||
def test_delete_cluster(delete: mock.MagicMock):
|
||||
runner = CliRunner()
|
||||
runner.invoke(delete_cluster, ["test-7"])
|
||||
|
||||
delete.assert_called_once_with(cluster_id="test-7", force=False, wait=False)
|
||||
|
||||
|
||||
@mock.patch("lightning_app.utilities.login.Auth._run_server")
|
||||
@mock.patch("lightning_app.utilities.login.Auth.clear")
|
||||
def test_cli_login(clear: mock.MagicMock, run_server: mock.MagicMock):
|
||||
|
|
|
@ -0,0 +1,135 @@
|
|||
from unittest import mock
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import click
|
||||
import pytest
|
||||
from lightning_cloud.openapi import (
|
||||
V1AWSClusterDriverSpec,
|
||||
V1ClusterDriver,
|
||||
V1ClusterPerformanceProfile,
|
||||
V1ClusterSpec,
|
||||
V1ClusterType,
|
||||
V1CreateClusterRequest,
|
||||
V1InstanceSpec,
|
||||
V1KubernetesClusterDriver,
|
||||
)
|
||||
from lightning_cloud.openapi.models import Externalv1Cluster, V1ClusterState, V1ClusterStatus, V1ListClustersResponse
|
||||
|
||||
from lightning_app.cli import cmd_clusters
|
||||
from lightning_app.cli.cmd_clusters import AWSClusterManager
|
||||
|
||||
|
||||
class FakeLightningClient:
|
||||
def __init__(self, list_responses=[], consume=True):
|
||||
self.list_responses = list_responses
|
||||
self.list_call_count = 0
|
||||
self.consume = consume
|
||||
|
||||
def cluster_service_list_clusters(self, phase_not_in=None):
|
||||
self.list_call_count = self.list_call_count + 1
|
||||
if self.consume:
|
||||
return self.list_responses.pop()
|
||||
return self.list_responses[0]
|
||||
|
||||
|
||||
@mock.patch("lightning_cloud.login.Auth.authenticate", MagicMock())
|
||||
@mock.patch("lightning_app.utilities.network.LightningClient.cluster_service_create_cluster")
|
||||
def test_create_cluster(api: mock.MagicMock):
|
||||
cluster_manager = AWSClusterManager()
|
||||
cluster_manager.create(
|
||||
cluster_name="test-7",
|
||||
external_id="dummy",
|
||||
role_arn="arn:aws:iam::1234567890:role/lai-byoc",
|
||||
instance_types=["t2.small"],
|
||||
region="us-west-2",
|
||||
)
|
||||
|
||||
api.assert_called_once_with(
|
||||
body=V1CreateClusterRequest(
|
||||
name="test-7",
|
||||
spec=V1ClusterSpec(
|
||||
cluster_type=V1ClusterType.BYOC,
|
||||
performance_profile=V1ClusterPerformanceProfile.DEFAULT,
|
||||
driver=V1ClusterDriver(
|
||||
kubernetes=V1KubernetesClusterDriver(
|
||||
aws=V1AWSClusterDriverSpec(
|
||||
region="us-west-2",
|
||||
role_arn="arn:aws:iam::1234567890:role/lai-byoc",
|
||||
external_id="dummy",
|
||||
instance_types=[V1InstanceSpec(name="t2.small")],
|
||||
)
|
||||
)
|
||||
),
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
@mock.patch("lightning_cloud.login.Auth.authenticate", MagicMock())
|
||||
@mock.patch("lightning_app.utilities.network.LightningClient.cluster_service_list_clusters")
|
||||
def test_list_clusters(api: mock.MagicMock):
|
||||
cluster_manager = AWSClusterManager()
|
||||
cluster_manager.list()
|
||||
|
||||
api.assert_called_once_with(phase_not_in=[V1ClusterState.DELETED])
|
||||
|
||||
|
||||
@mock.patch("lightning_cloud.login.Auth.authenticate", MagicMock())
|
||||
@mock.patch("lightning_app.utilities.network.LightningClient.cluster_service_delete_cluster")
|
||||
def test_delete_cluster(api: mock.MagicMock):
|
||||
cluster_manager = AWSClusterManager()
|
||||
cluster_manager.delete(cluster_id="test-7")
|
||||
|
||||
api.assert_called_once_with(id="test-7", force=False)
|
||||
|
||||
|
||||
class Test_check_cluster_name_is_valid:
|
||||
@pytest.mark.parametrize("name", ["test-7", "0wildgoat"])
|
||||
def test_valid(self, name):
|
||||
assert cmd_clusters._check_cluster_name_is_valid(None, None, name)
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"name", ["(&%)!@#", "1234567890123456789012345678901234567890123456789012345678901234567890"]
|
||||
)
|
||||
def test_invalid(self, name):
|
||||
with pytest.raises(click.ClickException) as e:
|
||||
cmd_clusters._check_cluster_name_is_valid(None, None, name)
|
||||
assert "cluster name doesn't match regex pattern" in str(e.value)
|
||||
|
||||
|
||||
class Test_wait_for_cluster_state:
|
||||
# TODO(rra) add tests for pagination
|
||||
|
||||
@pytest.mark.parametrize("target_state", [V1ClusterState.RUNNING, V1ClusterState.DELETED])
|
||||
@pytest.mark.parametrize(
|
||||
"previous_state", [V1ClusterState.QUEUED, V1ClusterState.PENDING, V1ClusterState.UNSPECIFIED]
|
||||
)
|
||||
def test_happy_path(self, target_state, previous_state):
|
||||
client = FakeLightningClient(
|
||||
list_responses=[
|
||||
V1ListClustersResponse(
|
||||
clusters=[Externalv1Cluster(id="test-cluster", status=V1ClusterStatus(phase=state))]
|
||||
)
|
||||
for state in [previous_state, target_state]
|
||||
]
|
||||
)
|
||||
cmd_clusters._wait_for_cluster_state(client, "test-cluster", target_state, check_timeout=0.1)
|
||||
assert client.list_call_count == 1
|
||||
|
||||
@pytest.mark.parametrize("target_state", [V1ClusterState.RUNNING, V1ClusterState.DELETED])
|
||||
def test_times_out(self, target_state):
|
||||
client = FakeLightningClient(
|
||||
list_responses=[
|
||||
V1ListClustersResponse(
|
||||
clusters=[
|
||||
Externalv1Cluster(id="test-cluster", status=V1ClusterStatus(phase=V1ClusterState.UNSPECIFIED))
|
||||
]
|
||||
)
|
||||
],
|
||||
consume=False,
|
||||
)
|
||||
with pytest.raises(click.ClickException) as e:
|
||||
cmd_clusters._wait_for_cluster_state(
|
||||
client, "test-cluster", target_state, max_wait_time=0.4, check_timeout=0.2
|
||||
)
|
||||
assert "Max wait time elapsed" in str(e.value)
|
|
@ -0,0 +1,53 @@
|
|||
import os
|
||||
import uuid
|
||||
|
||||
import pytest
|
||||
|
||||
from src.lightning_app.testing.testing import run_cli
|
||||
|
||||
|
||||
@pytest.mark.cloud
|
||||
@pytest.mark.skipif(
|
||||
os.environ.get("LIGHTNING_BYOC_ROLE_ARN") is None, reason="missing LIGHTNING_BYOC_ROLE_ARN environment variable"
|
||||
)
|
||||
@pytest.mark.skipif(
|
||||
os.environ.get("LIGHTNING_BYOC_EXTERNAL_ID") is None,
|
||||
reason="missing LIGHTNING_BYOC_EXTERNAL_ID environment variable",
|
||||
)
|
||||
def test_cluster_lifecycle() -> None:
|
||||
role_arn = os.environ.get("LIGHTNING_BYOC_ROLE_ARN", None)
|
||||
external_id = os.environ.get("LIGHTNING_BYOC_EXTERNAL_ID", None)
|
||||
region = "us-west-2"
|
||||
instance_types = "t2.small,t3.small"
|
||||
cluster_name = "byoc-%s" % (uuid.uuid4())
|
||||
with run_cli(
|
||||
[
|
||||
"create",
|
||||
"cluster",
|
||||
cluster_name,
|
||||
"--provider",
|
||||
"aws",
|
||||
"--role-arn",
|
||||
role_arn,
|
||||
"--external-id",
|
||||
external_id,
|
||||
"--region",
|
||||
region,
|
||||
"--instance-types",
|
||||
instance_types,
|
||||
"--wait",
|
||||
]
|
||||
) as (stdout, stderr):
|
||||
assert "success" in stdout, f"stdout: {stdout}\nstderr: {stderr}"
|
||||
|
||||
with run_cli(["list", "clusters"]) as (stdout, stderr):
|
||||
assert cluster_name in stdout, f"stdout: {stdout}\nstderr: {stderr}"
|
||||
|
||||
with run_cli(["delete", "cluster", "--force", cluster_name]) as (stdout, stderr):
|
||||
assert "success" in stdout, f"stdout: {stdout}\nstderr: {stderr}"
|
||||
|
||||
|
||||
@pytest.mark.cloud
|
||||
def test_cluster_list() -> None:
|
||||
with run_cli(["list", "clusters"]) as (stdout, stderr):
|
||||
assert "lightning-cloud" in stdout, f"stdout: {stdout}\nstderr: {stderr}"
|
Loading…
Reference in New Issue