diff --git a/examples/app_display_name/.lightningignore b/examples/app_display_name/.lightningignore new file mode 100644 index 0000000000..f7275bbbd0 --- /dev/null +++ b/examples/app_display_name/.lightningignore @@ -0,0 +1 @@ +venv/ diff --git a/examples/app_display_name/app.py b/examples/app_display_name/app.py new file mode 100644 index 0000000000..f06d8ee562 --- /dev/null +++ b/examples/app_display_name/app.py @@ -0,0 +1,25 @@ +import lightning as L + + +class Work(L.LightningWork): + def __init__(self, start_with_flow=True): + super().__init__(start_with_flow=start_with_flow) + + def run(self): + pass + + +class Flow(L.LightningFlow): + def __init__(self): + super().__init__() + self.w = Work() + self.w1 = Work(start_with_flow=False) + self.w.display_name = "My Custom Name" # Not supported yet + self.w1.display_name = "My Custom Name 1" + + def run(self): + self.w.run() + self.w1.run() + + +app = L.LightningApp(Flow()) diff --git a/src/lightning_app/CHANGELOG.md b/src/lightning_app/CHANGELOG.md index b427988b92..1ffd4f1efb 100644 --- a/src/lightning_app/CHANGELOG.md +++ b/src/lightning_app/CHANGELOG.md @@ -16,6 +16,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Added more datatypes to serving component ([#16018](https://github.com/Lightning-AI/lightning/pull/16018)) +- Added `display_name` property to LightningWork for the cloud ([#16095](https://github.com/Lightning-AI/lightning/pull/16095)) + + ### Changed - diff --git a/src/lightning_app/api/http_methods.py b/src/lightning_app/api/http_methods.py index 982e02d959..379e87cb68 100644 --- a/src/lightning_app/api/http_methods.py +++ b/src/lightning_app/api/http_methods.py @@ -8,7 +8,7 @@ from multiprocessing import Queue from typing import Any, Callable, Dict, List, Optional from uuid import uuid4 -from fastapi import FastAPI, HTTPException, Request +from fastapi import FastAPI, HTTPException, Request, status from lightning_utilities.core.apply_func import apply_to_collection from lightning_app.api.request_types import _APIRequest, _CommandRequest, _RequestResponse @@ -170,7 +170,10 @@ class _HttpMethod: while request_id not in responses_store: await asyncio.sleep(0.01) if (time.time() - t0) > self.timeout: - raise Exception("The response was never received.") + raise HTTPException( + status.HTTP_500_INTERNAL_SERVER_ERROR, + detail="The response was never received.", + ) logger.debug(f"Processed request {request_id} for route: {self.route}") diff --git a/src/lightning_app/core/work.py b/src/lightning_app/core/work.py index 029f01fd2f..12203d4db3 100644 --- a/src/lightning_app/core/work.py +++ b/src/lightning_app/core/work.py @@ -119,7 +119,16 @@ class LightningWork: " in the next version. Use `cache_calls` instead." ) self._cache_calls = run_once if run_once is not None else cache_calls - self._state = {"_host", "_port", "_url", "_future_url", "_internal_ip", "_restarting", "_cloud_compute"} + self._state = { + "_host", + "_port", + "_url", + "_future_url", + "_internal_ip", + "_restarting", + "_cloud_compute", + "_display_name", + } self._parallel = parallel self._host: str = host self._port: Optional[int] = port @@ -129,6 +138,7 @@ class LightningWork: # setattr_replacement is used by the multiprocessing runtime to send the latest changes to the main coordinator self._setattr_replacement: Optional[Callable[[str, Any], None]] = None self._name = "" + self._display_name = "" # The ``self._calls`` is used to track whether the run # method with a given set of input arguments has already been called. # Example of its usage: @@ -207,6 +217,22 @@ class LightningWork: """Returns the name of the LightningWork.""" return self._name + @property + def display_name(self): + """Returns the display name of the LightningWork in the cloud. + + The display name needs to set before the run method of the work is called. + """ + return self._display_name + + @display_name.setter + def display_name(self, display_name: str): + """Sets the display name of the LightningWork in the cloud.""" + if not self.has_started: + self._display_name = display_name + elif self._display_name != display_name: + raise RuntimeError("The display name can be set only before the work has started.") + @property def cache_calls(self) -> bool: """Returns whether the ``run`` method should cache its input arguments and not run again when provided with diff --git a/src/lightning_app/utilities/frontend.py b/src/lightning_app/utilities/frontend.py index 470036436a..afc5f21539 100644 --- a/src/lightning_app/utilities/frontend.py +++ b/src/lightning_app/utilities/frontend.py @@ -22,11 +22,12 @@ def update_index_file(ui_root: str, info: Optional[AppInfo] = None, root_path: s entry_file = Path(ui_root) / "index.html" original_file = Path(ui_root) / "index.original.html" - if not original_file.exists(): - shutil.copyfile(entry_file, original_file) # keep backup - else: - # revert index.html in case it was modified after creating original.html - shutil.copyfile(original_file, entry_file) + if root_path: + if not original_file.exists(): + shutil.copyfile(entry_file, original_file) # keep backup + else: + # revert index.html in case it was modified after creating original.html + shutil.copyfile(original_file, entry_file) if info: with original_file.open() as f: diff --git a/tests/tests_app/core/test_lightning_app.py b/tests/tests_app/core/test_lightning_app.py index ea552adad7..d397bb23e5 100644 --- a/tests/tests_app/core/test_lightning_app.py +++ b/tests/tests_app/core/test_lightning_app.py @@ -124,6 +124,7 @@ def test_simple_app(tmpdir): "_paths": {}, "_port": None, "_restarting": False, + "_display_name": "", }, "calls": {"latest_call_hash": None}, "changes": {}, @@ -140,6 +141,7 @@ def test_simple_app(tmpdir): "_paths": {}, "_port": None, "_restarting": False, + "_display_name": "", }, "calls": {"latest_call_hash": None}, "changes": {}, @@ -969,7 +971,7 @@ class SizeFlow(LightningFlow): def test_state_size_constant_growth(): app = LightningApp(SizeFlow()) MultiProcessRuntime(app, start_server=False).dispatch() - assert app.root._state_sizes[0] <= 7824 + assert app.root._state_sizes[0] <= 7888 assert app.root._state_sizes[20] <= 26500 diff --git a/tests/tests_app/core/test_lightning_flow.py b/tests/tests_app/core/test_lightning_flow.py index c8e9921f29..c2aa52b8e6 100644 --- a/tests/tests_app/core/test_lightning_flow.py +++ b/tests/tests_app/core/test_lightning_flow.py @@ -329,6 +329,7 @@ def test_lightning_flow_and_work(): "_paths": {}, "_restarting": False, "_internal_ip": "", + "_display_name": "", "_cloud_compute": { "type": "__cloud_compute__", "name": "default", @@ -352,6 +353,7 @@ def test_lightning_flow_and_work(): "_paths": {}, "_restarting": False, "_internal_ip": "", + "_display_name": "", "_cloud_compute": { "type": "__cloud_compute__", "name": "default", @@ -391,6 +393,7 @@ def test_lightning_flow_and_work(): "_paths": {}, "_restarting": False, "_internal_ip": "", + "_display_name": "", "_cloud_compute": { "type": "__cloud_compute__", "name": "default", @@ -414,6 +417,7 @@ def test_lightning_flow_and_work(): "_paths": {}, "_restarting": False, "_internal_ip": "", + "_display_name": "", "_cloud_compute": { "type": "__cloud_compute__", "name": "default", diff --git a/tests/tests_app/core/test_lightning_work.py b/tests/tests_app/core/test_lightning_work.py index cb97eabfa2..01f55b1f90 100644 --- a/tests/tests_app/core/test_lightning_work.py +++ b/tests/tests_app/core/test_lightning_work.py @@ -11,7 +11,7 @@ from lightning_app.runners import MultiProcessRuntime from lightning_app.storage import Path from lightning_app.testing.helpers import _MockQueue, EmptyFlow, EmptyWork from lightning_app.testing.testing import LightningTestApp -from lightning_app.utilities.enum import WorkStageStatus +from lightning_app.utilities.enum import make_status, WorkStageStatus from lightning_app.utilities.exceptions import LightningWorkException from lightning_app.utilities.packaging.build_config import BuildConfig from lightning_app.utilities.proxies import ProxyWorkRun, WorkRunner @@ -384,3 +384,24 @@ class FlowStart(LightningFlow): def test_lightning_app_work_start(cache_calls, parallel): app = LightningApp(FlowStart(cache_calls, parallel)) MultiProcessRuntime(app, start_server=False).dispatch() + + +class WorkDisplay(LightningWork): + def __init__(self): + super().__init__() + + def run(self): + pass + + +def test_lightning_work_display_name(): + work = WorkDisplay() + assert work.state_vars["vars"]["_display_name"] == "" + work.display_name = "Hello" + assert work.state_vars["vars"]["_display_name"] == "Hello" + + work._calls["latest_call_hash"] = "test" + work._calls["test"] = {"statuses": [make_status(WorkStageStatus.PENDING)]} + with pytest.raises(RuntimeError, match="The display name can be set only before the work has started."): + work.display_name = "HELLO" + work.display_name = "Hello" diff --git a/tests/tests_app/structures/test_structures.py b/tests/tests_app/structures/test_structures.py index 3346da5a85..852589a444 100644 --- a/tests/tests_app/structures/test_structures.py +++ b/tests/tests_app/structures/test_structures.py @@ -44,6 +44,7 @@ def test_dict(): "_host": "127.0.0.1", "_paths": {}, "_restarting": False, + "_display_name": "", "_internal_ip": "", "_cloud_compute": { "type": "__cloud_compute__", @@ -76,6 +77,7 @@ def test_dict(): "_host": "127.0.0.1", "_paths": {}, "_restarting": False, + "_display_name": "", "_internal_ip": "", "_cloud_compute": { "type": "__cloud_compute__", @@ -108,6 +110,7 @@ def test_dict(): "_host": "127.0.0.1", "_paths": {}, "_restarting": False, + "_display_name": "", "_internal_ip": "", "_cloud_compute": { "type": "__cloud_compute__", @@ -193,6 +196,7 @@ def test_list(): "_paths": {}, "_restarting": False, "_internal_ip": "", + "_display_name": "", "_cloud_compute": { "type": "__cloud_compute__", "name": "default", @@ -225,6 +229,7 @@ def test_list(): "_paths": {}, "_restarting": False, "_internal_ip": "", + "_display_name": "", "_cloud_compute": { "type": "__cloud_compute__", "name": "default", @@ -252,6 +257,7 @@ def test_list(): "_paths": {}, "_restarting": False, "_internal_ip": "", + "_display_name": "", "_cloud_compute": { "type": "__cloud_compute__", "name": "default",