lightning/tests/tests_app/cli/test_cmd_launch.py

329 lines
9.8 KiB
Python

import os
import signal
import time
from functools import partial
from multiprocessing import Process
from pathlib import Path
from unittest import mock
from unittest.mock import ANY, MagicMock, Mock
from click.testing import CliRunner
from lightning.app.cli.lightning_cli_launch import run_flow, run_flow_and_servers, run_frontend, run_server
from lightning.app.core.queues import QueuingSystem
from lightning.app.frontend.web import StaticWebFrontend
from lightning.app.launcher import launcher
from lightning.app.runners.runtime import load_app_from_file
from lightning.app.testing.helpers import EmptyWork, _RunIf
from lightning.app.utilities.app_commands import run_app_commands
from lightning.app.utilities.network import find_free_network_port
from tests_app import _PROJECT_ROOT
_FILE_PATH = os.path.join(_PROJECT_ROOT, "tests/tests_app/core/scripts/app_metadata.py")
def test_run_frontend(monkeypatch):
"""Test that the CLI can be used to start the frontend server of a particular LightningFlow using the cloud
dispatcher.
This CLI call is made by Lightning AI and is not meant to be invoked by the user directly.
"""
runner = CliRunner()
port = find_free_network_port()
start_server_mock = Mock()
monkeypatch.setattr(StaticWebFrontend, "start_server", start_server_mock)
result = runner.invoke(
run_frontend,
[
str(Path(__file__).parent / "launch_data" / "app_v0" / "app.py"),
"--flow-name",
"root.aas",
"--host",
"localhost",
"--port",
port,
],
)
assert result.exit_code == 0
start_server_mock.assert_called_once()
start_server_mock.assert_called_with("localhost", port)
class MockRedisQueue:
_MOCKS = {}
def __init__(self, name: str, default_timeout: float):
self.name = name
self.default_timeout = default_timeout
self.queue = [None] # adding a dummy element.
self._MOCKS[name] = MagicMock()
def put(self, item):
self._MOCKS[self.name].put(item)
self.queue.put(item)
def get(self, timeout: int = None):
self._MOCKS[self.name].get(timeout=timeout)
return self.queue.pop(0)
@property
def is_running(self):
self._MOCKS[self.name].is_running()
return True
@mock.patch("lightning.app.core.queues.RedisQueue", MockRedisQueue)
@mock.patch("lightning.app.launcher.launcher.check_if_redis_running", MagicMock(return_value=True))
@mock.patch("lightning.app.launcher.launcher.start_server")
def test_run_server(start_server_mock):
runner = CliRunner()
result = runner.invoke(
run_server,
[
_FILE_PATH,
"--queue-id",
"1",
"--host",
"http://127.0.0.1:7501/view",
"--port",
"6000",
],
catch_exceptions=False,
)
assert result.exit_code == 0
start_server_mock.assert_called_once_with(
host="http://127.0.0.1:7501/view",
port=6000,
api_publish_state_queue=ANY,
api_delta_queue=ANY,
api_response_queue=ANY,
spec=ANY,
apis=ANY,
)
kwargs = start_server_mock._mock_call_args.kwargs
assert isinstance(kwargs["api_publish_state_queue"], MockRedisQueue)
assert kwargs["api_publish_state_queue"].name.startswith("1")
assert isinstance(kwargs["api_delta_queue"], MockRedisQueue)
assert kwargs["api_delta_queue"].name.startswith("1")
def mock_server(should_catch=False, sleep=1000):
if should_catch:
def _sigterm_handler(*_):
time.sleep(100)
signal.signal(signal.SIGTERM, _sigterm_handler)
time.sleep(sleep)
def run_forever_process():
while True:
time.sleep(1)
def run_for_2_seconds_and_raise():
time.sleep(2)
raise RuntimeError("existing")
def exit_successfully_immediately():
return
def start_servers(should_catch=False, sleep=1000):
processes = [
(
"p1",
launcher.start_server_in_process(target=partial(mock_server, should_catch=should_catch, sleep=sleep)),
),
(
"p2",
launcher.start_server_in_process(target=partial(mock_server, sleep=sleep)),
),
(
"p3",
launcher.start_server_in_process(target=partial(mock_server, sleep=sleep)),
),
]
launcher.manage_server_processes(processes)
@_RunIf(skip_windows=True)
def test_manage_server_processes():
p = Process(target=partial(start_servers, sleep=0.5))
p.start()
p.join()
assert p.exitcode == 0
p = Process(target=start_servers)
p.start()
p.join(0.5)
p.terminate()
p.join()
assert p.exitcode in [-15, 0]
p = Process(target=partial(start_servers, should_catch=True))
p.start()
p.join(0.5)
p.terminate()
p.join()
assert p.exitcode in [-15, 1]
def start_processes(**functions):
processes = []
for name, fn in functions.items():
processes.append((name, launcher.start_server_in_process(fn)))
launcher.manage_server_processes(processes)
@_RunIf(skip_windows=True)
def test_manage_server_processes_one_process_gets_killed(capfd):
functions = {"p1": run_forever_process, "p2": run_for_2_seconds_and_raise}
p = Process(target=start_processes, kwargs=functions)
p.start()
for _ in range(40):
time.sleep(1)
if p.exitcode is not None:
break
assert p.exitcode == 1
captured = capfd.readouterr()
assert (
"Found dead components with non-zero exit codes, exiting execution!!! Components: \n"
"| Name | Exit Code |\n|------|-----------|\n| p2 | 1 |\n" in captured.out
)
@_RunIf(skip_windows=True)
def test_manage_server_processes_all_processes_exits_with_zero_exitcode(capfd):
functions = {
"p1": exit_successfully_immediately,
"p2": exit_successfully_immediately,
}
p = Process(target=start_processes, kwargs=functions)
p.start()
for _ in range(40):
time.sleep(1)
if p.exitcode is not None:
break
assert p.exitcode == 0
captured = capfd.readouterr()
assert "All the components are inactive with exitcode 0. Exiting execution!!!" in captured.out
@mock.patch("lightning.app.launcher.launcher.StorageOrchestrator", MagicMock())
@mock.patch("lightning.app.core.queues.RedisQueue", MockRedisQueue)
@mock.patch("lightning.app.launcher.launcher.manage_server_processes", Mock())
def test_run_flow_and_servers(monkeypatch):
runner = CliRunner()
start_server_mock = Mock()
monkeypatch.setattr(launcher, "start_server_in_process", start_server_mock)
runner.invoke(
run_flow_and_servers,
[
str(Path(__file__).parent / "launch_data" / "app_v0" / "app.py"),
"--base-url",
"https://some.url",
"--queue-id",
"1",
"--host",
"http://127.0.0.1:7501/view",
"--port",
6000,
"--flow-port",
"root.aas",
6001,
"--flow-port",
"root.bbs",
6002,
],
catch_exceptions=False,
)
start_server_mock.assert_called()
assert start_server_mock.call_count == 4
@mock.patch("lightning.app.core.queues.RedisQueue", MockRedisQueue)
@mock.patch("lightning.app.launcher.launcher.WorkRunner")
def test_run_work(mock_work_runner, monkeypatch):
run_app_commands(_FILE_PATH)
app = load_app_from_file(_FILE_PATH)
names = [w.name for w in app.works]
mocked_queue = MagicMock()
mocked_queue.get.return_value = EmptyWork()
monkeypatch.setattr(
QueuingSystem,
"get_work_queue",
MagicMock(return_value=mocked_queue),
)
assert names == [
"root.flow_a_1.work_a",
"root.flow_a_2.work_a",
"root.flow_b.work_b",
]
for name in names:
launcher.run_lightning_work(
file=_FILE_PATH,
work_name=name,
queue_id="1",
)
kwargs = mock_work_runner._mock_call_args.kwargs
assert isinstance(kwargs["work"], EmptyWork)
assert kwargs["work_name"] == name
assert isinstance(kwargs["caller_queue"], MockRedisQueue)
assert kwargs["caller_queue"].name.startswith("1")
assert isinstance(kwargs["delta_queue"], MockRedisQueue)
assert kwargs["delta_queue"].name.startswith("1")
assert isinstance(kwargs["readiness_queue"], MockRedisQueue)
assert kwargs["readiness_queue"].name.startswith("1")
assert isinstance(kwargs["error_queue"], MockRedisQueue)
assert kwargs["error_queue"].name.startswith("1")
assert isinstance(kwargs["request_queue"], MockRedisQueue)
assert kwargs["request_queue"].name.startswith("1")
assert isinstance(kwargs["response_queue"], MockRedisQueue)
assert kwargs["response_queue"].name.startswith("1")
assert isinstance(kwargs["copy_request_queue"], MockRedisQueue)
assert kwargs["copy_request_queue"].name.startswith("1")
assert isinstance(kwargs["copy_response_queue"], MockRedisQueue)
assert kwargs["copy_response_queue"].name.startswith("1")
MockRedisQueue._MOCKS["healthz"].is_running.assert_called()
@mock.patch("lightning.app.core.queues.QueuingSystem", MagicMock())
@mock.patch("lightning.app.launcher.launcher.StorageOrchestrator", MagicMock())
@mock.patch("lightning.app.LightningApp._run")
@mock.patch("lightning.app.launcher.launcher.CloudBackend")
def test_run_flow(mock_cloud_backend, mock_lightning_app_run):
runner = CliRunner()
base_url = "https://lightning.ai/me/apps"
result = runner.invoke(
run_flow,
[_FILE_PATH, "--queue-id=1", f"--base-url={base_url}"],
catch_exceptions=False,
)
assert result.exit_code == 0
mock_lightning_app_run.assert_called_once()
assert len(mock_cloud_backend._mock_mock_calls) == 13