lightning/tests/tests_app/utilities/test_commands.py

163 lines
4.5 KiB
Python

import argparse
import sys
from multiprocessing import Process
from time import sleep
from unittest.mock import MagicMock
import pytest
import requests
from pydantic import BaseModel
from lightning_app import LightningApp, LightningFlow
from lightning_app.cli.commands.app_commands import _run_app_command
from lightning_app.cli.commands.connection import connect, disconnect
from lightning_app.core.constants import APP_SERVER_PORT
from lightning_app.runners import MultiProcessRuntime
from lightning_app.testing.helpers import _RunIf
from lightning_app.utilities.commands.base import _download_command, _validate_client_command, ClientCommand
from lightning_app.utilities.state import AppState
class SweepConfig(BaseModel):
sweep_name: str
num_trials: int
class SweepCommand(ClientCommand):
def run(self) -> None:
parser = argparse.ArgumentParser()
parser.add_argument("--sweep_name", type=str)
parser.add_argument("--num_trials", type=int)
hparams = parser.parse_args()
config = SweepConfig(sweep_name=hparams.sweep_name, num_trials=hparams.num_trials)
response = self.invoke_handler(config=config)
assert response is True
class FlowCommands(LightningFlow):
def __init__(self):
super().__init__()
self.names = []
self.has_sweep = False
def run(self):
if self.has_sweep and len(self.names) == 1:
sleep(1)
self._exit()
def trigger_method(self, name: str):
self.names.append(name)
def sweep(self, config: SweepConfig):
self.has_sweep = True
return True
def configure_commands(self):
return [{"user command": self.trigger_method}, {"sweep": SweepCommand(self.sweep)}]
class DummyConfig(BaseModel):
something: str
something_else: int
class DummyCommand(ClientCommand):
def run(self, something: str, something_else: int) -> None:
config = DummyConfig(something=something, something_else=something_else)
response = self.invoke_handler(config=config)
assert response == {"body": 0}
def run(config: DummyConfig):
assert isinstance(config, DummyCommand)
def run_failure_0(name: str):
pass
def run_failure_1(name):
pass
class CustomModel(BaseModel):
pass
def run_failure_2(name: CustomModel):
pass
@_RunIf(skip_windows=True)
def test_validate_client_command():
with pytest.raises(Exception, match="The provided annotation for the argument name"):
_validate_client_command(ClientCommand(run_failure_0))
with pytest.raises(Exception, match="annotate your method"):
_validate_client_command(ClientCommand(run_failure_1))
with pytest.raises(Exception, match="lightning_app/utilities/commands/base.py"):
_validate_client_command(ClientCommand(run_failure_2))
def test_client_commands(monkeypatch):
import requests
resp = MagicMock()
resp.status_code = 200
value = {"body": 0}
resp.json = MagicMock(return_value=value)
post = MagicMock()
post.return_value = resp
monkeypatch.setattr(requests, "post", post)
url = "http//"
kwargs = {"something": "1", "something_else": "1"}
command = DummyCommand(run)
_validate_client_command(command)
client_command = _download_command(
command_name="something",
cls_path=__file__,
cls_name="DummyCommand",
)
client_command._setup("something", app_url=url)
client_command.run(**kwargs)
def target():
app = LightningApp(FlowCommands())
MultiProcessRuntime(app).dispatch()
def test_configure_commands(monkeypatch):
"""This test validates command can be used locally with connect and disconnect."""
process = Process(target=target)
process.start()
time_left = 15
while time_left > 0:
try:
requests.get(f"http://localhost:{APP_SERVER_PORT}/healthz")
break
except requests.exceptions.ConnectionError:
sleep(0.1)
time_left -= 0.1
sleep(0.5)
monkeypatch.setattr(sys, "argv", ["lightning", "user", "command", "--name=something"])
connect("localhost")
_run_app_command("localhost", None)
sleep(0.5)
state = AppState()
state._request_state()
assert state.names == ["something"]
monkeypatch.setattr(sys, "argv", ["lightning", "sweep", "--sweep_name=my_name", "--num_trials=1"])
_run_app_command("localhost", None)
time_left = 15
while time_left > 0:
if process.exitcode == 0:
break
sleep(0.1)
time_left -= 0.1
assert process.exitcode == 0
disconnect()