Add --app_args support from the CLI (#13625)
This commit is contained in:
parent
aa62fe36df
commit
2a873da042
|
@ -0,0 +1,28 @@
|
|||
import argparse
|
||||
|
||||
import lightning as L
|
||||
|
||||
|
||||
class Work(L.LightningWork):
|
||||
def __init__(self, cloud_compute):
|
||||
super().__init__(cloud_compute=cloud_compute)
|
||||
|
||||
def run(self):
|
||||
pass
|
||||
|
||||
|
||||
class Flow(L.LightningFlow):
|
||||
def __init__(self, cloud_compute):
|
||||
super().__init__()
|
||||
self.work = Work(cloud_compute)
|
||||
|
||||
def run(self):
|
||||
assert self.work.cloud_compute.name == "gpu", self.work.cloud_compute.name
|
||||
self._exit()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--use_gpu", action="store_true", default=False, help="Whether to use GPU in the cloud")
|
||||
hparams = parser.parse_args()
|
||||
app = L.LightningApp(Flow(L.CloudCompute("gpu" if hparams.use_gpu else "cpu")))
|
|
@ -1,7 +1,7 @@
|
|||
import logging
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Tuple, Union
|
||||
from typing import List, Tuple, Union
|
||||
|
||||
import click
|
||||
from requests.exceptions import ConnectionError
|
||||
|
@ -109,8 +109,17 @@ def run():
|
|||
@click.option("--blocking", "blocking", type=bool, default=False)
|
||||
@click.option("--open-ui", type=bool, default=True, help="Decide whether to launch the app UI in a web browser")
|
||||
@click.option("--env", type=str, default=[], multiple=True, help="Env variables to be set for the app.")
|
||||
@click.option("--app_args", type=str, default=[], multiple=True, help="Collection of arguments for the app.")
|
||||
def run_app(
|
||||
file: str, cloud: bool, without_server: bool, no_cache: bool, name: str, blocking: bool, open_ui: bool, env: tuple
|
||||
file: str,
|
||||
cloud: bool,
|
||||
without_server: bool,
|
||||
no_cache: bool,
|
||||
name: str,
|
||||
blocking: bool,
|
||||
open_ui: bool,
|
||||
env: tuple,
|
||||
app_args: List[str],
|
||||
):
|
||||
"""Run an app from a file."""
|
||||
_run_app(file, cloud, without_server, no_cache, name, blocking, open_ui, env)
|
||||
|
@ -263,10 +272,4 @@ def _prepare_file(file: str) -> str:
|
|||
if exists:
|
||||
return file
|
||||
|
||||
if not exists and file == "quick_start.py":
|
||||
from lightning_app.demo.quick_start import app
|
||||
|
||||
logger.info(f"For demo purposes, Lightning will run the {app.__file__} file.")
|
||||
return app.__file__
|
||||
|
||||
raise FileNotFoundError(f"The provided file {file} hasn't been found.")
|
||||
|
|
|
@ -76,14 +76,20 @@ class LightningTestApp(LightningApp):
|
|||
|
||||
|
||||
@requires("click")
|
||||
def application_testing(lightning_app_cls: Type[LightningTestApp], command_line: List[str] = []) -> Any:
|
||||
def application_testing(
|
||||
lightning_app_cls: Type[LightningTestApp] = LightningTestApp, command_line: List[str] = []
|
||||
) -> Any:
|
||||
from unittest import mock
|
||||
|
||||
from click.testing import CliRunner
|
||||
|
||||
with mock.patch("lightning.LightningApp", lightning_app_cls):
|
||||
original = sys.argv
|
||||
sys.argv = command_line
|
||||
runner = CliRunner()
|
||||
return runner.invoke(run_app, command_line, catch_exceptions=False)
|
||||
result = runner.invoke(run_app, command_line, catch_exceptions=False)
|
||||
sys.argv = original
|
||||
return result
|
||||
|
||||
|
||||
class SingleWorkFlow(LightningFlow):
|
||||
|
|
|
@ -4,6 +4,7 @@ import os
|
|||
import sys
|
||||
import traceback
|
||||
import types
|
||||
from contextlib import contextmanager
|
||||
from typing import Dict, List, TYPE_CHECKING, Union
|
||||
|
||||
from lightning_app.utilities.exceptions import MisconfigurationException
|
||||
|
@ -26,7 +27,8 @@ def load_app_from_file(filepath: str) -> "LightningApp":
|
|||
code = _create_code(filepath)
|
||||
module = _create_fake_main_module(filepath)
|
||||
try:
|
||||
exec(code, module.__dict__)
|
||||
with _patch_sys_argv():
|
||||
exec(code, module.__dict__)
|
||||
except Exception:
|
||||
# we want to format the exception as if no frame was on top.
|
||||
exp, val, tb = sys.exc_info()
|
||||
|
@ -113,6 +115,48 @@ def _create_fake_main_module(script_path):
|
|||
return module
|
||||
|
||||
|
||||
@contextmanager
|
||||
def _patch_sys_argv():
|
||||
"""This function modifies the ``sys.argv`` by extracting the arguments after ``--app_args`` and removed
|
||||
everything else before executing the user app script.
|
||||
|
||||
The command: ``lightning run app app.py --without-server --app_args --use_gpu --env ...`` will be converted into
|
||||
``app.py --use_gpu``
|
||||
"""
|
||||
from lightning_app.cli.lightning_cli import run_app
|
||||
|
||||
original_argv = sys.argv
|
||||
# 1: Remove the CLI command
|
||||
if sys.argv[:3] == ["lightning", "run", "app"]:
|
||||
sys.argv = sys.argv[3:]
|
||||
|
||||
if "--app_args" not in sys.argv:
|
||||
# 2: If app_args wasn't used, there is no arguments, so we assign the shorten arguments.
|
||||
new_argv = sys.argv[:1]
|
||||
else:
|
||||
# 3: Collect all the arguments from the CLI
|
||||
options = [p.opts[0] for p in run_app.params[1:] if p.opts[0] != "--app_args"]
|
||||
argv_slice = sys.argv
|
||||
# 4: Find the index of `app_args`
|
||||
first_index = argv_slice.index("--app_args") + 1
|
||||
# 5: Find the next argument from the CLI if any.
|
||||
matches = [
|
||||
argv_slice.index(opt) for opt in options if opt in argv_slice and argv_slice.index(opt) >= first_index
|
||||
]
|
||||
if not matches:
|
||||
last_index = len(argv_slice)
|
||||
else:
|
||||
last_index = min(matches)
|
||||
# 6: last_index is either the fully command or the latest match from the CLI options.
|
||||
new_argv = [argv_slice[0]] + argv_slice[first_index:last_index]
|
||||
|
||||
# 7: Patch the command
|
||||
sys.argv = new_argv
|
||||
yield
|
||||
# 8: Restore the command
|
||||
sys.argv = original_argv
|
||||
|
||||
|
||||
def component_to_metadata(obj: Union["LightningWork", "LightningFlow"]) -> Dict:
|
||||
from lightning_app import LightningWork
|
||||
|
||||
|
|
|
@ -0,0 +1,67 @@
|
|||
import os
|
||||
import sys
|
||||
|
||||
from lightning_app import _PACKAGE_ROOT
|
||||
from lightning_app.testing.testing import application_testing
|
||||
from lightning_app.utilities.load_app import _patch_sys_argv
|
||||
|
||||
|
||||
def test_app_argparse_example():
|
||||
original_argv = sys.argv
|
||||
|
||||
command_line = [
|
||||
os.path.join(os.path.dirname(os.path.dirname(_PACKAGE_ROOT)), "examples/app_argparse/app.py"),
|
||||
"--app_args",
|
||||
"--use_gpu",
|
||||
"--without-server",
|
||||
]
|
||||
result = application_testing(command_line=command_line)
|
||||
assert result.exit_code == 0, result.__dict__
|
||||
assert sys.argv == original_argv
|
||||
|
||||
|
||||
def test_patch_sys_argv():
|
||||
original_argv = sys.argv
|
||||
|
||||
sys.argv = expected = ["lightning", "run", "app", "app.py"]
|
||||
with _patch_sys_argv():
|
||||
assert sys.argv == ["app.py"]
|
||||
|
||||
assert sys.argv == expected
|
||||
|
||||
sys.argv = expected = ["lightning", "run", "app", "app.py", "--without-server", "--env", "name=something"]
|
||||
with _patch_sys_argv():
|
||||
assert sys.argv == ["app.py"]
|
||||
|
||||
assert sys.argv == expected
|
||||
|
||||
sys.argv = expected = ["lightning", "run", "app", "app.py", "--app_args"]
|
||||
with _patch_sys_argv():
|
||||
assert sys.argv == ["app.py"]
|
||||
|
||||
assert sys.argv == expected
|
||||
|
||||
sys.argv = expected = ["lightning", "run", "app", "app.py", "--app_args", "--env", "name=something"]
|
||||
with _patch_sys_argv():
|
||||
assert sys.argv == ["app.py"]
|
||||
|
||||
assert sys.argv == expected
|
||||
|
||||
sys.argv = expected = [
|
||||
"lightning",
|
||||
"run",
|
||||
"app",
|
||||
"app.py",
|
||||
"--without-server",
|
||||
"--app_args",
|
||||
"--use_gpu",
|
||||
"--name=hello",
|
||||
"--env",
|
||||
"name=something",
|
||||
]
|
||||
with _patch_sys_argv():
|
||||
assert sys.argv == ["app.py", "--use_gpu", "--name=hello"]
|
||||
|
||||
assert sys.argv == expected
|
||||
|
||||
sys.argv = original_argv
|
Loading…
Reference in New Issue