lightning/tests/tests_app/utilities/test_commands.py

163 lines
4.5 KiB
Python
Raw Normal View History

2022-07-25 17:13:46 +00:00
import argparse
import sys
from multiprocessing import Process
from time import sleep
from unittest.mock import MagicMock
import pytest
import requests
from pydantic import BaseModel
from lightning import LightningFlow
from lightning_app import LightningApp
from lightning_app.cli.lightning_cli import app_command
from lightning_app.core.constants import APP_SERVER_PORT
from lightning_app.runners import MultiProcessRuntime
from lightning_app.testing.helpers import RunIf
from lightning_app.utilities.commands.base import _command_to_method_and_metadata, _download_command, ClientCommand
from lightning_app.utilities.state import AppState
class SweepConfig(BaseModel):
sweep_name: str
num_trials: int
class SweepCommand(ClientCommand):
def run(self) -> None:
print(sys.argv)
parser = argparse.ArgumentParser()
parser.add_argument("--sweep_name", type=str)
parser.add_argument("--num_trials", type=int)
hparams = parser.parse_args()
config = SweepConfig(sweep_name=hparams.sweep_name, num_trials=hparams.num_trials)
response = self.invoke_handler(config=config)
assert response is True
class FlowCommands(LightningFlow):
def __init__(self):
super().__init__()
self.names = []
self.has_sweep = False
def run(self):
if self.has_sweep and len(self.names) == 1:
sleep(2)
self._exit()
def trigger_method(self, name: str):
self.names.append(name)
def sweep(self, config: SweepConfig):
self.has_sweep = True
return True
def configure_commands(self):
return [{"user_command": self.trigger_method}, {"sweep": SweepCommand(self.sweep)}]
class DummyConfig(BaseModel):
something: str
something_else: int
class DummyCommand(ClientCommand):
def run(self, something: str, something_else: int) -> None:
config = DummyConfig(something=something, something_else=something_else)
response = self.invoke_handler(config=config)
assert response == {"body": 0}
def run(config: DummyConfig):
assert isinstance(config, DummyCommand)
def run_failure_0(name: str):
pass
def run_failure_1(name):
pass
class CustomModel(BaseModel):
pass
def run_failure_2(name: CustomModel):
pass
@RunIf(skip_windows=True)
def test_command_to_method_and_metadata():
with pytest.raises(Exception, match="The provided annotation for the argument name"):
_command_to_method_and_metadata(ClientCommand(run_failure_0))
with pytest.raises(Exception, match="annotate your method"):
_command_to_method_and_metadata(ClientCommand(run_failure_1))
with pytest.raises(Exception, match="lightning_app/utilities/commands/base.py"):
_command_to_method_and_metadata(ClientCommand(run_failure_2))
def test_client_commands(monkeypatch):
import requests
resp = MagicMock()
resp.status_code = 200
value = {"body": 0}
resp.json = MagicMock(return_value=value)
post = MagicMock()
post.return_value = resp
monkeypatch.setattr(requests, "post", post)
url = "http//"
kwargs = {"something": "1", "something_else": "1"}
command = DummyCommand(run)
_, command_metadata = _command_to_method_and_metadata(command)
command_metadata.update(
{
"command": "dummy",
"affiliation": "root",
"is_client_command": True,
"owner": "root",
}
)
client_command, models = _download_command(command_metadata, None)
client_command._setup(metadata=command_metadata, models=models, app_url=url)
client_command.run(**kwargs)
def target():
app = LightningApp(FlowCommands())
MultiProcessRuntime(app).dispatch()
def test_configure_commands(monkeypatch):
process = Process(target=target)
process.start()
time_left = 15
while time_left > 0:
try:
requests.get(f"http://localhost:{APP_SERVER_PORT}/healthz")
break
except requests.exceptions.ConnectionError:
sleep(0.1)
time_left -= 0.1
sleep(0.5)
monkeypatch.setattr(sys, "argv", ["lightning", "user_command", "--name=something"])
app_command()
sleep(0.5)
state = AppState()
state._request_state()
assert state.names == ["something"]
monkeypatch.setattr(sys, "argv", ["lightning", "sweep", "--sweep_name", "my_name", "--num_trials", "1"])
app_command()
time_left = 15
while time_left > 0 or process.exitcode is None:
sleep(0.1)
time_left -= 0.1
assert process.exitcode == 0