[App] Pass LightningWork to LightningApp (#15215)
* update * update * update * Apply suggestions from code review * Apply suggestions from code review * Apply suggestions from code review * ll Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> Co-authored-by: Jirka <jirka.borovec@seznam.cz> Co-authored-by: Luca Antiga <luca.antiga@gmail.com>
This commit is contained in:
parent
775e9ebc0f
commit
8ec7ddf317
|
@ -17,7 +17,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
|
|||
- Added support for adding descriptions to commands either through a docstring or the `DESCRIPTION` attribute ([#15193](https://github.com/Lightning-AI/lightning/pull/15193)
|
||||
- Added a try / catch mechanism around request processing to avoid killing the flow ([#15187](https://github.com/Lightning-AI/lightning/pull/15187)
|
||||
- Added a Database Component ([#14995](https://github.com/Lightning-AI/lightning/pull/14995)
|
||||
|
||||
- Added support to pass a `LightningWork` to the `LightningApp` ([#15215](https://github.com/Lightning-AI/lightning/pull/15215)
|
||||
|
||||
### Fixed
|
||||
|
||||
|
|
|
@ -46,7 +46,7 @@ logger = Logger(__name__)
|
|||
class LightningApp:
|
||||
def __init__(
|
||||
self,
|
||||
root: "lightning_app.LightningFlow",
|
||||
root: "t.Union[lightning_app.LightningFlow, lightning_app.LightningWork]",
|
||||
debug: bool = False,
|
||||
info: frontend.AppInfo = None,
|
||||
root_path: str = "",
|
||||
|
@ -62,8 +62,8 @@ class LightningApp:
|
|||
the :class:`~lightning_app.core.flow.LightningFlow` provided.
|
||||
|
||||
Arguments:
|
||||
root: The root LightningFlow component, that defines all the app's nested components, running infinitely.
|
||||
It must define a `run()` method that the app can call.
|
||||
root: The root ``LightningFlow`` or ``LightningWork`` component, that defines all the app's nested
|
||||
components, running infinitely. It must define a `run()` method that the app can call.
|
||||
debug: Whether to activate the Lightning Logger debug mode.
|
||||
This can be helpful when reporting bugs on Lightning repo.
|
||||
info: Provide additional info about the app which will be used to update html title,
|
||||
|
@ -89,6 +89,10 @@ class LightningApp:
|
|||
"""
|
||||
|
||||
self.root_path = root_path # when running behind a proxy
|
||||
|
||||
if isinstance(root, lightning_app.LightningWork):
|
||||
root = lightning_app.core.flow._RootFlow(root)
|
||||
|
||||
_validate_root_flow(root)
|
||||
self._root = root
|
||||
|
||||
|
|
|
@ -763,3 +763,15 @@ class LightningFlow:
|
|||
child.set_state(state)
|
||||
elif strict:
|
||||
raise ValueError(f"The component {child_name} wasn't instantiated for the component {self.name}")
|
||||
|
||||
|
||||
class _RootFlow(LightningFlow):
|
||||
def __init__(self, work):
|
||||
super().__init__()
|
||||
self.work = work
|
||||
|
||||
def run(self):
|
||||
self.work.run()
|
||||
|
||||
def configure_layout(self):
|
||||
return [{"name": "Main", "content": self.work}]
|
||||
|
|
|
@ -75,7 +75,8 @@ def clear_app_state_state_variables():
|
|||
lightning_app.utilities.state._STATE = None
|
||||
lightning_app.utilities.state._LAST_STATE = None
|
||||
AppState._MY_AFFILIATION = ()
|
||||
cloud_compute._CLOUD_COMPUTE_STORE.clear()
|
||||
if hasattr(cloud_compute, "_CLOUD_COMPUTE_STORE"):
|
||||
cloud_compute._CLOUD_COMPUTE_STORE.clear()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
|
|
|
@ -10,6 +10,7 @@ from lightning_app.core.work import BuildConfig, LightningWork, LightningWorkExc
|
|||
from lightning_app.runners import MultiProcessRuntime
|
||||
from lightning_app.storage import Path
|
||||
from lightning_app.testing.helpers import EmptyFlow, EmptyWork, MockQueue
|
||||
from lightning_app.testing.testing import LightningTestApp
|
||||
from lightning_app.utilities.enum import WorkStageStatus
|
||||
from lightning_app.utilities.proxies import ProxyWorkRun, WorkRunner
|
||||
|
||||
|
@ -327,3 +328,20 @@ def test_work_local_build_config_provided():
|
|||
|
||||
w = Work()
|
||||
w.run()
|
||||
|
||||
|
||||
class WorkCounter(LightningWork):
|
||||
def run(self):
|
||||
pass
|
||||
|
||||
|
||||
class LightningTestAppWithWork(LightningTestApp):
|
||||
def on_before_run_once(self):
|
||||
if self.root.work.has_succeeded:
|
||||
return True
|
||||
return super().on_before_run_once()
|
||||
|
||||
|
||||
def test_lightning_app_with_work():
|
||||
app = LightningTestAppWithWork(WorkCounter())
|
||||
MultiProcessRuntime(app, start_server=False).dispatch()
|
||||
|
|
Loading…
Reference in New Issue