From b936fd4380cee0880f430255f63f9f68831d3db7 Mon Sep 17 00:00:00 2001 From: thomas chaton Date: Tue, 4 Oct 2022 21:46:44 +0200 Subject: [PATCH] [app] Add CloudCompute ID serializable within the flow and works state (#14819) --- .gitignore | 3 +- examples/app_layout/requirements.txt | 1 + .../app_works_on_default_machine/.lightning | 1 + .../app_works_on_default_machine/app_v2.py | 53 ++++++++++++ .../requirements.txt | 1 + requirements/app/base.txt | 2 +- src/lightning_app/CHANGELOG.md | 2 + src/lightning_app/core/constants.py | 10 ++- src/lightning_app/core/flow.py | 9 ++ src/lightning_app/core/work.py | 26 +++++- .../runners/backends/mp_process.py | 4 +- src/lightning_app/runners/cloud.py | 31 ++++++- src/lightning_app/utilities/component.py | 5 ++ src/lightning_app/utilities/layout.py | 2 + src/lightning_app/utilities/network.py | 4 +- .../utilities/packaging/cloud_compute.py | 69 +++++++++++++++- .../utilities/packaging/lightning_utils.py | 10 ++- tests/tests_app/conftest.py | 2 + tests/tests_app/core/test_lightning_app.py | 82 +++++++++++++++++-- tests/tests_app/core/test_lightning_flow.py | 44 ++++++++++ tests/tests_app/structures/test_structures.py | 66 +++++++++++++++ .../packaging/test_lightning_utils.py | 5 +- tests/tests_app/utilities/test_load_app.py | 6 +- 23 files changed, 409 insertions(+), 29 deletions(-) create mode 100644 examples/app_layout/requirements.txt create mode 100644 examples/app_works_on_default_machine/.lightning create mode 100644 examples/app_works_on_default_machine/app_v2.py create mode 100644 examples/app_works_on_default_machine/requirements.txt diff --git a/.gitignore b/.gitignore index c308ce2620..b6eae6055a 100644 --- a/.gitignore +++ b/.gitignore @@ -112,7 +112,8 @@ celerybeat-schedule # dotenv .env -.env_stagging +.env_staging +.env_local # virtualenv .venv diff --git a/examples/app_layout/requirements.txt b/examples/app_layout/requirements.txt new file mode 100644 index 0000000000..12a4706528 --- /dev/null +++ b/examples/app_layout/requirements.txt @@ -0,0 +1 @@ +streamlit diff --git a/examples/app_works_on_default_machine/.lightning b/examples/app_works_on_default_machine/.lightning new file mode 100644 index 0000000000..2eb65c5108 --- /dev/null +++ b/examples/app_works_on_default_machine/.lightning @@ -0,0 +1 @@ +name: app_works_on_default_machine diff --git a/examples/app_works_on_default_machine/app_v2.py b/examples/app_works_on_default_machine/app_v2.py new file mode 100644 index 0000000000..f1d3c36d2a --- /dev/null +++ b/examples/app_works_on_default_machine/app_v2.py @@ -0,0 +1,53 @@ +from fastapi import FastAPI +from fastapi.middleware.cors import CORSMiddleware +from uvicorn import run + +from lightning import CloudCompute, LightningApp, LightningFlow, LightningWork + + +class Work(LightningWork): + def __init__(self, **kwargs): + super().__init__(parallel=True, **kwargs) + + def run(self): + fastapi_service = FastAPI() + + fastapi_service.add_middleware( + CORSMiddleware, + allow_origins=["*"], + allow_credentials=True, + allow_methods=["*"], + allow_headers=["*"], + ) + + @fastapi_service.get("/") + def get_root(): + return {"Hello Word!"} + + run(fastapi_service, host=self.host, port=self.port) + + +class Flow(LightningFlow): + def __init__(self): + super().__init__() + # In the Cloud: All the works defined without passing explicitly a CloudCompute object + # are running on the default machine. + # This would apply to `work_a`, `work_b` and the dynamically created `work_d`. + + self.work_a = Work() + self.work_b = Work() + + self.work_c = Work(cloud_compute=CloudCompute(name="cpu-small")) + + def run(self): + if not hasattr(self, "work_d"): + self.work_d = Work() + + for work in self.works(): + work.run() + + def configure_layout(self): + return [{"name": w.name, "content": w} for i, w in enumerate(self.works())] + + +app = LightningApp(Flow(), debug=True) diff --git a/examples/app_works_on_default_machine/requirements.txt b/examples/app_works_on_default_machine/requirements.txt new file mode 100644 index 0000000000..12a4706528 --- /dev/null +++ b/examples/app_works_on_default_machine/requirements.txt @@ -0,0 +1 @@ +streamlit diff --git a/requirements/app/base.txt b/requirements/app/base.txt index 7d82b0e461..8fcf323f58 100644 --- a/requirements/app/base.txt +++ b/requirements/app/base.txt @@ -1,4 +1,4 @@ -lightning-cloud==0.5.7 +lightning-cloud>=0.5.7 packaging deepdiff>=5.7.0, <=5.8.1 starsessions>=1.2.1, <2.0 # strict diff --git a/src/lightning_app/CHANGELOG.md b/src/lightning_app/CHANGELOG.md index d7c11179e4..85a6a4a532 100644 --- a/src/lightning_app/CHANGELOG.md +++ b/src/lightning_app/CHANGELOG.md @@ -15,6 +15,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Add `--secret` option to CLI to allow binding Secrets to app environment variables when running in the cloud ([#14612](https://github.com/Lightning-AI/lightning/pull/14612)) +- Add support for running the works without cloud compute in the default container ([#14819](https://github.com/Lightning-AI/lightning/pull/14819)) + ### Changed - diff --git a/src/lightning_app/core/constants.py b/src/lightning_app/core/constants.py index 78400abaff..dc38064223 100644 --- a/src/lightning_app/core/constants.py +++ b/src/lightning_app/core/constants.py @@ -47,5 +47,13 @@ DEBUG_ENABLED = bool(int(os.getenv("LIGHTNING_DEBUG", "0"))) LIGHTNING_COMPONENT_PUBLIC_REGISTRY = "https://lightning.ai/v1/components" LIGHTNING_APPS_PUBLIC_REGISTRY = "https://lightning.ai/v1/apps" ENABLE_STATE_WEBSOCKET = bool(int(os.getenv("ENABLE_STATE_WEBSOCKET", "0"))) - DEBUG: bool = lightning_cloud.env.DEBUG + +# EXPERIMENTAL: ENV VARIABLES TO ENABLE MULTIPLE WORKS IN THE SAME MACHINE +DEFAULT_NUMBER_OF_EXPOSED_PORTS = int(os.getenv("DEFAULT_NUMBER_OF_EXPOSED_PORTS", "50")) +ENABLE_MULTIPLE_WORKS_IN_DEFAULT_CONTAINER = bool( + int(os.getenv("ENABLE_MULTIPLE_WORKS_IN_DEFAULT_CONTAINER", "0")) +) # Note: This is disabled for the time being. +ENABLE_MULTIPLE_WORKS_IN_NON_DEFAULT_CONTAINER = bool( + int(os.getenv("ENABLE_MULTIPLE_WORKS_IN_NON_DEFAULT_CONTAINER", "0")) +) # This isn't used in the cloud yet. diff --git a/src/lightning_app/core/flow.py b/src/lightning_app/core/flow.py index 3d265c49e2..9fe723e2b9 100644 --- a/src/lightning_app/core/flow.py +++ b/src/lightning_app/core/flow.py @@ -14,6 +14,7 @@ from lightning_app.utilities.app_helpers import _is_json_serializable, _Lightnin from lightning_app.utilities.component import _sanitize_state from lightning_app.utilities.exceptions import ExitAppException from lightning_app.utilities.introspection import _is_init_context, _is_run_context +from lightning_app.utilities.packaging.cloud_compute import _maybe_create_cloud_compute, CloudCompute class LightningFlow: @@ -145,6 +146,8 @@ class LightningFlow: # Attach the backend to the flow and its children work. if self._backend: LightningFlow._attach_backend(value, self._backend) + for work in value.works(): + work._register_cloud_compute() elif isinstance(value, LightningWork): self._works.add(name) @@ -153,6 +156,7 @@ class LightningFlow: self._state.remove(name) if self._backend: self._backend._wrap_run_method(_LightningAppRef().get_current(), value) + value._register_cloud_compute() elif isinstance(value, (Dict, List)): value._backend = self._backend @@ -177,6 +181,9 @@ class LightningFlow: value.component_name = self.name self._state.add(name) + elif isinstance(value, CloudCompute): + self._state.add(name) + elif _is_json_serializable(value): self._state.add(name) @@ -320,6 +327,8 @@ class LightningFlow: for k, v in provided_state["vars"].items(): if isinstance(v, Dict): v = _maybe_create_drive(self.name, v) + if isinstance(v, Dict): + v = _maybe_create_cloud_compute(v) setattr(self, k, v) self._changes = provided_state["changes"] self._calls.update(provided_state["calls"]) diff --git a/src/lightning_app/core/work.py b/src/lightning_app/core/work.py index 6bc299a8b4..26395672e8 100644 --- a/src/lightning_app/core/work.py +++ b/src/lightning_app/core/work.py @@ -24,7 +24,12 @@ from lightning_app.utilities.exceptions import LightningWorkException from lightning_app.utilities.introspection import _is_init_context from lightning_app.utilities.network import find_free_network_port from lightning_app.utilities.packaging.build_config import BuildConfig -from lightning_app.utilities.packaging.cloud_compute import CloudCompute +from lightning_app.utilities.packaging.cloud_compute import ( + _CLOUD_COMPUTE_STORE, + _CloudComputeStore, + _maybe_create_cloud_compute, + CloudCompute, +) from lightning_app.utilities.proxies import LightningWorkSetAttrProxy, ProxyWorkRun, unwrap @@ -103,7 +108,7 @@ class LightningWork: " in the next version. Use `cache_calls` instead." ) self._cache_calls = run_once if run_once is not None else cache_calls - self._state = {"_host", "_port", "_url", "_future_url", "_internal_ip", "_restarting"} + self._state = {"_host", "_port", "_url", "_future_url", "_internal_ip", "_restarting", "_cloud_compute"} self._parallel = parallel self._host: str = host self._port: Optional[int] = port @@ -226,8 +231,14 @@ class LightningWork: return self._cloud_compute @cloud_compute.setter - def cloud_compute(self, cloud_compute) -> None: + def cloud_compute(self, cloud_compute: CloudCompute) -> None: """Returns the cloud compute used to select the cloud hardware.""" + # A new ID + current_id = self._cloud_compute.id + new_id = cloud_compute.id + if current_id != new_id: + compute_store: _CloudComputeStore = _CLOUD_COMPUTE_STORE[current_id] + compute_store.remove(self.name) self._cloud_compute = cloud_compute @property @@ -488,6 +499,8 @@ class LightningWork: for k, v in provided_state["vars"].items(): if isinstance(v, Dict): v = _maybe_create_drive(self.name, v) + if isinstance(v, Dict): + v = _maybe_create_cloud_compute(v) setattr(self, k, v) self._changes = provided_state["changes"] @@ -575,3 +588,10 @@ class LightningWork: f"The work `{self.__class__.__name__}` is missing the `run()` method. This is required. Implement it" " first and then call it in your Flow." ) + + def _register_cloud_compute(self): + internal_id = self.cloud_compute.id + assert internal_id + if internal_id not in _CLOUD_COMPUTE_STORE: + _CLOUD_COMPUTE_STORE[internal_id] = _CloudComputeStore(id=internal_id, component_names=[]) + _CLOUD_COMPUTE_STORE[internal_id].add_component_name(self.name) diff --git a/src/lightning_app/runners/backends/mp_process.py b/src/lightning_app/runners/backends/mp_process.py index a1667422b2..9365e6ac1d 100644 --- a/src/lightning_app/runners/backends/mp_process.py +++ b/src/lightning_app/runners/backends/mp_process.py @@ -74,8 +74,8 @@ class MultiProcessingBackend(Backend): and work._url == "" and work._port ): - url = f"http://{work._host}:{work._port}" - if _check_service_url_is_ready(url): + url = work._future_url if work._future_url else f"http://{work._host}:{work._port}" + if _check_service_url_is_ready(url, metadata=f"Checking {work.name}"): work._url = url def stop_work(self, app, work: "lightning_app.LightningWork") -> None: diff --git a/src/lightning_app/runners/cloud.py b/src/lightning_app/runners/cloud.py index 38083d9507..8f3e4a0b50 100644 --- a/src/lightning_app/runners/cloud.py +++ b/src/lightning_app/runners/cloud.py @@ -40,7 +40,13 @@ from lightning_cloud.openapi import ( from lightning_cloud.openapi.rest import ApiException from lightning_app.core.app import LightningApp -from lightning_app.core.constants import CLOUD_UPLOAD_WARNING, DISABLE_DEPENDENCY_CACHE +from lightning_app.core.constants import ( + CLOUD_UPLOAD_WARNING, + DEFAULT_NUMBER_OF_EXPOSED_PORTS, + DISABLE_DEPENDENCY_CACHE, + ENABLE_MULTIPLE_WORKS_IN_DEFAULT_CONTAINER, + ENABLE_MULTIPLE_WORKS_IN_NON_DEFAULT_CONTAINER, +) from lightning_app.runners.backends.cloud import CloudBackend from lightning_app.runners.runtime import Runtime from lightning_app.source_code import LocalSourceCodeDir @@ -108,6 +114,12 @@ class CloudRuntime(Runtime): ] v1_env_vars.extend(env_vars_from_secrets) + if ENABLE_MULTIPLE_WORKS_IN_DEFAULT_CONTAINER: + v1_env_vars.append(V1EnvVar(name="ENABLE_MULTIPLE_WORKS_IN_DEFAULT_CONTAINER", value="1")) + + if ENABLE_MULTIPLE_WORKS_IN_NON_DEFAULT_CONTAINER: + v1_env_vars.append(V1EnvVar(name="ENABLE_MULTIPLE_WORKS_IN_NON_DEFAULT_CONTAINER", value="1")) + work_reqs: List[V1Work] = [] for flow in self.app.flows: for work in flow.works(recurse=False): @@ -184,6 +196,7 @@ class CloudRuntime(Runtime): desired_state=V1LightningappInstanceState.RUNNING, env=v1_env_vars, ) + # if requirements file at the root of the repository is present, # we pass just the file name to the backend, so backend can find it in the relative path if requirements_file.is_file(): @@ -219,6 +232,22 @@ class CloudRuntime(Runtime): local_source=True, dependency_cache_key=app_spec.dependency_cache_key, ) + + if ENABLE_MULTIPLE_WORKS_IN_DEFAULT_CONTAINER: + network_configs: List[V1NetworkConfig] = [] + + initial_port = 8080 + 1 + len(frontend_specs) + for _ in range(DEFAULT_NUMBER_OF_EXPOSED_PORTS): + network_configs.append( + V1NetworkConfig( + name="w" + str(initial_port), + port=initial_port, + ) + ) + initial_port += 1 + + release_body.network_config = network_configs + if cluster_id is not None: self._ensure_cluster_project_binding(project.project_id, cluster_id) diff --git a/src/lightning_app/utilities/component.py b/src/lightning_app/utilities/component.py index f7fd0ff850..0e79854e18 100644 --- a/src/lightning_app/utilities/component.py +++ b/src/lightning_app/utilities/component.py @@ -7,6 +7,7 @@ from lightning_utilities.core.apply_func import apply_to_collection from lightning_app.utilities.app_helpers import is_overridden from lightning_app.utilities.enum import ComponentContext +from lightning_app.utilities.packaging.cloud_compute import CloudCompute from lightning_app.utilities.tree import breadth_first if TYPE_CHECKING: @@ -52,9 +53,13 @@ def _sanitize_state(state: Dict[str, Any]) -> Dict[str, Any]: def sanitize_drive(drive: Drive) -> Dict: return drive.to_dict() + def sanitize_cloud_compute(cloud_compute: CloudCompute) -> Dict: + return cloud_compute.to_dict() + state = apply_to_collection(state, dtype=Path, function=sanitize_path) state = apply_to_collection(state, dtype=BasePayload, function=sanitize_payload) state = apply_to_collection(state, dtype=Drive, function=sanitize_drive) + state = apply_to_collection(state, dtype=CloudCompute, function=sanitize_cloud_compute) return state diff --git a/src/lightning_app/utilities/layout.py b/src/lightning_app/utilities/layout.py index bffc56e919..ca12ab8b7a 100644 --- a/src/lightning_app/utilities/layout.py +++ b/src/lightning_app/utilities/layout.py @@ -92,8 +92,10 @@ def _collect_content_layout(layout: List[Dict], flow: "lightning_app.LightningFl f"You configured an http link {url[:32]}... but it won't be accessible in the cloud." f" Consider replacing 'http' with 'https' in the link above." ) + elif isinstance(entry["content"], lightning_app.LightningFlow): entry["content"] = entry["content"].name + elif isinstance(entry["content"], lightning_app.LightningWork): if entry["content"].url and not entry["content"].url.startswith("/"): entry["content"] = entry["content"].url diff --git a/src/lightning_app/utilities/network.py b/src/lightning_app/utilities/network.py index 4ed4ea1591..4012c2513d 100644 --- a/src/lightning_app/utilities/network.py +++ b/src/lightning_app/utilities/network.py @@ -49,12 +49,12 @@ def _configure_session() -> Session: return http -def _check_service_url_is_ready(url: str, timeout: float = 5) -> bool: +def _check_service_url_is_ready(url: str, timeout: float = 5, metadata="") -> bool: try: response = requests.get(url, timeout=timeout) return response.status_code in (200, 404) except (ConnectionError, ConnectTimeout, ReadTimeout): - logger.debug(f"The url {url} is not ready.") + logger.debug(f"The url {url} is not ready. {metadata}") return False diff --git a/src/lightning_app/utilities/packaging/cloud_compute.py b/src/lightning_app/utilities/packaging/cloud_compute.py index 6527911855..5832d809c1 100644 --- a/src/lightning_app/utilities/packaging/cloud_compute.py +++ b/src/lightning_app/utilities/packaging/cloud_compute.py @@ -1,5 +1,48 @@ from dataclasses import asdict, dataclass -from typing import List, Optional, Union +from typing import Dict, List, Optional, Union +from uuid import uuid4 + +from lightning_app.core.constants import ENABLE_MULTIPLE_WORKS_IN_NON_DEFAULT_CONTAINER + +__CLOUD_COMPUTE_IDENTIFIER__ = "__cloud_compute__" + + +@dataclass +class _CloudComputeStore: + id: str + component_names: List[str] + + def add_component_name(self, new_component_name: str) -> None: + found_index = None + # When the work is being named by the flow, pop its previous names + for index, component_name in enumerate(self.component_names): + if new_component_name.endswith(component_name.replace("root.", "")): + found_index = index + + if found_index is not None: + self.component_names[found_index] = new_component_name + else: + if ( + len(self.component_names) == 1 + and not ENABLE_MULTIPLE_WORKS_IN_NON_DEFAULT_CONTAINER + and self.id != "default" + ): + raise Exception( + f"A Cloud Compute can be assigned only to a single Work. Attached to {self.component_names[0]}" + ) + self.component_names.append(new_component_name) + + def remove(self, new_component_name: str) -> None: + found_index = None + for index, component_name in enumerate(self.component_names): + if new_component_name == component_name: + found_index = index + + if found_index is not None: + del self.component_names[found_index] + + +_CLOUD_COMPUTE_STORE = {} @dataclass @@ -43,6 +86,7 @@ class CloudCompute: wait_timeout: Optional[int] = None idle_timeout: Optional[int] = None shm_size: Optional[int] = 0 + _internal_id: Optional[str] = None def __post_init__(self): if self.clusters: @@ -52,9 +96,28 @@ class CloudCompute: self.name = self.name.lower() + # All `default` CloudCompute are identified in the same way. + if self._internal_id is None: + self._internal_id = "default" if self.name == "default" else uuid4().hex[:7] + def to_dict(self): - return {"__cloud_compute__": asdict(self)} + return {"type": __CLOUD_COMPUTE_IDENTIFIER__, **asdict(self)} @classmethod def from_dict(cls, d): - return cls(**d["__cloud_compute__"]) + assert d.pop("type") == __CLOUD_COMPUTE_IDENTIFIER__ + return cls(**d) + + @property + def id(self) -> Optional[str]: + return self._internal_id + + def is_default(self) -> bool: + return self.name == "default" + + +def _maybe_create_cloud_compute(state: Dict) -> Union[CloudCompute, Dict]: + if state and __CLOUD_COMPUTE_IDENTIFIER__ == state.get("type", None): + cloud_compute = CloudCompute.from_dict(state) + return cloud_compute + return state diff --git a/src/lightning_app/utilities/packaging/lightning_utils.py b/src/lightning_app/utilities/packaging/lightning_utils.py index f0e87f63e6..ed81b2bff1 100644 --- a/src/lightning_app/utilities/packaging/lightning_utils.py +++ b/src/lightning_app/utilities/packaging/lightning_utils.py @@ -110,7 +110,7 @@ def _prepare_lightning_wheels_and_requirements(root: Path) -> Optional[Callable] download_frontend(_PROJECT_ROOT) _prepare_wheel(_PROJECT_ROOT) - logger.info("Packaged Lightning with your application.") + logger.info(f"Packaged Lightning with your application. Version: {version}") tar_name = _copy_tar(_PROJECT_ROOT, root) @@ -121,7 +121,9 @@ def _prepare_lightning_wheels_and_requirements(root: Path) -> Optional[Callable] # building and copying launcher wheel if installed in editable mode launcher_project_path = get_dist_path_if_editable_install("lightning_launcher") if launcher_project_path: - logger.info("Packaged Lightning Launcher with your application.") + from lightning_launcher.__version__ import __version__ as launcher_version + + logger.info(f"Packaged Lightning Launcher with your application. Version: {launcher_version}") _prepare_wheel(launcher_project_path) tar_name = _copy_tar(launcher_project_path, root) tar_files.append(os.path.join(root, tar_name)) @@ -129,7 +131,9 @@ def _prepare_lightning_wheels_and_requirements(root: Path) -> Optional[Callable] # building and copying lightning-cloud wheel if installed in editable mode lightning_cloud_project_path = get_dist_path_if_editable_install("lightning_cloud") if lightning_cloud_project_path: - logger.info("Packaged Lightning Cloud with your application.") + from lightning_cloud.__version__ import __version__ as cloud_version + + logger.info(f"Packaged Lightning Cloud with your application. Version: {cloud_version}") _prepare_wheel(lightning_cloud_project_path) tar_name = _copy_tar(lightning_cloud_project_path, root) tar_files.append(os.path.join(root, tar_name)) diff --git a/tests/tests_app/conftest.py b/tests/tests_app/conftest.py index af0071322b..e8b887637e 100644 --- a/tests/tests_app/conftest.py +++ b/tests/tests_app/conftest.py @@ -11,6 +11,7 @@ from tests_app import _PROJECT_ROOT from lightning_app.storage.path import storage_root_dir from lightning_app.utilities.component import _set_context +from lightning_app.utilities.packaging import cloud_compute from lightning_app.utilities.packaging.app_config import _APP_CONFIG_FILENAME from lightning_app.utilities.state import AppState @@ -74,6 +75,7 @@ def clear_app_state_state_variables(): lightning_app.utilities.state._STATE = None lightning_app.utilities.state._LAST_STATE = None AppState._MY_AFFILIATION = () + cloud_compute._CLOUD_COMPUTE_STORE.clear() @pytest.fixture diff --git a/tests/tests_app/core/test_lightning_app.py b/tests/tests_app/core/test_lightning_app.py index 93c475700d..2f42643297 100644 --- a/tests/tests_app/core/test_lightning_app.py +++ b/tests/tests_app/core/test_lightning_app.py @@ -11,7 +11,7 @@ from deepdiff import Delta from pympler import asizeof from tests_app import _PROJECT_ROOT -from lightning_app import LightningApp, LightningFlow, LightningWork # F401 +from lightning_app import CloudCompute, LightningApp, LightningFlow, LightningWork # F401 from lightning_app.core.constants import ( FLOW_DURATION_SAMPLES, FLOW_DURATION_THRESHOLD, @@ -27,6 +27,7 @@ from lightning_app.testing.helpers import RunIf from lightning_app.testing.testing import LightningTestApp from lightning_app.utilities.app_helpers import affiliation from lightning_app.utilities.enum import AppStage, WorkStageStatus, WorkStopReasons +from lightning_app.utilities.packaging import cloud_compute from lightning_app.utilities.redis import check_if_redis_running from lightning_app.utilities.warnings import LightningFlowWarning @@ -255,7 +256,7 @@ def test_nested_component(runtime_cls): assert app.root.b.c.d.e.w_e.c == 1 -class WorkCC(LightningWork): +class WorkCCC(LightningWork): def run(self): pass @@ -263,7 +264,7 @@ class WorkCC(LightningWork): class CC(LightningFlow): def __init__(self): super().__init__() - self.work_cc = WorkCC() + self.work_cc = WorkCCC() def run(self): pass @@ -719,7 +720,7 @@ class WorkDD(LightningWork): self.counter += 1 -class FlowCC(LightningFlow): +class FlowCCTolerance(LightningFlow): def __init__(self): super().__init__() self.work = WorkDD() @@ -744,7 +745,7 @@ class FaultToleranceLightningTestApp(LightningTestApp): # TODO (tchaton) Resolve this test with Resumable App. @RunIf(skip_windows=True) def test_fault_tolerance_work(): - app = FaultToleranceLightningTestApp(FlowCC()) + app = FaultToleranceLightningTestApp(FlowCCTolerance()) MultiProcessRuntime(app, start_server=False).dispatch() assert app.root.work.counter == 2 @@ -952,8 +953,8 @@ class SizeFlow(LightningFlow): def test_state_size_constant_growth(): app = LightningApp(SizeFlow()) MultiProcessRuntime(app, start_server=False).dispatch() - assert app.root._state_sizes[0] <= 5904 - assert app.root._state_sizes[20] <= 23736 + assert app.root._state_sizes[0] <= 6952 + assert app.root._state_sizes[20] <= 24896 class FlowUpdated(LightningFlow): @@ -1041,3 +1042,70 @@ class TestLightningHasUpdatedApp(LightningApp): def test_lightning_app_has_updated(): app = TestLightningHasUpdatedApp(FlowPath()) MultiProcessRuntime(app, start_server=False).dispatch() + + +class WorkCC(LightningWork): + def run(self): + pass + + +class FlowCC(LightningFlow): + def __init__(self): + super().__init__() + self.cloud_compute = CloudCompute(name="gpu", _internal_id="a") + self.work_a = WorkCC(cloud_compute=self.cloud_compute) + self.work_b = WorkCC(cloud_compute=self.cloud_compute) + self.work_c = WorkCC() + assert self.work_a.cloud_compute._internal_id == self.work_b.cloud_compute._internal_id + + def run(self): + self.work_d = WorkCC() + + +class FlowWrapper(LightningFlow): + def __init__(self, flow): + super().__init__() + self.w = flow + + +def test_cloud_compute_binding(): + + cloud_compute.ENABLE_MULTIPLE_WORKS_IN_NON_DEFAULT_CONTAINER = True + + assert cloud_compute._CLOUD_COMPUTE_STORE == {} + flow = FlowCC() + assert len(cloud_compute._CLOUD_COMPUTE_STORE) == 2 + assert cloud_compute._CLOUD_COMPUTE_STORE["default"].component_names == ["root.work_c"] + assert cloud_compute._CLOUD_COMPUTE_STORE["a"].component_names == ["root.work_a", "root.work_b"] + + wrapper = FlowWrapper(flow) + assert cloud_compute._CLOUD_COMPUTE_STORE["default"].component_names == ["root.w.work_c"] + assert cloud_compute._CLOUD_COMPUTE_STORE["a"].component_names == ["root.w.work_a", "root.w.work_b"] + + _ = FlowWrapper(wrapper) + assert cloud_compute._CLOUD_COMPUTE_STORE["default"].component_names == ["root.w.w.work_c"] + assert cloud_compute._CLOUD_COMPUTE_STORE["a"].component_names == ["root.w.w.work_a", "root.w.w.work_b"] + + assert "__cloud_compute__" == flow.state["vars"]["cloud_compute"]["type"] + assert "__cloud_compute__" == flow.work_a.state["vars"]["_cloud_compute"]["type"] + assert "__cloud_compute__" == flow.work_b.state["vars"]["_cloud_compute"]["type"] + assert "__cloud_compute__" == flow.work_c.state["vars"]["_cloud_compute"]["type"] + work_a_id = flow.work_a.state["vars"]["_cloud_compute"]["_internal_id"] + work_b_id = flow.work_b.state["vars"]["_cloud_compute"]["_internal_id"] + work_c_id = flow.work_c.state["vars"]["_cloud_compute"]["_internal_id"] + assert work_a_id == work_b_id + assert work_a_id != work_c_id + assert work_c_id == "default" + + flow.work_a.cloud_compute = CloudCompute(name="something_else") + assert cloud_compute._CLOUD_COMPUTE_STORE["a"].component_names == ["root.w.w.work_b"] + + flow.set_state(flow.state) + assert isinstance(flow.cloud_compute, CloudCompute) + assert isinstance(flow.work_a.cloud_compute, CloudCompute) + assert isinstance(flow.work_c.cloud_compute, CloudCompute) + + cloud_compute.ENABLE_MULTIPLE_WORKS_IN_NON_DEFAULT_CONTAINER = False + + with pytest.raises(Exception, match="A Cloud Compute can be assigned only to a single Work"): + FlowCC() diff --git a/tests/tests_app/core/test_lightning_flow.py b/tests/tests_app/core/test_lightning_flow.py index 489def49fc..591c05e136 100644 --- a/tests/tests_app/core/test_lightning_flow.py +++ b/tests/tests_app/core/test_lightning_flow.py @@ -325,6 +325,17 @@ def test_lightning_flow_and_work(): "_paths": {}, "_restarting": False, "_internal_ip": "", + "_cloud_compute": { + "type": "__cloud_compute__", + "name": "default", + "disk_size": 0, + "clusters": None, + "preemptible": False, + "wait_timeout": None, + "idle_timeout": None, + "shm_size": 0, + "_internal_id": "default", + }, }, "calls": {CacheCallsKeys.LATEST_CALL_HASH: None}, "changes": {}, @@ -339,6 +350,17 @@ def test_lightning_flow_and_work(): "_paths": {}, "_restarting": False, "_internal_ip": "", + "_cloud_compute": { + "type": "__cloud_compute__", + "name": "default", + "disk_size": 0, + "clusters": None, + "preemptible": False, + "wait_timeout": None, + "idle_timeout": None, + "shm_size": 0, + "_internal_id": "default", + }, }, "calls": {CacheCallsKeys.LATEST_CALL_HASH: None}, "changes": {}, @@ -369,6 +391,17 @@ def test_lightning_flow_and_work(): "_paths": {}, "_restarting": False, "_internal_ip": "", + "_cloud_compute": { + "type": "__cloud_compute__", + "name": "default", + "disk_size": 0, + "clusters": None, + "preemptible": False, + "wait_timeout": None, + "idle_timeout": None, + "shm_size": 0, + "_internal_id": "default", + }, }, "calls": {CacheCallsKeys.LATEST_CALL_HASH: None}, "changes": {}, @@ -383,6 +416,17 @@ def test_lightning_flow_and_work(): "_paths": {}, "_restarting": False, "_internal_ip": "", + "_cloud_compute": { + "type": "__cloud_compute__", + "name": "default", + "disk_size": 0, + "clusters": None, + "preemptible": False, + "wait_timeout": None, + "idle_timeout": None, + "shm_size": 0, + "_internal_id": "default", + }, }, "calls": { CacheCallsKeys.LATEST_CALL_HASH: None, diff --git a/tests/tests_app/structures/test_structures.py b/tests/tests_app/structures/test_structures.py index 18a6d372bf..73f656a39c 100644 --- a/tests/tests_app/structures/test_structures.py +++ b/tests/tests_app/structures/test_structures.py @@ -45,6 +45,17 @@ def test_dict(): "_paths": {}, "_restarting": False, "_internal_ip": "", + "_cloud_compute": { + "type": "__cloud_compute__", + "name": "default", + "disk_size": 0, + "clusters": None, + "preemptible": False, + "wait_timeout": None, + "idle_timeout": None, + "shm_size": 0, + "_internal_id": "default", + }, } for k in ("a", "b", "c", "d") ) @@ -68,6 +79,17 @@ def test_dict(): "_paths": {}, "_restarting": False, "_internal_ip": "", + "_cloud_compute": { + "type": "__cloud_compute__", + "name": "default", + "disk_size": 0, + "clusters": None, + "preemptible": False, + "wait_timeout": None, + "idle_timeout": None, + "shm_size": 0, + "_internal_id": "default", + }, } for k in ("a", "b", "c", "d") ) @@ -91,6 +113,17 @@ def test_dict(): "_paths": {}, "_restarting": False, "_internal_ip": "", + "_cloud_compute": { + "type": "__cloud_compute__", + "name": "default", + "disk_size": 0, + "clusters": None, + "preemptible": False, + "wait_timeout": None, + "idle_timeout": None, + "shm_size": 0, + "_internal_id": "default", + }, } for k in ("a", "b", "c", "d") ) @@ -166,6 +199,17 @@ def test_list(): "_paths": {}, "_restarting": False, "_internal_ip": "", + "_cloud_compute": { + "type": "__cloud_compute__", + "name": "default", + "disk_size": 0, + "clusters": None, + "preemptible": False, + "wait_timeout": None, + "idle_timeout": None, + "shm_size": 0, + "_internal_id": "default", + }, } for i in range(4) ) @@ -189,6 +233,17 @@ def test_list(): "_paths": {}, "_restarting": False, "_internal_ip": "", + "_cloud_compute": { + "type": "__cloud_compute__", + "name": "default", + "disk_size": 0, + "clusters": None, + "preemptible": False, + "wait_timeout": None, + "idle_timeout": None, + "shm_size": 0, + "_internal_id": "default", + }, } for i in range(4) ) @@ -207,6 +262,17 @@ def test_list(): "_paths": {}, "_restarting": False, "_internal_ip": "", + "_cloud_compute": { + "type": "__cloud_compute__", + "name": "default", + "disk_size": 0, + "clusters": None, + "preemptible": False, + "wait_timeout": None, + "idle_timeout": None, + "shm_size": 0, + "_internal_id": "default", + }, } for i in range(4) ) diff --git a/tests/tests_app/utilities/packaging/test_lightning_utils.py b/tests/tests_app/utilities/packaging/test_lightning_utils.py index 8f30aa21dd..4527a6309b 100644 --- a/tests/tests_app/utilities/packaging/test_lightning_utils.py +++ b/tests/tests_app/utilities/packaging/test_lightning_utils.py @@ -3,7 +3,6 @@ from unittest import mock import pytest -from lightning.__version__ import version from lightning_app.testing.helpers import RunIf from lightning_app.utilities.packaging import lightning_utils from lightning_app.utilities.packaging.lightning_utils import ( @@ -16,8 +15,10 @@ def test_prepare_lightning_wheels_and_requirement(tmpdir): """This test ensures the lightning source gets packaged inside the lightning repo.""" cleanup_handle = _prepare_lightning_wheels_and_requirements(tmpdir) + from lightning.__version__ import version + tar_name = f"lightning-{version}.tar.gz" - assert sorted(os.listdir(tmpdir)) == [tar_name] + assert sorted(os.listdir(tmpdir))[0] == tar_name cleanup_handle() assert os.listdir(tmpdir) == [] diff --git a/tests/tests_app/utilities/test_load_app.py b/tests/tests_app/utilities/test_load_app.py index 09af874cbd..dd5d8b136d 100644 --- a/tests/tests_app/utilities/test_load_app.py +++ b/tests/tests_app/utilities/test_load_app.py @@ -49,7 +49,7 @@ def test_extract_metadata_from_component(): "docstring": "WorkA.", "local_build_config": {"__build_config__": ANY}, "cloud_build_config": {"__build_config__": ANY}, - "cloud_compute": {"__cloud_compute__": ANY}, + "cloud_compute": ANY, }, { "affiliation": ["root", "flow_a_2"], @@ -64,7 +64,7 @@ def test_extract_metadata_from_component(): "docstring": "WorkA.", "local_build_config": {"__build_config__": ANY}, "cloud_build_config": {"__build_config__": ANY}, - "cloud_compute": {"__cloud_compute__": ANY}, + "cloud_compute": ANY, }, { "affiliation": ["root", "flow_b"], @@ -79,6 +79,6 @@ def test_extract_metadata_from_component(): "docstring": "WorkB.", "local_build_config": {"__build_config__": ANY}, "cloud_build_config": {"__build_config__": ANY}, - "cloud_compute": {"__cloud_compute__": ANY}, + "cloud_compute": ANY, }, ]