Add --app_args support from the CLI (#13625)

This commit is contained in:
thomas chaton 2022-07-15 19:12:40 +01:00 committed by GitHub
parent aa62fe36df
commit 2a873da042
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 159 additions and 11 deletions

View File

@ -0,0 +1,28 @@
import argparse
import lightning as L
class Work(L.LightningWork):
def __init__(self, cloud_compute):
super().__init__(cloud_compute=cloud_compute)
def run(self):
pass
class Flow(L.LightningFlow):
def __init__(self, cloud_compute):
super().__init__()
self.work = Work(cloud_compute)
def run(self):
assert self.work.cloud_compute.name == "gpu", self.work.cloud_compute.name
self._exit()
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--use_gpu", action="store_true", default=False, help="Whether to use GPU in the cloud")
hparams = parser.parse_args()
app = L.LightningApp(Flow(L.CloudCompute("gpu" if hparams.use_gpu else "cpu")))

View File

@ -1,7 +1,7 @@
import logging
import os
from pathlib import Path
from typing import Tuple, Union
from typing import List, Tuple, Union
import click
from requests.exceptions import ConnectionError
@ -109,8 +109,17 @@ def run():
@click.option("--blocking", "blocking", type=bool, default=False)
@click.option("--open-ui", type=bool, default=True, help="Decide whether to launch the app UI in a web browser")
@click.option("--env", type=str, default=[], multiple=True, help="Env variables to be set for the app.")
@click.option("--app_args", type=str, default=[], multiple=True, help="Collection of arguments for the app.")
def run_app(
file: str, cloud: bool, without_server: bool, no_cache: bool, name: str, blocking: bool, open_ui: bool, env: tuple
file: str,
cloud: bool,
without_server: bool,
no_cache: bool,
name: str,
blocking: bool,
open_ui: bool,
env: tuple,
app_args: List[str],
):
"""Run an app from a file."""
_run_app(file, cloud, without_server, no_cache, name, blocking, open_ui, env)
@ -263,10 +272,4 @@ def _prepare_file(file: str) -> str:
if exists:
return file
if not exists and file == "quick_start.py":
from lightning_app.demo.quick_start import app
logger.info(f"For demo purposes, Lightning will run the {app.__file__} file.")
return app.__file__
raise FileNotFoundError(f"The provided file {file} hasn't been found.")

View File

@ -76,14 +76,20 @@ class LightningTestApp(LightningApp):
@requires("click")
def application_testing(lightning_app_cls: Type[LightningTestApp], command_line: List[str] = []) -> Any:
def application_testing(
lightning_app_cls: Type[LightningTestApp] = LightningTestApp, command_line: List[str] = []
) -> Any:
from unittest import mock
from click.testing import CliRunner
with mock.patch("lightning.LightningApp", lightning_app_cls):
original = sys.argv
sys.argv = command_line
runner = CliRunner()
return runner.invoke(run_app, command_line, catch_exceptions=False)
result = runner.invoke(run_app, command_line, catch_exceptions=False)
sys.argv = original
return result
class SingleWorkFlow(LightningFlow):

View File

@ -4,6 +4,7 @@ import os
import sys
import traceback
import types
from contextlib import contextmanager
from typing import Dict, List, TYPE_CHECKING, Union
from lightning_app.utilities.exceptions import MisconfigurationException
@ -26,7 +27,8 @@ def load_app_from_file(filepath: str) -> "LightningApp":
code = _create_code(filepath)
module = _create_fake_main_module(filepath)
try:
exec(code, module.__dict__)
with _patch_sys_argv():
exec(code, module.__dict__)
except Exception:
# we want to format the exception as if no frame was on top.
exp, val, tb = sys.exc_info()
@ -113,6 +115,48 @@ def _create_fake_main_module(script_path):
return module
@contextmanager
def _patch_sys_argv():
"""This function modifies the ``sys.argv`` by extracting the arguments after ``--app_args`` and removed
everything else before executing the user app script.
The command: ``lightning run app app.py --without-server --app_args --use_gpu --env ...`` will be converted into
``app.py --use_gpu``
"""
from lightning_app.cli.lightning_cli import run_app
original_argv = sys.argv
# 1: Remove the CLI command
if sys.argv[:3] == ["lightning", "run", "app"]:
sys.argv = sys.argv[3:]
if "--app_args" not in sys.argv:
# 2: If app_args wasn't used, there is no arguments, so we assign the shorten arguments.
new_argv = sys.argv[:1]
else:
# 3: Collect all the arguments from the CLI
options = [p.opts[0] for p in run_app.params[1:] if p.opts[0] != "--app_args"]
argv_slice = sys.argv
# 4: Find the index of `app_args`
first_index = argv_slice.index("--app_args") + 1
# 5: Find the next argument from the CLI if any.
matches = [
argv_slice.index(opt) for opt in options if opt in argv_slice and argv_slice.index(opt) >= first_index
]
if not matches:
last_index = len(argv_slice)
else:
last_index = min(matches)
# 6: last_index is either the fully command or the latest match from the CLI options.
new_argv = [argv_slice[0]] + argv_slice[first_index:last_index]
# 7: Patch the command
sys.argv = new_argv
yield
# 8: Restore the command
sys.argv = original_argv
def component_to_metadata(obj: Union["LightningWork", "LightningFlow"]) -> Dict:
from lightning_app import LightningWork

View File

@ -0,0 +1,67 @@
import os
import sys
from lightning_app import _PACKAGE_ROOT
from lightning_app.testing.testing import application_testing
from lightning_app.utilities.load_app import _patch_sys_argv
def test_app_argparse_example():
original_argv = sys.argv
command_line = [
os.path.join(os.path.dirname(os.path.dirname(_PACKAGE_ROOT)), "examples/app_argparse/app.py"),
"--app_args",
"--use_gpu",
"--without-server",
]
result = application_testing(command_line=command_line)
assert result.exit_code == 0, result.__dict__
assert sys.argv == original_argv
def test_patch_sys_argv():
original_argv = sys.argv
sys.argv = expected = ["lightning", "run", "app", "app.py"]
with _patch_sys_argv():
assert sys.argv == ["app.py"]
assert sys.argv == expected
sys.argv = expected = ["lightning", "run", "app", "app.py", "--without-server", "--env", "name=something"]
with _patch_sys_argv():
assert sys.argv == ["app.py"]
assert sys.argv == expected
sys.argv = expected = ["lightning", "run", "app", "app.py", "--app_args"]
with _patch_sys_argv():
assert sys.argv == ["app.py"]
assert sys.argv == expected
sys.argv = expected = ["lightning", "run", "app", "app.py", "--app_args", "--env", "name=something"]
with _patch_sys_argv():
assert sys.argv == ["app.py"]
assert sys.argv == expected
sys.argv = expected = [
"lightning",
"run",
"app",
"app.py",
"--without-server",
"--app_args",
"--use_gpu",
"--name=hello",
"--env",
"name=something",
]
with _patch_sys_argv():
assert sys.argv == ["app.py", "--use_gpu", "--name=hello"]
assert sys.argv == expected
sys.argv = original_argv