[App] Pass LightningWork to LightningApp (#15215)

* update

* update

* update

* Apply suggestions from code review

* Apply suggestions from code review

* Apply suggestions from code review

* ll

Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>
Co-authored-by: Jirka <jirka.borovec@seznam.cz>
Co-authored-by: Luca Antiga <luca.antiga@gmail.com>
This commit is contained in:
thomas chaton 2022-10-20 15:18:06 +01:00 committed by GitHub
parent 775e9ebc0f
commit 8ec7ddf317
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 40 additions and 5 deletions

View File

@ -17,7 +17,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Added support for adding descriptions to commands either through a docstring or the `DESCRIPTION` attribute ([#15193](https://github.com/Lightning-AI/lightning/pull/15193)
- Added a try / catch mechanism around request processing to avoid killing the flow ([#15187](https://github.com/Lightning-AI/lightning/pull/15187)
- Added a Database Component ([#14995](https://github.com/Lightning-AI/lightning/pull/14995)
- Added support to pass a `LightningWork` to the `LightningApp` ([#15215](https://github.com/Lightning-AI/lightning/pull/15215)
### Fixed

View File

@ -46,7 +46,7 @@ logger = Logger(__name__)
class LightningApp:
def __init__(
self,
root: "lightning_app.LightningFlow",
root: "t.Union[lightning_app.LightningFlow, lightning_app.LightningWork]",
debug: bool = False,
info: frontend.AppInfo = None,
root_path: str = "",
@ -62,8 +62,8 @@ class LightningApp:
the :class:`~lightning_app.core.flow.LightningFlow` provided.
Arguments:
root: The root LightningFlow component, that defines all the app's nested components, running infinitely.
It must define a `run()` method that the app can call.
root: The root ``LightningFlow`` or ``LightningWork`` component, that defines all the app's nested
components, running infinitely. It must define a `run()` method that the app can call.
debug: Whether to activate the Lightning Logger debug mode.
This can be helpful when reporting bugs on Lightning repo.
info: Provide additional info about the app which will be used to update html title,
@ -89,6 +89,10 @@ class LightningApp:
"""
self.root_path = root_path # when running behind a proxy
if isinstance(root, lightning_app.LightningWork):
root = lightning_app.core.flow._RootFlow(root)
_validate_root_flow(root)
self._root = root

View File

@ -763,3 +763,15 @@ class LightningFlow:
child.set_state(state)
elif strict:
raise ValueError(f"The component {child_name} wasn't instantiated for the component {self.name}")
class _RootFlow(LightningFlow):
def __init__(self, work):
super().__init__()
self.work = work
def run(self):
self.work.run()
def configure_layout(self):
return [{"name": "Main", "content": self.work}]

View File

@ -75,7 +75,8 @@ def clear_app_state_state_variables():
lightning_app.utilities.state._STATE = None
lightning_app.utilities.state._LAST_STATE = None
AppState._MY_AFFILIATION = ()
cloud_compute._CLOUD_COMPUTE_STORE.clear()
if hasattr(cloud_compute, "_CLOUD_COMPUTE_STORE"):
cloud_compute._CLOUD_COMPUTE_STORE.clear()
@pytest.fixture

View File

@ -10,6 +10,7 @@ from lightning_app.core.work import BuildConfig, LightningWork, LightningWorkExc
from lightning_app.runners import MultiProcessRuntime
from lightning_app.storage import Path
from lightning_app.testing.helpers import EmptyFlow, EmptyWork, MockQueue
from lightning_app.testing.testing import LightningTestApp
from lightning_app.utilities.enum import WorkStageStatus
from lightning_app.utilities.proxies import ProxyWorkRun, WorkRunner
@ -327,3 +328,20 @@ def test_work_local_build_config_provided():
w = Work()
w.run()
class WorkCounter(LightningWork):
def run(self):
pass
class LightningTestAppWithWork(LightningTestApp):
def on_before_run_once(self):
if self.root.work.has_succeeded:
return True
return super().on_before_run_once()
def test_lightning_app_with_work():
app = LightningTestAppWithWork(WorkCounter())
MultiProcessRuntime(app, start_server=False).dispatch()