From 2a873da042d913d9a93e9d6f16ccdd123d020063 Mon Sep 17 00:00:00 2001 From: thomas chaton Date: Fri, 15 Jul 2022 19:12:40 +0100 Subject: [PATCH] Add --app_args support from the CLI (#13625) --- examples/app_argparse/app.py | 28 ++++++++++ src/lightning_app/cli/lightning_cli.py | 19 ++++--- src/lightning_app/testing/testing.py | 10 +++- src/lightning_app/utilities/load_app.py | 46 +++++++++++++++- tests/tests_app_examples/test_argparse.py | 67 +++++++++++++++++++++++ 5 files changed, 159 insertions(+), 11 deletions(-) create mode 100644 examples/app_argparse/app.py create mode 100644 tests/tests_app_examples/test_argparse.py diff --git a/examples/app_argparse/app.py b/examples/app_argparse/app.py new file mode 100644 index 0000000000..98683d7ef3 --- /dev/null +++ b/examples/app_argparse/app.py @@ -0,0 +1,28 @@ +import argparse + +import lightning as L + + +class Work(L.LightningWork): + def __init__(self, cloud_compute): + super().__init__(cloud_compute=cloud_compute) + + def run(self): + pass + + +class Flow(L.LightningFlow): + def __init__(self, cloud_compute): + super().__init__() + self.work = Work(cloud_compute) + + def run(self): + assert self.work.cloud_compute.name == "gpu", self.work.cloud_compute.name + self._exit() + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--use_gpu", action="store_true", default=False, help="Whether to use GPU in the cloud") + hparams = parser.parse_args() + app = L.LightningApp(Flow(L.CloudCompute("gpu" if hparams.use_gpu else "cpu"))) diff --git a/src/lightning_app/cli/lightning_cli.py b/src/lightning_app/cli/lightning_cli.py index fb39f743ec..696269c712 100644 --- a/src/lightning_app/cli/lightning_cli.py +++ b/src/lightning_app/cli/lightning_cli.py @@ -1,7 +1,7 @@ import logging import os from pathlib import Path -from typing import Tuple, Union +from typing import List, Tuple, Union import click from requests.exceptions import ConnectionError @@ -109,8 +109,17 @@ def run(): @click.option("--blocking", "blocking", type=bool, default=False) @click.option("--open-ui", type=bool, default=True, help="Decide whether to launch the app UI in a web browser") @click.option("--env", type=str, default=[], multiple=True, help="Env variables to be set for the app.") +@click.option("--app_args", type=str, default=[], multiple=True, help="Collection of arguments for the app.") def run_app( - file: str, cloud: bool, without_server: bool, no_cache: bool, name: str, blocking: bool, open_ui: bool, env: tuple + file: str, + cloud: bool, + without_server: bool, + no_cache: bool, + name: str, + blocking: bool, + open_ui: bool, + env: tuple, + app_args: List[str], ): """Run an app from a file.""" _run_app(file, cloud, without_server, no_cache, name, blocking, open_ui, env) @@ -263,10 +272,4 @@ def _prepare_file(file: str) -> str: if exists: return file - if not exists and file == "quick_start.py": - from lightning_app.demo.quick_start import app - - logger.info(f"For demo purposes, Lightning will run the {app.__file__} file.") - return app.__file__ - raise FileNotFoundError(f"The provided file {file} hasn't been found.") diff --git a/src/lightning_app/testing/testing.py b/src/lightning_app/testing/testing.py index 7ae9bf6274..e72c6d05ae 100644 --- a/src/lightning_app/testing/testing.py +++ b/src/lightning_app/testing/testing.py @@ -76,14 +76,20 @@ class LightningTestApp(LightningApp): @requires("click") -def application_testing(lightning_app_cls: Type[LightningTestApp], command_line: List[str] = []) -> Any: +def application_testing( + lightning_app_cls: Type[LightningTestApp] = LightningTestApp, command_line: List[str] = [] +) -> Any: from unittest import mock from click.testing import CliRunner with mock.patch("lightning.LightningApp", lightning_app_cls): + original = sys.argv + sys.argv = command_line runner = CliRunner() - return runner.invoke(run_app, command_line, catch_exceptions=False) + result = runner.invoke(run_app, command_line, catch_exceptions=False) + sys.argv = original + return result class SingleWorkFlow(LightningFlow): diff --git a/src/lightning_app/utilities/load_app.py b/src/lightning_app/utilities/load_app.py index 3fef4e63f9..0fff863bc4 100644 --- a/src/lightning_app/utilities/load_app.py +++ b/src/lightning_app/utilities/load_app.py @@ -4,6 +4,7 @@ import os import sys import traceback import types +from contextlib import contextmanager from typing import Dict, List, TYPE_CHECKING, Union from lightning_app.utilities.exceptions import MisconfigurationException @@ -26,7 +27,8 @@ def load_app_from_file(filepath: str) -> "LightningApp": code = _create_code(filepath) module = _create_fake_main_module(filepath) try: - exec(code, module.__dict__) + with _patch_sys_argv(): + exec(code, module.__dict__) except Exception: # we want to format the exception as if no frame was on top. exp, val, tb = sys.exc_info() @@ -113,6 +115,48 @@ def _create_fake_main_module(script_path): return module +@contextmanager +def _patch_sys_argv(): + """This function modifies the ``sys.argv`` by extracting the arguments after ``--app_args`` and removed + everything else before executing the user app script. + + The command: ``lightning run app app.py --without-server --app_args --use_gpu --env ...`` will be converted into + ``app.py --use_gpu`` + """ + from lightning_app.cli.lightning_cli import run_app + + original_argv = sys.argv + # 1: Remove the CLI command + if sys.argv[:3] == ["lightning", "run", "app"]: + sys.argv = sys.argv[3:] + + if "--app_args" not in sys.argv: + # 2: If app_args wasn't used, there is no arguments, so we assign the shorten arguments. + new_argv = sys.argv[:1] + else: + # 3: Collect all the arguments from the CLI + options = [p.opts[0] for p in run_app.params[1:] if p.opts[0] != "--app_args"] + argv_slice = sys.argv + # 4: Find the index of `app_args` + first_index = argv_slice.index("--app_args") + 1 + # 5: Find the next argument from the CLI if any. + matches = [ + argv_slice.index(opt) for opt in options if opt in argv_slice and argv_slice.index(opt) >= first_index + ] + if not matches: + last_index = len(argv_slice) + else: + last_index = min(matches) + # 6: last_index is either the fully command or the latest match from the CLI options. + new_argv = [argv_slice[0]] + argv_slice[first_index:last_index] + + # 7: Patch the command + sys.argv = new_argv + yield + # 8: Restore the command + sys.argv = original_argv + + def component_to_metadata(obj: Union["LightningWork", "LightningFlow"]) -> Dict: from lightning_app import LightningWork diff --git a/tests/tests_app_examples/test_argparse.py b/tests/tests_app_examples/test_argparse.py new file mode 100644 index 0000000000..0c1e55b0d4 --- /dev/null +++ b/tests/tests_app_examples/test_argparse.py @@ -0,0 +1,67 @@ +import os +import sys + +from lightning_app import _PACKAGE_ROOT +from lightning_app.testing.testing import application_testing +from lightning_app.utilities.load_app import _patch_sys_argv + + +def test_app_argparse_example(): + original_argv = sys.argv + + command_line = [ + os.path.join(os.path.dirname(os.path.dirname(_PACKAGE_ROOT)), "examples/app_argparse/app.py"), + "--app_args", + "--use_gpu", + "--without-server", + ] + result = application_testing(command_line=command_line) + assert result.exit_code == 0, result.__dict__ + assert sys.argv == original_argv + + +def test_patch_sys_argv(): + original_argv = sys.argv + + sys.argv = expected = ["lightning", "run", "app", "app.py"] + with _patch_sys_argv(): + assert sys.argv == ["app.py"] + + assert sys.argv == expected + + sys.argv = expected = ["lightning", "run", "app", "app.py", "--without-server", "--env", "name=something"] + with _patch_sys_argv(): + assert sys.argv == ["app.py"] + + assert sys.argv == expected + + sys.argv = expected = ["lightning", "run", "app", "app.py", "--app_args"] + with _patch_sys_argv(): + assert sys.argv == ["app.py"] + + assert sys.argv == expected + + sys.argv = expected = ["lightning", "run", "app", "app.py", "--app_args", "--env", "name=something"] + with _patch_sys_argv(): + assert sys.argv == ["app.py"] + + assert sys.argv == expected + + sys.argv = expected = [ + "lightning", + "run", + "app", + "app.py", + "--without-server", + "--app_args", + "--use_gpu", + "--name=hello", + "--env", + "name=something", + ] + with _patch_sys_argv(): + assert sys.argv == ["app.py", "--use_gpu", "--name=hello"] + + assert sys.argv == expected + + sys.argv = original_argv