feat: allow root path to run the app on `/path` (#14972)

* feat: add base path
* uvicorn fix arg
* Add prefix
* update with base_path fix
* replace base path with root path
* Apply suggestions from code review

Co-authored-by: Kaushik B <45285388+kaushikb11@users.noreply.github.com>
Co-authored-by: Kaushik B <kaushikbokka@gmail.com>
Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>
Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
Pritam Soni 2022-10-07 19:39:40 +05:30 committed by GitHub
parent 8ec7ffb5ce
commit 80080550d9
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
10 changed files with 66 additions and 28 deletions

View File

@ -348,6 +348,7 @@ def start_server(
has_started_queue: Optional[Queue] = None,
host="127.0.0.1",
port=8000,
root_path: str = "",
uvicorn_run: bool = True,
spec: Optional[List] = None,
apis: Optional[List[HttpMethod]] = None,
@ -384,6 +385,6 @@ def start_server(
register_global_routes()
uvicorn.run(app=fastapi_service, host=host, port=port, log_level="error")
uvicorn.run(app=fastapi_service, host=host, port=port, log_level="error", root_path=root_path)
return refresher

View File

@ -49,6 +49,7 @@ class LightningApp:
root: "lightning_app.LightningFlow",
debug: bool = False,
info: frontend.AppInfo = None,
root_path: str = "",
):
"""The Lightning App, or App in short runs a tree of one or more components that interact to create end-to-end
applications. There are two kinds of components: :class:`~lightning_app.core.flow.LightningFlow` and
@ -67,6 +68,11 @@ class LightningApp:
This can be helpful when reporting bugs on Lightning repo.
info: Provide additional info about the app which will be used to update html title,
description and image meta tags and specify any additional tags as list of html strings.
root_path: Set this to `/path` if you want to run your app behind a proxy at `/path` leave empty for "/".
For instance, if you want to run your app at `https://customdomain.com/myapp`,
set `root_path` to `/myapp`.
You can learn more about proxy `here <https://www.fortinet.com/resources/cyberglossary/proxy-server>`_.
.. doctest::
@ -82,6 +88,7 @@ class LightningApp:
Hello World!
"""
self.root_path = root_path # when running behind a proxy
_validate_root_flow(root)
self._root = root
@ -140,7 +147,7 @@ class LightningApp:
# update index.html,
# this should happen once for all apps before the ui server starts running.
frontend.update_index_file_with_info(FRONTEND_DIR, info=info)
frontend.update_index_file(FRONTEND_DIR, info=info, root_path=root_path)
def get_component_by_name(self, component_name: str):
"""Returns the instance corresponding to the given component name."""

View File

@ -15,7 +15,7 @@ class Frontend(ABC):
self.flow: Optional["LightningFlow"] = None
@abstractmethod
def start_server(self, host: str, port: int) -> None:
def start_server(self, host: str, port: int, root_path: str = "") -> None:
"""Start the process that serves the UI at the given hostname and port number.
Arguments:
@ -23,13 +23,16 @@ class Frontend(ABC):
but defaults to localhost when running locally.
port: The port number where the UI will be served. This gets determined by the dispatcher, which by default
chooses any free port when running locally.
root_path: root_path for the server if app in exposed via a proxy at `/<root_path>`
Example:
An custom implementation could look like this:
.. code-block:: python
def start_server(self, host, port):
def start_server(self, host, port, root_path=""):
self._process = subprocess.Popen(["flask", "run" "--host", host, "--port", str(port)])
"""

View File

@ -95,7 +95,7 @@ class PanelFrontend(Frontend):
self._log_files: dict[str, TextIO] = {}
_logger.debug("PanelFrontend Frontend with %s is initialized.", entry_point)
def start_server(self, host: str, port: int) -> None:
def start_server(self, host: str, port: int, root_path: str = "") -> None:
_logger.debug("PanelFrontend starting server on %s:%s", host, port)
# 1: Prepare environment variables and arguments.

View File

@ -20,6 +20,7 @@ class StaticWebFrontend(Frontend):
Arguments:
serve_dir: A local directory to serve files from. This directory should at least contain a file `index.html`.
root_path: A path prefix when routing traffic from behind a proxy at `/<root_path>`
Example:
@ -36,7 +37,7 @@ class StaticWebFrontend(Frontend):
self.serve_dir = serve_dir
self._process: Optional[mp.Process] = None
def start_server(self, host: str, port: int) -> None:
def start_server(self, host: str, port: int, root_path: str = "") -> None:
log_file = str(get_frontend_logfile())
self._process = mp.Process(
target=start_server,
@ -46,6 +47,7 @@ class StaticWebFrontend(Frontend):
serve_dir=self.serve_dir,
path=f"/{self.flow.name}",
log_file=log_file,
root_path=root_path,
),
)
self._process.start()
@ -61,7 +63,9 @@ def healthz():
return {"status": "ok"}
def start_server(serve_dir: str, host: str = "localhost", port: int = -1, path: str = "/", log_file: str = "") -> None:
def start_server(
serve_dir: str, host: str = "localhost", port: int = -1, path: str = "/", log_file: str = "", root_path: str = ""
) -> None:
if port == -1:
port = find_free_network_port()
fastapi_service = FastAPI()
@ -76,11 +80,11 @@ def start_server(serve_dir: str, host: str = "localhost", port: int = -1, path:
# trailing / is required for urljoin to properly join the path. In case of
# multiple trailing /, urljoin removes them
fastapi_service.get(urljoin(f"{path}/", "healthz"), status_code=200)(healthz)
fastapi_service.mount(path, StaticFiles(directory=serve_dir, html=True), name="static")
fastapi_service.mount(urljoin(path, root_path), StaticFiles(directory=serve_dir, html=True), name="static")
log_config = _get_log_config(log_file) if log_file else uvicorn.config.LOGGING_CONFIG
uvicorn.run(app=fastapi_service, host=host, port=port, log_config=log_config)
uvicorn.run(app=fastapi_service, host=host, port=port, log_config=log_config, root_path=root_path)
def _get_log_config(log_file: str) -> dict:
@ -115,7 +119,8 @@ def _get_log_config(log_file: str) -> dict:
if __name__ == "__main__": # pragma: no-cover
parser = ArgumentParser()
parser.add_argument("serve_dir", type=str)
parser.add_argument("root_path", type=str, default="")
parser.add_argument("--host", type=str, default="localhost")
parser.add_argument("--port", type=int, default=-1)
args = parser.parse_args()
start_server(serve_dir=args.serve_dir, host=args.host, port=args.port)
start_server(serve_dir=args.serve_dir, host=args.host, port=args.port, root_path=args.root_path)

View File

@ -83,6 +83,7 @@ class MultiProcessRuntime(Runtime):
api_delta_queue=self.app.api_delta_queue,
has_started_queue=has_started_queue,
spec=extract_metadata_from_app(self.app),
root_path=self.app.root_path,
)
server_proc = multiprocessing.Process(target=start_server, kwargs=kwargs)
self.processes["server"] = server_proc

View File

@ -34,6 +34,7 @@ class SingleProcessRuntime(Runtime):
api_delta_queue=self.app.api_delta_queue,
has_started_queue=has_started_queue,
spec=extract_metadata_from_app(self.app),
root_path=self.app.root_path,
)
server_proc = mp.Process(target=start_server, kwargs=kwargs)
self.processes["server"] = server_proc

View File

@ -14,7 +14,7 @@ class AppInfo:
meta_tags: Optional[List[str]] = None
def update_index_file_with_info(ui_root: str, info: AppInfo = None) -> None:
def update_index_file(ui_root: str, info: Optional[AppInfo] = None, root_path: str = "") -> None:
import shutil
from pathlib import Path
@ -27,19 +27,27 @@ def update_index_file_with_info(ui_root: str, info: AppInfo = None) -> None:
# revert index.html in case it was modified after creating original.html
shutil.copyfile(original_file, entry_file)
if not info:
return
if info:
with original_file.open() as f:
original = f.read()
original = ""
with entry_file.open("w") as f:
f.write(_get_updated_content(original=original, root_path=root_path, info=info))
with original_file.open() as f:
original = f.read()
if root_path:
root_path_without_slash = root_path.replace("/", "", 1) if root_path.startswith("/") else root_path
src_dir = Path(ui_root)
dst_dir = src_dir / root_path_without_slash
with entry_file.open("w") as f:
f.write(_get_updated_content(original=original, info=info))
if dst_dir.exists():
shutil.rmtree(dst_dir, ignore_errors=True)
# copy everything except the current root_path, this is to fix a bug if user specifies
# /abc at first and then /abc/def, server don't start
# ideally we should copy everything except custom root_path that user passed.
shutil.copytree(src_dir, dst_dir, ignore=shutil.ignore_patterns(f"{root_path_without_slash}*"))
def _get_updated_content(original: str, info: AppInfo) -> str:
def _get_updated_content(original: str, root_path: str, info: AppInfo) -> str:
soup = BeautifulSoup(original, "html.parser")
# replace favicon
@ -56,6 +64,11 @@ def _get_updated_content(original: str, info: AppInfo) -> str:
soup.find("meta", {"property": "og:image"}).attrs["content"] = info.image
if info.meta_tags:
soup.find("head").append(*[BeautifulSoup(meta, "html.parser") for meta in info.meta_tags])
for meta in info.meta_tags:
soup.find("head").append(BeautifulSoup(meta, "html.parser"))
return str(soup)
if root_path:
# this will be used by lightning app ui to add root_path to add requests
soup.find("head").append(BeautifulSoup(f'<script>window.app_prefix="{root_path}"</script>', "html.parser"))
return str(soup).replace("/static", f"{root_path}/static")

View File

@ -359,6 +359,7 @@ def test_start_server_started():
has_started_queue=has_started_queue,
api_response_queue=api_response_queue,
port=1111,
root_path="",
)
server_proc = mp.Process(target=start_server, kwargs=kwargs)
@ -385,6 +386,7 @@ def test_start_server_info_message(ui_refresher, uvicorn_run, caplog, monkeypatc
api_delta_queue=api_delta_queue,
has_started_queue=has_started_queue,
api_response_queue=api_response_queue,
root_path="test",
)
monkeypatch.setattr(api, "logger", logging.getLogger())
@ -395,7 +397,7 @@ def test_start_server_info_message(ui_refresher, uvicorn_run, caplog, monkeypatc
assert "Your app has started. View it in your browser: http://0.0.0.1:1111/view" in caplog.text
ui_refresher.assert_called_once()
uvicorn_run.assert_called_once_with(host="0.0.0.1", port=1111, log_level="error", app=mock.ANY)
uvicorn_run.assert_called_once_with(host="0.0.0.1", port=1111, log_level="error", app=mock.ANY, root_path="test")
class InputRequestModel(BaseModel):

View File

@ -39,6 +39,7 @@ def test_start_stop_server_through_frontend(process_mock):
"serve_dir": ".",
"path": "/root.my.flow",
"log_file": os.path.join(log_file_root, "frontend", "logs.log"),
"root_path": "",
},
)
process_mock().start.assert_called_once()
@ -47,24 +48,28 @@ def test_start_stop_server_through_frontend(process_mock):
@mock.patch("lightning_app.frontend.web.uvicorn")
def test_start_server_through_function(uvicorn_mock, tmpdir, monkeypatch):
@pytest.mark.parametrize("root_path", ["", "/base"])
def test_start_server_through_function(uvicorn_mock, tmpdir, monkeypatch, root_path):
FastAPIMock = MagicMock()
FastAPIMock.mount = MagicMock()
FastAPIGetDecoratorMock = MagicMock()
FastAPIMock.get.return_value = FastAPIGetDecoratorMock
monkeypatch.setattr(lightning_app.frontend.web, "FastAPI", MagicMock(return_value=FastAPIMock))
lightning_app.frontend.web.start_server(serve_dir=tmpdir, host="myhost", port=1000, path="/test-flow")
uvicorn_mock.run.assert_called_once_with(app=ANY, host="myhost", port=1000, log_config=ANY)
FastAPIMock.mount.assert_called_once_with("/test-flow", ANY, name="static")
lightning_app.frontend.web.start_server(
serve_dir=tmpdir, host="myhost", port=1000, path="/test-flow", root_path=root_path
)
uvicorn_mock.run.assert_called_once_with(app=ANY, host="myhost", port=1000, log_config=ANY, root_path=root_path)
FastAPIMock.mount.assert_called_once_with(root_path or "/test-flow", ANY, name="static")
FastAPIMock.get.assert_called_once_with("/test-flow/healthz", status_code=200)
FastAPIGetDecoratorMock.assert_called_once_with(healthz)
# path has default value "/"
FastAPIMock.mount = MagicMock()
lightning_app.frontend.web.start_server(serve_dir=tmpdir, host="myhost", port=1000)
FastAPIMock.mount.assert_called_once_with("/", ANY, name="static")
lightning_app.frontend.web.start_server(serve_dir=tmpdir, host="myhost", port=1000, root_path=root_path)
FastAPIMock.mount.assert_called_once_with(root_path or "/", ANY, name="static")
def test_healthz():