2022-12-05 21:58:22 +00:00
|
|
|
import os
|
2022-06-30 20:43:04 +00:00
|
|
|
from unittest import mock
|
|
|
|
from unittest.mock import Mock
|
|
|
|
|
2022-12-05 21:58:22 +00:00
|
|
|
import pytest
|
|
|
|
|
2022-06-30 20:43:04 +00:00
|
|
|
from lightning_app import LightningApp, LightningFlow, LightningWork
|
|
|
|
from lightning_app.frontend import StaticWebFrontend, StreamlitFrontend
|
|
|
|
from lightning_app.runners import MultiProcessRuntime
|
|
|
|
from lightning_app.utilities.component import _get_context
|
|
|
|
|
|
|
|
|
|
|
|
def _streamlit_render_fn():
|
|
|
|
pass
|
|
|
|
|
|
|
|
|
|
|
|
class StreamlitFlow(LightningFlow):
|
|
|
|
def run(self):
|
|
|
|
self._exit()
|
|
|
|
|
|
|
|
def configure_layout(self):
|
|
|
|
frontend = StreamlitFrontend(render_fn=_streamlit_render_fn)
|
|
|
|
frontend.start_server = Mock()
|
|
|
|
frontend.stop_server = Mock()
|
|
|
|
return frontend
|
|
|
|
|
|
|
|
|
|
|
|
class WebFlow(LightningFlow):
|
|
|
|
def run(self):
|
|
|
|
self._exit()
|
|
|
|
|
|
|
|
def configure_layout(self):
|
|
|
|
frontend = StaticWebFrontend(serve_dir="a/b/c")
|
|
|
|
frontend.start_server = Mock()
|
|
|
|
frontend.stop_server = Mock()
|
|
|
|
return frontend
|
|
|
|
|
|
|
|
|
|
|
|
class StartFrontendServersTestFlow(LightningFlow):
|
|
|
|
def __init__(self):
|
|
|
|
super().__init__()
|
|
|
|
self.flow0 = StreamlitFlow()
|
|
|
|
self.flow1 = WebFlow()
|
|
|
|
|
|
|
|
def run(self):
|
|
|
|
self._exit()
|
|
|
|
|
|
|
|
|
|
|
|
@mock.patch("lightning_app.runners.multiprocess.find_free_network_port")
|
|
|
|
def test_multiprocess_starts_frontend_servers(*_):
|
|
|
|
"""Test that the MultiProcessRuntime starts the servers for the frontends in each LightningFlow."""
|
|
|
|
root = StartFrontendServersTestFlow()
|
|
|
|
app = LightningApp(root)
|
|
|
|
MultiProcessRuntime(app).dispatch()
|
|
|
|
|
|
|
|
app.frontends[root.flow0.name].start_server.assert_called_once()
|
|
|
|
app.frontends[root.flow1.name].start_server.assert_called_once()
|
|
|
|
|
|
|
|
app.frontends[root.flow0.name].stop_server.assert_called_once()
|
|
|
|
app.frontends[root.flow1.name].stop_server.assert_called_once()
|
|
|
|
|
|
|
|
|
|
|
|
class ContextWork(LightningWork):
|
|
|
|
def __init__(self):
|
|
|
|
super().__init__()
|
|
|
|
|
|
|
|
def run(self):
|
|
|
|
assert _get_context().value == "work"
|
|
|
|
|
|
|
|
|
2022-12-19 14:06:52 +00:00
|
|
|
class ContextFlow(LightningFlow):
|
2022-06-30 20:43:04 +00:00
|
|
|
def __init__(self):
|
|
|
|
super().__init__()
|
|
|
|
self.work = ContextWork()
|
|
|
|
assert _get_context() is None
|
|
|
|
|
|
|
|
def run(self):
|
|
|
|
assert _get_context().value == "flow"
|
|
|
|
self.work.run()
|
|
|
|
assert _get_context().value == "flow"
|
|
|
|
self._exit()
|
|
|
|
|
|
|
|
|
|
|
|
def test_multiprocess_runtime_sets_context():
|
|
|
|
"""Test that the runtime sets the global variable COMPONENT_CONTEXT in Flow and Work."""
|
2022-12-19 14:06:52 +00:00
|
|
|
MultiProcessRuntime(LightningApp(ContextFlow())).dispatch()
|
2022-12-05 21:58:22 +00:00
|
|
|
|
|
|
|
|
|
|
|
@pytest.mark.parametrize(
|
|
|
|
"env,expected_url",
|
|
|
|
[
|
|
|
|
({}, "http://127.0.0.1:7501/view"),
|
|
|
|
({"APP_SERVER_HOST": "http://test"}, "http://test"),
|
|
|
|
],
|
|
|
|
)
|
|
|
|
def test_get_app_url(env, expected_url):
|
|
|
|
with mock.patch.dict(os.environ, env):
|
|
|
|
assert MultiProcessRuntime._get_app_url() == expected_url
|