From 8ec7ddf317ad2a185abaa15d66141e1115c96c79 Mon Sep 17 00:00:00 2001 From: thomas chaton Date: Thu, 20 Oct 2022 15:18:06 +0100 Subject: [PATCH] [App] Pass LightningWork to LightningApp (#15215) * update * update * update * Apply suggestions from code review * Apply suggestions from code review * Apply suggestions from code review * ll Co-authored-by: Jirka Borovec Co-authored-by: Jirka Co-authored-by: Luca Antiga --- src/lightning_app/CHANGELOG.md | 2 +- src/lightning_app/core/app.py | 10 +++++++--- src/lightning_app/core/flow.py | 12 ++++++++++++ tests/tests_app/conftest.py | 3 ++- tests/tests_app/core/test_lightning_work.py | 18 ++++++++++++++++++ 5 files changed, 40 insertions(+), 5 deletions(-) diff --git a/src/lightning_app/CHANGELOG.md b/src/lightning_app/CHANGELOG.md index cbdd55f632..9bead9acfb 100644 --- a/src/lightning_app/CHANGELOG.md +++ b/src/lightning_app/CHANGELOG.md @@ -17,7 +17,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Added support for adding descriptions to commands either through a docstring or the `DESCRIPTION` attribute ([#15193](https://github.com/Lightning-AI/lightning/pull/15193) - Added a try / catch mechanism around request processing to avoid killing the flow ([#15187](https://github.com/Lightning-AI/lightning/pull/15187) - Added a Database Component ([#14995](https://github.com/Lightning-AI/lightning/pull/14995) - +- Added support to pass a `LightningWork` to the `LightningApp` ([#15215](https://github.com/Lightning-AI/lightning/pull/15215) ### Fixed diff --git a/src/lightning_app/core/app.py b/src/lightning_app/core/app.py index 879b7fcf4e..0552c146f4 100644 --- a/src/lightning_app/core/app.py +++ b/src/lightning_app/core/app.py @@ -46,7 +46,7 @@ logger = Logger(__name__) class LightningApp: def __init__( self, - root: "lightning_app.LightningFlow", + root: "t.Union[lightning_app.LightningFlow, lightning_app.LightningWork]", debug: bool = False, info: frontend.AppInfo = None, root_path: str = "", @@ -62,8 +62,8 @@ class LightningApp: the :class:`~lightning_app.core.flow.LightningFlow` provided. Arguments: - root: The root LightningFlow component, that defines all the app's nested components, running infinitely. - It must define a `run()` method that the app can call. + root: The root ``LightningFlow`` or ``LightningWork`` component, that defines all the app's nested + components, running infinitely. It must define a `run()` method that the app can call. debug: Whether to activate the Lightning Logger debug mode. This can be helpful when reporting bugs on Lightning repo. info: Provide additional info about the app which will be used to update html title, @@ -89,6 +89,10 @@ class LightningApp: """ self.root_path = root_path # when running behind a proxy + + if isinstance(root, lightning_app.LightningWork): + root = lightning_app.core.flow._RootFlow(root) + _validate_root_flow(root) self._root = root diff --git a/src/lightning_app/core/flow.py b/src/lightning_app/core/flow.py index 9fe723e2b9..f42bd48628 100644 --- a/src/lightning_app/core/flow.py +++ b/src/lightning_app/core/flow.py @@ -763,3 +763,15 @@ class LightningFlow: child.set_state(state) elif strict: raise ValueError(f"The component {child_name} wasn't instantiated for the component {self.name}") + + +class _RootFlow(LightningFlow): + def __init__(self, work): + super().__init__() + self.work = work + + def run(self): + self.work.run() + + def configure_layout(self): + return [{"name": "Main", "content": self.work}] diff --git a/tests/tests_app/conftest.py b/tests/tests_app/conftest.py index e8b887637e..138e2b0808 100644 --- a/tests/tests_app/conftest.py +++ b/tests/tests_app/conftest.py @@ -75,7 +75,8 @@ def clear_app_state_state_variables(): lightning_app.utilities.state._STATE = None lightning_app.utilities.state._LAST_STATE = None AppState._MY_AFFILIATION = () - cloud_compute._CLOUD_COMPUTE_STORE.clear() + if hasattr(cloud_compute, "_CLOUD_COMPUTE_STORE"): + cloud_compute._CLOUD_COMPUTE_STORE.clear() @pytest.fixture diff --git a/tests/tests_app/core/test_lightning_work.py b/tests/tests_app/core/test_lightning_work.py index e0619420f4..f36ca030be 100644 --- a/tests/tests_app/core/test_lightning_work.py +++ b/tests/tests_app/core/test_lightning_work.py @@ -10,6 +10,7 @@ from lightning_app.core.work import BuildConfig, LightningWork, LightningWorkExc from lightning_app.runners import MultiProcessRuntime from lightning_app.storage import Path from lightning_app.testing.helpers import EmptyFlow, EmptyWork, MockQueue +from lightning_app.testing.testing import LightningTestApp from lightning_app.utilities.enum import WorkStageStatus from lightning_app.utilities.proxies import ProxyWorkRun, WorkRunner @@ -327,3 +328,20 @@ def test_work_local_build_config_provided(): w = Work() w.run() + + +class WorkCounter(LightningWork): + def run(self): + pass + + +class LightningTestAppWithWork(LightningTestApp): + def on_before_run_once(self): + if self.root.work.has_succeeded: + return True + return super().on_before_run_once() + + +def test_lightning_app_with_work(): + app = LightningTestAppWithWork(WorkCounter()) + MultiProcessRuntime(app, start_server=False).dispatch()