[App] Add display name property to the work (#16095)
Co-authored-by: thomas <thomas@thomass-MacBook-Pro.local>
This commit is contained in:
parent
3e8319d422
commit
22b254f491
|
@ -0,0 +1 @@
|
|||
venv/
|
|
@ -0,0 +1,25 @@
|
|||
import lightning as L
|
||||
|
||||
|
||||
class Work(L.LightningWork):
|
||||
def __init__(self, start_with_flow=True):
|
||||
super().__init__(start_with_flow=start_with_flow)
|
||||
|
||||
def run(self):
|
||||
pass
|
||||
|
||||
|
||||
class Flow(L.LightningFlow):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.w = Work()
|
||||
self.w1 = Work(start_with_flow=False)
|
||||
self.w.display_name = "My Custom Name" # Not supported yet
|
||||
self.w1.display_name = "My Custom Name 1"
|
||||
|
||||
def run(self):
|
||||
self.w.run()
|
||||
self.w1.run()
|
||||
|
||||
|
||||
app = L.LightningApp(Flow())
|
|
@ -16,6 +16,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
|
|||
|
||||
- Added more datatypes to serving component ([#16018](https://github.com/Lightning-AI/lightning/pull/16018))
|
||||
|
||||
- Added `display_name` property to LightningWork for the cloud ([#16095](https://github.com/Lightning-AI/lightning/pull/16095))
|
||||
|
||||
|
||||
### Changed
|
||||
|
||||
-
|
||||
|
|
|
@ -8,7 +8,7 @@ from multiprocessing import Queue
|
|||
from typing import Any, Callable, Dict, List, Optional
|
||||
from uuid import uuid4
|
||||
|
||||
from fastapi import FastAPI, HTTPException, Request
|
||||
from fastapi import FastAPI, HTTPException, Request, status
|
||||
from lightning_utilities.core.apply_func import apply_to_collection
|
||||
|
||||
from lightning_app.api.request_types import _APIRequest, _CommandRequest, _RequestResponse
|
||||
|
@ -170,7 +170,10 @@ class _HttpMethod:
|
|||
while request_id not in responses_store:
|
||||
await asyncio.sleep(0.01)
|
||||
if (time.time() - t0) > self.timeout:
|
||||
raise Exception("The response was never received.")
|
||||
raise HTTPException(
|
||||
status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="The response was never received.",
|
||||
)
|
||||
|
||||
logger.debug(f"Processed request {request_id} for route: {self.route}")
|
||||
|
||||
|
|
|
@ -119,7 +119,16 @@ class LightningWork:
|
|||
" in the next version. Use `cache_calls` instead."
|
||||
)
|
||||
self._cache_calls = run_once if run_once is not None else cache_calls
|
||||
self._state = {"_host", "_port", "_url", "_future_url", "_internal_ip", "_restarting", "_cloud_compute"}
|
||||
self._state = {
|
||||
"_host",
|
||||
"_port",
|
||||
"_url",
|
||||
"_future_url",
|
||||
"_internal_ip",
|
||||
"_restarting",
|
||||
"_cloud_compute",
|
||||
"_display_name",
|
||||
}
|
||||
self._parallel = parallel
|
||||
self._host: str = host
|
||||
self._port: Optional[int] = port
|
||||
|
@ -129,6 +138,7 @@ class LightningWork:
|
|||
# setattr_replacement is used by the multiprocessing runtime to send the latest changes to the main coordinator
|
||||
self._setattr_replacement: Optional[Callable[[str, Any], None]] = None
|
||||
self._name = ""
|
||||
self._display_name = ""
|
||||
# The ``self._calls`` is used to track whether the run
|
||||
# method with a given set of input arguments has already been called.
|
||||
# Example of its usage:
|
||||
|
@ -207,6 +217,22 @@ class LightningWork:
|
|||
"""Returns the name of the LightningWork."""
|
||||
return self._name
|
||||
|
||||
@property
|
||||
def display_name(self):
|
||||
"""Returns the display name of the LightningWork in the cloud.
|
||||
|
||||
The display name needs to set before the run method of the work is called.
|
||||
"""
|
||||
return self._display_name
|
||||
|
||||
@display_name.setter
|
||||
def display_name(self, display_name: str):
|
||||
"""Sets the display name of the LightningWork in the cloud."""
|
||||
if not self.has_started:
|
||||
self._display_name = display_name
|
||||
elif self._display_name != display_name:
|
||||
raise RuntimeError("The display name can be set only before the work has started.")
|
||||
|
||||
@property
|
||||
def cache_calls(self) -> bool:
|
||||
"""Returns whether the ``run`` method should cache its input arguments and not run again when provided with
|
||||
|
|
|
@ -22,11 +22,12 @@ def update_index_file(ui_root: str, info: Optional[AppInfo] = None, root_path: s
|
|||
entry_file = Path(ui_root) / "index.html"
|
||||
original_file = Path(ui_root) / "index.original.html"
|
||||
|
||||
if not original_file.exists():
|
||||
shutil.copyfile(entry_file, original_file) # keep backup
|
||||
else:
|
||||
# revert index.html in case it was modified after creating original.html
|
||||
shutil.copyfile(original_file, entry_file)
|
||||
if root_path:
|
||||
if not original_file.exists():
|
||||
shutil.copyfile(entry_file, original_file) # keep backup
|
||||
else:
|
||||
# revert index.html in case it was modified after creating original.html
|
||||
shutil.copyfile(original_file, entry_file)
|
||||
|
||||
if info:
|
||||
with original_file.open() as f:
|
||||
|
|
|
@ -124,6 +124,7 @@ def test_simple_app(tmpdir):
|
|||
"_paths": {},
|
||||
"_port": None,
|
||||
"_restarting": False,
|
||||
"_display_name": "",
|
||||
},
|
||||
"calls": {"latest_call_hash": None},
|
||||
"changes": {},
|
||||
|
@ -140,6 +141,7 @@ def test_simple_app(tmpdir):
|
|||
"_paths": {},
|
||||
"_port": None,
|
||||
"_restarting": False,
|
||||
"_display_name": "",
|
||||
},
|
||||
"calls": {"latest_call_hash": None},
|
||||
"changes": {},
|
||||
|
@ -969,7 +971,7 @@ class SizeFlow(LightningFlow):
|
|||
def test_state_size_constant_growth():
|
||||
app = LightningApp(SizeFlow())
|
||||
MultiProcessRuntime(app, start_server=False).dispatch()
|
||||
assert app.root._state_sizes[0] <= 7824
|
||||
assert app.root._state_sizes[0] <= 7888
|
||||
assert app.root._state_sizes[20] <= 26500
|
||||
|
||||
|
||||
|
|
|
@ -329,6 +329,7 @@ def test_lightning_flow_and_work():
|
|||
"_paths": {},
|
||||
"_restarting": False,
|
||||
"_internal_ip": "",
|
||||
"_display_name": "",
|
||||
"_cloud_compute": {
|
||||
"type": "__cloud_compute__",
|
||||
"name": "default",
|
||||
|
@ -352,6 +353,7 @@ def test_lightning_flow_and_work():
|
|||
"_paths": {},
|
||||
"_restarting": False,
|
||||
"_internal_ip": "",
|
||||
"_display_name": "",
|
||||
"_cloud_compute": {
|
||||
"type": "__cloud_compute__",
|
||||
"name": "default",
|
||||
|
@ -391,6 +393,7 @@ def test_lightning_flow_and_work():
|
|||
"_paths": {},
|
||||
"_restarting": False,
|
||||
"_internal_ip": "",
|
||||
"_display_name": "",
|
||||
"_cloud_compute": {
|
||||
"type": "__cloud_compute__",
|
||||
"name": "default",
|
||||
|
@ -414,6 +417,7 @@ def test_lightning_flow_and_work():
|
|||
"_paths": {},
|
||||
"_restarting": False,
|
||||
"_internal_ip": "",
|
||||
"_display_name": "",
|
||||
"_cloud_compute": {
|
||||
"type": "__cloud_compute__",
|
||||
"name": "default",
|
||||
|
|
|
@ -11,7 +11,7 @@ from lightning_app.runners import MultiProcessRuntime
|
|||
from lightning_app.storage import Path
|
||||
from lightning_app.testing.helpers import _MockQueue, EmptyFlow, EmptyWork
|
||||
from lightning_app.testing.testing import LightningTestApp
|
||||
from lightning_app.utilities.enum import WorkStageStatus
|
||||
from lightning_app.utilities.enum import make_status, WorkStageStatus
|
||||
from lightning_app.utilities.exceptions import LightningWorkException
|
||||
from lightning_app.utilities.packaging.build_config import BuildConfig
|
||||
from lightning_app.utilities.proxies import ProxyWorkRun, WorkRunner
|
||||
|
@ -384,3 +384,24 @@ class FlowStart(LightningFlow):
|
|||
def test_lightning_app_work_start(cache_calls, parallel):
|
||||
app = LightningApp(FlowStart(cache_calls, parallel))
|
||||
MultiProcessRuntime(app, start_server=False).dispatch()
|
||||
|
||||
|
||||
class WorkDisplay(LightningWork):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
def run(self):
|
||||
pass
|
||||
|
||||
|
||||
def test_lightning_work_display_name():
|
||||
work = WorkDisplay()
|
||||
assert work.state_vars["vars"]["_display_name"] == ""
|
||||
work.display_name = "Hello"
|
||||
assert work.state_vars["vars"]["_display_name"] == "Hello"
|
||||
|
||||
work._calls["latest_call_hash"] = "test"
|
||||
work._calls["test"] = {"statuses": [make_status(WorkStageStatus.PENDING)]}
|
||||
with pytest.raises(RuntimeError, match="The display name can be set only before the work has started."):
|
||||
work.display_name = "HELLO"
|
||||
work.display_name = "Hello"
|
||||
|
|
|
@ -44,6 +44,7 @@ def test_dict():
|
|||
"_host": "127.0.0.1",
|
||||
"_paths": {},
|
||||
"_restarting": False,
|
||||
"_display_name": "",
|
||||
"_internal_ip": "",
|
||||
"_cloud_compute": {
|
||||
"type": "__cloud_compute__",
|
||||
|
@ -76,6 +77,7 @@ def test_dict():
|
|||
"_host": "127.0.0.1",
|
||||
"_paths": {},
|
||||
"_restarting": False,
|
||||
"_display_name": "",
|
||||
"_internal_ip": "",
|
||||
"_cloud_compute": {
|
||||
"type": "__cloud_compute__",
|
||||
|
@ -108,6 +110,7 @@ def test_dict():
|
|||
"_host": "127.0.0.1",
|
||||
"_paths": {},
|
||||
"_restarting": False,
|
||||
"_display_name": "",
|
||||
"_internal_ip": "",
|
||||
"_cloud_compute": {
|
||||
"type": "__cloud_compute__",
|
||||
|
@ -193,6 +196,7 @@ def test_list():
|
|||
"_paths": {},
|
||||
"_restarting": False,
|
||||
"_internal_ip": "",
|
||||
"_display_name": "",
|
||||
"_cloud_compute": {
|
||||
"type": "__cloud_compute__",
|
||||
"name": "default",
|
||||
|
@ -225,6 +229,7 @@ def test_list():
|
|||
"_paths": {},
|
||||
"_restarting": False,
|
||||
"_internal_ip": "",
|
||||
"_display_name": "",
|
||||
"_cloud_compute": {
|
||||
"type": "__cloud_compute__",
|
||||
"name": "default",
|
||||
|
@ -252,6 +257,7 @@ def test_list():
|
|||
"_paths": {},
|
||||
"_restarting": False,
|
||||
"_internal_ip": "",
|
||||
"_display_name": "",
|
||||
"_cloud_compute": {
|
||||
"type": "__cloud_compute__",
|
||||
"name": "default",
|
||||
|
|
Loading…
Reference in New Issue