[app] Add CloudCompute ID serializable within the flow and works state (#14819)

This commit is contained in:
thomas chaton 2022-10-04 21:46:44 +02:00 committed by GitHub
parent 53694eb93d
commit b936fd4380
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
23 changed files with 409 additions and 29 deletions

3
.gitignore vendored
View File

@ -112,7 +112,8 @@ celerybeat-schedule
# dotenv # dotenv
.env .env
.env_stagging .env_staging
.env_local
# virtualenv # virtualenv
.venv .venv

View File

@ -0,0 +1 @@
streamlit

View File

@ -0,0 +1 @@
name: app_works_on_default_machine

View File

@ -0,0 +1,53 @@
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
from uvicorn import run
from lightning import CloudCompute, LightningApp, LightningFlow, LightningWork
class Work(LightningWork):
def __init__(self, **kwargs):
super().__init__(parallel=True, **kwargs)
def run(self):
fastapi_service = FastAPI()
fastapi_service.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
@fastapi_service.get("/")
def get_root():
return {"Hello Word!"}
run(fastapi_service, host=self.host, port=self.port)
class Flow(LightningFlow):
def __init__(self):
super().__init__()
# In the Cloud: All the works defined without passing explicitly a CloudCompute object
# are running on the default machine.
# This would apply to `work_a`, `work_b` and the dynamically created `work_d`.
self.work_a = Work()
self.work_b = Work()
self.work_c = Work(cloud_compute=CloudCompute(name="cpu-small"))
def run(self):
if not hasattr(self, "work_d"):
self.work_d = Work()
for work in self.works():
work.run()
def configure_layout(self):
return [{"name": w.name, "content": w} for i, w in enumerate(self.works())]
app = LightningApp(Flow(), debug=True)

View File

@ -0,0 +1 @@
streamlit

View File

@ -1,4 +1,4 @@
lightning-cloud==0.5.7 lightning-cloud>=0.5.7
packaging packaging
deepdiff>=5.7.0, <=5.8.1 deepdiff>=5.7.0, <=5.8.1
starsessions>=1.2.1, <2.0 # strict starsessions>=1.2.1, <2.0 # strict

View File

@ -15,6 +15,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Add `--secret` option to CLI to allow binding Secrets to app environment variables when running in the cloud ([#14612](https://github.com/Lightning-AI/lightning/pull/14612)) - Add `--secret` option to CLI to allow binding Secrets to app environment variables when running in the cloud ([#14612](https://github.com/Lightning-AI/lightning/pull/14612))
- Add support for running the works without cloud compute in the default container ([#14819](https://github.com/Lightning-AI/lightning/pull/14819))
### Changed ### Changed
- -

View File

@ -47,5 +47,13 @@ DEBUG_ENABLED = bool(int(os.getenv("LIGHTNING_DEBUG", "0")))
LIGHTNING_COMPONENT_PUBLIC_REGISTRY = "https://lightning.ai/v1/components" LIGHTNING_COMPONENT_PUBLIC_REGISTRY = "https://lightning.ai/v1/components"
LIGHTNING_APPS_PUBLIC_REGISTRY = "https://lightning.ai/v1/apps" LIGHTNING_APPS_PUBLIC_REGISTRY = "https://lightning.ai/v1/apps"
ENABLE_STATE_WEBSOCKET = bool(int(os.getenv("ENABLE_STATE_WEBSOCKET", "0"))) ENABLE_STATE_WEBSOCKET = bool(int(os.getenv("ENABLE_STATE_WEBSOCKET", "0")))
DEBUG: bool = lightning_cloud.env.DEBUG DEBUG: bool = lightning_cloud.env.DEBUG
# EXPERIMENTAL: ENV VARIABLES TO ENABLE MULTIPLE WORKS IN THE SAME MACHINE
DEFAULT_NUMBER_OF_EXPOSED_PORTS = int(os.getenv("DEFAULT_NUMBER_OF_EXPOSED_PORTS", "50"))
ENABLE_MULTIPLE_WORKS_IN_DEFAULT_CONTAINER = bool(
int(os.getenv("ENABLE_MULTIPLE_WORKS_IN_DEFAULT_CONTAINER", "0"))
) # Note: This is disabled for the time being.
ENABLE_MULTIPLE_WORKS_IN_NON_DEFAULT_CONTAINER = bool(
int(os.getenv("ENABLE_MULTIPLE_WORKS_IN_NON_DEFAULT_CONTAINER", "0"))
) # This isn't used in the cloud yet.

View File

@ -14,6 +14,7 @@ from lightning_app.utilities.app_helpers import _is_json_serializable, _Lightnin
from lightning_app.utilities.component import _sanitize_state from lightning_app.utilities.component import _sanitize_state
from lightning_app.utilities.exceptions import ExitAppException from lightning_app.utilities.exceptions import ExitAppException
from lightning_app.utilities.introspection import _is_init_context, _is_run_context from lightning_app.utilities.introspection import _is_init_context, _is_run_context
from lightning_app.utilities.packaging.cloud_compute import _maybe_create_cloud_compute, CloudCompute
class LightningFlow: class LightningFlow:
@ -145,6 +146,8 @@ class LightningFlow:
# Attach the backend to the flow and its children work. # Attach the backend to the flow and its children work.
if self._backend: if self._backend:
LightningFlow._attach_backend(value, self._backend) LightningFlow._attach_backend(value, self._backend)
for work in value.works():
work._register_cloud_compute()
elif isinstance(value, LightningWork): elif isinstance(value, LightningWork):
self._works.add(name) self._works.add(name)
@ -153,6 +156,7 @@ class LightningFlow:
self._state.remove(name) self._state.remove(name)
if self._backend: if self._backend:
self._backend._wrap_run_method(_LightningAppRef().get_current(), value) self._backend._wrap_run_method(_LightningAppRef().get_current(), value)
value._register_cloud_compute()
elif isinstance(value, (Dict, List)): elif isinstance(value, (Dict, List)):
value._backend = self._backend value._backend = self._backend
@ -177,6 +181,9 @@ class LightningFlow:
value.component_name = self.name value.component_name = self.name
self._state.add(name) self._state.add(name)
elif isinstance(value, CloudCompute):
self._state.add(name)
elif _is_json_serializable(value): elif _is_json_serializable(value):
self._state.add(name) self._state.add(name)
@ -320,6 +327,8 @@ class LightningFlow:
for k, v in provided_state["vars"].items(): for k, v in provided_state["vars"].items():
if isinstance(v, Dict): if isinstance(v, Dict):
v = _maybe_create_drive(self.name, v) v = _maybe_create_drive(self.name, v)
if isinstance(v, Dict):
v = _maybe_create_cloud_compute(v)
setattr(self, k, v) setattr(self, k, v)
self._changes = provided_state["changes"] self._changes = provided_state["changes"]
self._calls.update(provided_state["calls"]) self._calls.update(provided_state["calls"])

View File

@ -24,7 +24,12 @@ from lightning_app.utilities.exceptions import LightningWorkException
from lightning_app.utilities.introspection import _is_init_context from lightning_app.utilities.introspection import _is_init_context
from lightning_app.utilities.network import find_free_network_port from lightning_app.utilities.network import find_free_network_port
from lightning_app.utilities.packaging.build_config import BuildConfig from lightning_app.utilities.packaging.build_config import BuildConfig
from lightning_app.utilities.packaging.cloud_compute import CloudCompute from lightning_app.utilities.packaging.cloud_compute import (
_CLOUD_COMPUTE_STORE,
_CloudComputeStore,
_maybe_create_cloud_compute,
CloudCompute,
)
from lightning_app.utilities.proxies import LightningWorkSetAttrProxy, ProxyWorkRun, unwrap from lightning_app.utilities.proxies import LightningWorkSetAttrProxy, ProxyWorkRun, unwrap
@ -103,7 +108,7 @@ class LightningWork:
" in the next version. Use `cache_calls` instead." " in the next version. Use `cache_calls` instead."
) )
self._cache_calls = run_once if run_once is not None else cache_calls self._cache_calls = run_once if run_once is not None else cache_calls
self._state = {"_host", "_port", "_url", "_future_url", "_internal_ip", "_restarting"} self._state = {"_host", "_port", "_url", "_future_url", "_internal_ip", "_restarting", "_cloud_compute"}
self._parallel = parallel self._parallel = parallel
self._host: str = host self._host: str = host
self._port: Optional[int] = port self._port: Optional[int] = port
@ -226,8 +231,14 @@ class LightningWork:
return self._cloud_compute return self._cloud_compute
@cloud_compute.setter @cloud_compute.setter
def cloud_compute(self, cloud_compute) -> None: def cloud_compute(self, cloud_compute: CloudCompute) -> None:
"""Returns the cloud compute used to select the cloud hardware.""" """Returns the cloud compute used to select the cloud hardware."""
# A new ID
current_id = self._cloud_compute.id
new_id = cloud_compute.id
if current_id != new_id:
compute_store: _CloudComputeStore = _CLOUD_COMPUTE_STORE[current_id]
compute_store.remove(self.name)
self._cloud_compute = cloud_compute self._cloud_compute = cloud_compute
@property @property
@ -488,6 +499,8 @@ class LightningWork:
for k, v in provided_state["vars"].items(): for k, v in provided_state["vars"].items():
if isinstance(v, Dict): if isinstance(v, Dict):
v = _maybe_create_drive(self.name, v) v = _maybe_create_drive(self.name, v)
if isinstance(v, Dict):
v = _maybe_create_cloud_compute(v)
setattr(self, k, v) setattr(self, k, v)
self._changes = provided_state["changes"] self._changes = provided_state["changes"]
@ -575,3 +588,10 @@ class LightningWork:
f"The work `{self.__class__.__name__}` is missing the `run()` method. This is required. Implement it" f"The work `{self.__class__.__name__}` is missing the `run()` method. This is required. Implement it"
" first and then call it in your Flow." " first and then call it in your Flow."
) )
def _register_cloud_compute(self):
internal_id = self.cloud_compute.id
assert internal_id
if internal_id not in _CLOUD_COMPUTE_STORE:
_CLOUD_COMPUTE_STORE[internal_id] = _CloudComputeStore(id=internal_id, component_names=[])
_CLOUD_COMPUTE_STORE[internal_id].add_component_name(self.name)

View File

@ -74,8 +74,8 @@ class MultiProcessingBackend(Backend):
and work._url == "" and work._url == ""
and work._port and work._port
): ):
url = f"http://{work._host}:{work._port}" url = work._future_url if work._future_url else f"http://{work._host}:{work._port}"
if _check_service_url_is_ready(url): if _check_service_url_is_ready(url, metadata=f"Checking {work.name}"):
work._url = url work._url = url
def stop_work(self, app, work: "lightning_app.LightningWork") -> None: def stop_work(self, app, work: "lightning_app.LightningWork") -> None:

View File

@ -40,7 +40,13 @@ from lightning_cloud.openapi import (
from lightning_cloud.openapi.rest import ApiException from lightning_cloud.openapi.rest import ApiException
from lightning_app.core.app import LightningApp from lightning_app.core.app import LightningApp
from lightning_app.core.constants import CLOUD_UPLOAD_WARNING, DISABLE_DEPENDENCY_CACHE from lightning_app.core.constants import (
CLOUD_UPLOAD_WARNING,
DEFAULT_NUMBER_OF_EXPOSED_PORTS,
DISABLE_DEPENDENCY_CACHE,
ENABLE_MULTIPLE_WORKS_IN_DEFAULT_CONTAINER,
ENABLE_MULTIPLE_WORKS_IN_NON_DEFAULT_CONTAINER,
)
from lightning_app.runners.backends.cloud import CloudBackend from lightning_app.runners.backends.cloud import CloudBackend
from lightning_app.runners.runtime import Runtime from lightning_app.runners.runtime import Runtime
from lightning_app.source_code import LocalSourceCodeDir from lightning_app.source_code import LocalSourceCodeDir
@ -108,6 +114,12 @@ class CloudRuntime(Runtime):
] ]
v1_env_vars.extend(env_vars_from_secrets) v1_env_vars.extend(env_vars_from_secrets)
if ENABLE_MULTIPLE_WORKS_IN_DEFAULT_CONTAINER:
v1_env_vars.append(V1EnvVar(name="ENABLE_MULTIPLE_WORKS_IN_DEFAULT_CONTAINER", value="1"))
if ENABLE_MULTIPLE_WORKS_IN_NON_DEFAULT_CONTAINER:
v1_env_vars.append(V1EnvVar(name="ENABLE_MULTIPLE_WORKS_IN_NON_DEFAULT_CONTAINER", value="1"))
work_reqs: List[V1Work] = [] work_reqs: List[V1Work] = []
for flow in self.app.flows: for flow in self.app.flows:
for work in flow.works(recurse=False): for work in flow.works(recurse=False):
@ -184,6 +196,7 @@ class CloudRuntime(Runtime):
desired_state=V1LightningappInstanceState.RUNNING, desired_state=V1LightningappInstanceState.RUNNING,
env=v1_env_vars, env=v1_env_vars,
) )
# if requirements file at the root of the repository is present, # if requirements file at the root of the repository is present,
# we pass just the file name to the backend, so backend can find it in the relative path # we pass just the file name to the backend, so backend can find it in the relative path
if requirements_file.is_file(): if requirements_file.is_file():
@ -219,6 +232,22 @@ class CloudRuntime(Runtime):
local_source=True, local_source=True,
dependency_cache_key=app_spec.dependency_cache_key, dependency_cache_key=app_spec.dependency_cache_key,
) )
if ENABLE_MULTIPLE_WORKS_IN_DEFAULT_CONTAINER:
network_configs: List[V1NetworkConfig] = []
initial_port = 8080 + 1 + len(frontend_specs)
for _ in range(DEFAULT_NUMBER_OF_EXPOSED_PORTS):
network_configs.append(
V1NetworkConfig(
name="w" + str(initial_port),
port=initial_port,
)
)
initial_port += 1
release_body.network_config = network_configs
if cluster_id is not None: if cluster_id is not None:
self._ensure_cluster_project_binding(project.project_id, cluster_id) self._ensure_cluster_project_binding(project.project_id, cluster_id)

View File

@ -7,6 +7,7 @@ from lightning_utilities.core.apply_func import apply_to_collection
from lightning_app.utilities.app_helpers import is_overridden from lightning_app.utilities.app_helpers import is_overridden
from lightning_app.utilities.enum import ComponentContext from lightning_app.utilities.enum import ComponentContext
from lightning_app.utilities.packaging.cloud_compute import CloudCompute
from lightning_app.utilities.tree import breadth_first from lightning_app.utilities.tree import breadth_first
if TYPE_CHECKING: if TYPE_CHECKING:
@ -52,9 +53,13 @@ def _sanitize_state(state: Dict[str, Any]) -> Dict[str, Any]:
def sanitize_drive(drive: Drive) -> Dict: def sanitize_drive(drive: Drive) -> Dict:
return drive.to_dict() return drive.to_dict()
def sanitize_cloud_compute(cloud_compute: CloudCompute) -> Dict:
return cloud_compute.to_dict()
state = apply_to_collection(state, dtype=Path, function=sanitize_path) state = apply_to_collection(state, dtype=Path, function=sanitize_path)
state = apply_to_collection(state, dtype=BasePayload, function=sanitize_payload) state = apply_to_collection(state, dtype=BasePayload, function=sanitize_payload)
state = apply_to_collection(state, dtype=Drive, function=sanitize_drive) state = apply_to_collection(state, dtype=Drive, function=sanitize_drive)
state = apply_to_collection(state, dtype=CloudCompute, function=sanitize_cloud_compute)
return state return state

View File

@ -92,8 +92,10 @@ def _collect_content_layout(layout: List[Dict], flow: "lightning_app.LightningFl
f"You configured an http link {url[:32]}... but it won't be accessible in the cloud." f"You configured an http link {url[:32]}... but it won't be accessible in the cloud."
f" Consider replacing 'http' with 'https' in the link above." f" Consider replacing 'http' with 'https' in the link above."
) )
elif isinstance(entry["content"], lightning_app.LightningFlow): elif isinstance(entry["content"], lightning_app.LightningFlow):
entry["content"] = entry["content"].name entry["content"] = entry["content"].name
elif isinstance(entry["content"], lightning_app.LightningWork): elif isinstance(entry["content"], lightning_app.LightningWork):
if entry["content"].url and not entry["content"].url.startswith("/"): if entry["content"].url and not entry["content"].url.startswith("/"):
entry["content"] = entry["content"].url entry["content"] = entry["content"].url

View File

@ -49,12 +49,12 @@ def _configure_session() -> Session:
return http return http
def _check_service_url_is_ready(url: str, timeout: float = 5) -> bool: def _check_service_url_is_ready(url: str, timeout: float = 5, metadata="") -> bool:
try: try:
response = requests.get(url, timeout=timeout) response = requests.get(url, timeout=timeout)
return response.status_code in (200, 404) return response.status_code in (200, 404)
except (ConnectionError, ConnectTimeout, ReadTimeout): except (ConnectionError, ConnectTimeout, ReadTimeout):
logger.debug(f"The url {url} is not ready.") logger.debug(f"The url {url} is not ready. {metadata}")
return False return False

View File

@ -1,5 +1,48 @@
from dataclasses import asdict, dataclass from dataclasses import asdict, dataclass
from typing import List, Optional, Union from typing import Dict, List, Optional, Union
from uuid import uuid4
from lightning_app.core.constants import ENABLE_MULTIPLE_WORKS_IN_NON_DEFAULT_CONTAINER
__CLOUD_COMPUTE_IDENTIFIER__ = "__cloud_compute__"
@dataclass
class _CloudComputeStore:
id: str
component_names: List[str]
def add_component_name(self, new_component_name: str) -> None:
found_index = None
# When the work is being named by the flow, pop its previous names
for index, component_name in enumerate(self.component_names):
if new_component_name.endswith(component_name.replace("root.", "")):
found_index = index
if found_index is not None:
self.component_names[found_index] = new_component_name
else:
if (
len(self.component_names) == 1
and not ENABLE_MULTIPLE_WORKS_IN_NON_DEFAULT_CONTAINER
and self.id != "default"
):
raise Exception(
f"A Cloud Compute can be assigned only to a single Work. Attached to {self.component_names[0]}"
)
self.component_names.append(new_component_name)
def remove(self, new_component_name: str) -> None:
found_index = None
for index, component_name in enumerate(self.component_names):
if new_component_name == component_name:
found_index = index
if found_index is not None:
del self.component_names[found_index]
_CLOUD_COMPUTE_STORE = {}
@dataclass @dataclass
@ -43,6 +86,7 @@ class CloudCompute:
wait_timeout: Optional[int] = None wait_timeout: Optional[int] = None
idle_timeout: Optional[int] = None idle_timeout: Optional[int] = None
shm_size: Optional[int] = 0 shm_size: Optional[int] = 0
_internal_id: Optional[str] = None
def __post_init__(self): def __post_init__(self):
if self.clusters: if self.clusters:
@ -52,9 +96,28 @@ class CloudCompute:
self.name = self.name.lower() self.name = self.name.lower()
# All `default` CloudCompute are identified in the same way.
if self._internal_id is None:
self._internal_id = "default" if self.name == "default" else uuid4().hex[:7]
def to_dict(self): def to_dict(self):
return {"__cloud_compute__": asdict(self)} return {"type": __CLOUD_COMPUTE_IDENTIFIER__, **asdict(self)}
@classmethod @classmethod
def from_dict(cls, d): def from_dict(cls, d):
return cls(**d["__cloud_compute__"]) assert d.pop("type") == __CLOUD_COMPUTE_IDENTIFIER__
return cls(**d)
@property
def id(self) -> Optional[str]:
return self._internal_id
def is_default(self) -> bool:
return self.name == "default"
def _maybe_create_cloud_compute(state: Dict) -> Union[CloudCompute, Dict]:
if state and __CLOUD_COMPUTE_IDENTIFIER__ == state.get("type", None):
cloud_compute = CloudCompute.from_dict(state)
return cloud_compute
return state

View File

@ -110,7 +110,7 @@ def _prepare_lightning_wheels_and_requirements(root: Path) -> Optional[Callable]
download_frontend(_PROJECT_ROOT) download_frontend(_PROJECT_ROOT)
_prepare_wheel(_PROJECT_ROOT) _prepare_wheel(_PROJECT_ROOT)
logger.info("Packaged Lightning with your application.") logger.info(f"Packaged Lightning with your application. Version: {version}")
tar_name = _copy_tar(_PROJECT_ROOT, root) tar_name = _copy_tar(_PROJECT_ROOT, root)
@ -121,7 +121,9 @@ def _prepare_lightning_wheels_and_requirements(root: Path) -> Optional[Callable]
# building and copying launcher wheel if installed in editable mode # building and copying launcher wheel if installed in editable mode
launcher_project_path = get_dist_path_if_editable_install("lightning_launcher") launcher_project_path = get_dist_path_if_editable_install("lightning_launcher")
if launcher_project_path: if launcher_project_path:
logger.info("Packaged Lightning Launcher with your application.") from lightning_launcher.__version__ import __version__ as launcher_version
logger.info(f"Packaged Lightning Launcher with your application. Version: {launcher_version}")
_prepare_wheel(launcher_project_path) _prepare_wheel(launcher_project_path)
tar_name = _copy_tar(launcher_project_path, root) tar_name = _copy_tar(launcher_project_path, root)
tar_files.append(os.path.join(root, tar_name)) tar_files.append(os.path.join(root, tar_name))
@ -129,7 +131,9 @@ def _prepare_lightning_wheels_and_requirements(root: Path) -> Optional[Callable]
# building and copying lightning-cloud wheel if installed in editable mode # building and copying lightning-cloud wheel if installed in editable mode
lightning_cloud_project_path = get_dist_path_if_editable_install("lightning_cloud") lightning_cloud_project_path = get_dist_path_if_editable_install("lightning_cloud")
if lightning_cloud_project_path: if lightning_cloud_project_path:
logger.info("Packaged Lightning Cloud with your application.") from lightning_cloud.__version__ import __version__ as cloud_version
logger.info(f"Packaged Lightning Cloud with your application. Version: {cloud_version}")
_prepare_wheel(lightning_cloud_project_path) _prepare_wheel(lightning_cloud_project_path)
tar_name = _copy_tar(lightning_cloud_project_path, root) tar_name = _copy_tar(lightning_cloud_project_path, root)
tar_files.append(os.path.join(root, tar_name)) tar_files.append(os.path.join(root, tar_name))

View File

@ -11,6 +11,7 @@ from tests_app import _PROJECT_ROOT
from lightning_app.storage.path import storage_root_dir from lightning_app.storage.path import storage_root_dir
from lightning_app.utilities.component import _set_context from lightning_app.utilities.component import _set_context
from lightning_app.utilities.packaging import cloud_compute
from lightning_app.utilities.packaging.app_config import _APP_CONFIG_FILENAME from lightning_app.utilities.packaging.app_config import _APP_CONFIG_FILENAME
from lightning_app.utilities.state import AppState from lightning_app.utilities.state import AppState
@ -74,6 +75,7 @@ def clear_app_state_state_variables():
lightning_app.utilities.state._STATE = None lightning_app.utilities.state._STATE = None
lightning_app.utilities.state._LAST_STATE = None lightning_app.utilities.state._LAST_STATE = None
AppState._MY_AFFILIATION = () AppState._MY_AFFILIATION = ()
cloud_compute._CLOUD_COMPUTE_STORE.clear()
@pytest.fixture @pytest.fixture

View File

@ -11,7 +11,7 @@ from deepdiff import Delta
from pympler import asizeof from pympler import asizeof
from tests_app import _PROJECT_ROOT from tests_app import _PROJECT_ROOT
from lightning_app import LightningApp, LightningFlow, LightningWork # F401 from lightning_app import CloudCompute, LightningApp, LightningFlow, LightningWork # F401
from lightning_app.core.constants import ( from lightning_app.core.constants import (
FLOW_DURATION_SAMPLES, FLOW_DURATION_SAMPLES,
FLOW_DURATION_THRESHOLD, FLOW_DURATION_THRESHOLD,
@ -27,6 +27,7 @@ from lightning_app.testing.helpers import RunIf
from lightning_app.testing.testing import LightningTestApp from lightning_app.testing.testing import LightningTestApp
from lightning_app.utilities.app_helpers import affiliation from lightning_app.utilities.app_helpers import affiliation
from lightning_app.utilities.enum import AppStage, WorkStageStatus, WorkStopReasons from lightning_app.utilities.enum import AppStage, WorkStageStatus, WorkStopReasons
from lightning_app.utilities.packaging import cloud_compute
from lightning_app.utilities.redis import check_if_redis_running from lightning_app.utilities.redis import check_if_redis_running
from lightning_app.utilities.warnings import LightningFlowWarning from lightning_app.utilities.warnings import LightningFlowWarning
@ -255,7 +256,7 @@ def test_nested_component(runtime_cls):
assert app.root.b.c.d.e.w_e.c == 1 assert app.root.b.c.d.e.w_e.c == 1
class WorkCC(LightningWork): class WorkCCC(LightningWork):
def run(self): def run(self):
pass pass
@ -263,7 +264,7 @@ class WorkCC(LightningWork):
class CC(LightningFlow): class CC(LightningFlow):
def __init__(self): def __init__(self):
super().__init__() super().__init__()
self.work_cc = WorkCC() self.work_cc = WorkCCC()
def run(self): def run(self):
pass pass
@ -719,7 +720,7 @@ class WorkDD(LightningWork):
self.counter += 1 self.counter += 1
class FlowCC(LightningFlow): class FlowCCTolerance(LightningFlow):
def __init__(self): def __init__(self):
super().__init__() super().__init__()
self.work = WorkDD() self.work = WorkDD()
@ -744,7 +745,7 @@ class FaultToleranceLightningTestApp(LightningTestApp):
# TODO (tchaton) Resolve this test with Resumable App. # TODO (tchaton) Resolve this test with Resumable App.
@RunIf(skip_windows=True) @RunIf(skip_windows=True)
def test_fault_tolerance_work(): def test_fault_tolerance_work():
app = FaultToleranceLightningTestApp(FlowCC()) app = FaultToleranceLightningTestApp(FlowCCTolerance())
MultiProcessRuntime(app, start_server=False).dispatch() MultiProcessRuntime(app, start_server=False).dispatch()
assert app.root.work.counter == 2 assert app.root.work.counter == 2
@ -952,8 +953,8 @@ class SizeFlow(LightningFlow):
def test_state_size_constant_growth(): def test_state_size_constant_growth():
app = LightningApp(SizeFlow()) app = LightningApp(SizeFlow())
MultiProcessRuntime(app, start_server=False).dispatch() MultiProcessRuntime(app, start_server=False).dispatch()
assert app.root._state_sizes[0] <= 5904 assert app.root._state_sizes[0] <= 6952
assert app.root._state_sizes[20] <= 23736 assert app.root._state_sizes[20] <= 24896
class FlowUpdated(LightningFlow): class FlowUpdated(LightningFlow):
@ -1041,3 +1042,70 @@ class TestLightningHasUpdatedApp(LightningApp):
def test_lightning_app_has_updated(): def test_lightning_app_has_updated():
app = TestLightningHasUpdatedApp(FlowPath()) app = TestLightningHasUpdatedApp(FlowPath())
MultiProcessRuntime(app, start_server=False).dispatch() MultiProcessRuntime(app, start_server=False).dispatch()
class WorkCC(LightningWork):
def run(self):
pass
class FlowCC(LightningFlow):
def __init__(self):
super().__init__()
self.cloud_compute = CloudCompute(name="gpu", _internal_id="a")
self.work_a = WorkCC(cloud_compute=self.cloud_compute)
self.work_b = WorkCC(cloud_compute=self.cloud_compute)
self.work_c = WorkCC()
assert self.work_a.cloud_compute._internal_id == self.work_b.cloud_compute._internal_id
def run(self):
self.work_d = WorkCC()
class FlowWrapper(LightningFlow):
def __init__(self, flow):
super().__init__()
self.w = flow
def test_cloud_compute_binding():
cloud_compute.ENABLE_MULTIPLE_WORKS_IN_NON_DEFAULT_CONTAINER = True
assert cloud_compute._CLOUD_COMPUTE_STORE == {}
flow = FlowCC()
assert len(cloud_compute._CLOUD_COMPUTE_STORE) == 2
assert cloud_compute._CLOUD_COMPUTE_STORE["default"].component_names == ["root.work_c"]
assert cloud_compute._CLOUD_COMPUTE_STORE["a"].component_names == ["root.work_a", "root.work_b"]
wrapper = FlowWrapper(flow)
assert cloud_compute._CLOUD_COMPUTE_STORE["default"].component_names == ["root.w.work_c"]
assert cloud_compute._CLOUD_COMPUTE_STORE["a"].component_names == ["root.w.work_a", "root.w.work_b"]
_ = FlowWrapper(wrapper)
assert cloud_compute._CLOUD_COMPUTE_STORE["default"].component_names == ["root.w.w.work_c"]
assert cloud_compute._CLOUD_COMPUTE_STORE["a"].component_names == ["root.w.w.work_a", "root.w.w.work_b"]
assert "__cloud_compute__" == flow.state["vars"]["cloud_compute"]["type"]
assert "__cloud_compute__" == flow.work_a.state["vars"]["_cloud_compute"]["type"]
assert "__cloud_compute__" == flow.work_b.state["vars"]["_cloud_compute"]["type"]
assert "__cloud_compute__" == flow.work_c.state["vars"]["_cloud_compute"]["type"]
work_a_id = flow.work_a.state["vars"]["_cloud_compute"]["_internal_id"]
work_b_id = flow.work_b.state["vars"]["_cloud_compute"]["_internal_id"]
work_c_id = flow.work_c.state["vars"]["_cloud_compute"]["_internal_id"]
assert work_a_id == work_b_id
assert work_a_id != work_c_id
assert work_c_id == "default"
flow.work_a.cloud_compute = CloudCompute(name="something_else")
assert cloud_compute._CLOUD_COMPUTE_STORE["a"].component_names == ["root.w.w.work_b"]
flow.set_state(flow.state)
assert isinstance(flow.cloud_compute, CloudCompute)
assert isinstance(flow.work_a.cloud_compute, CloudCompute)
assert isinstance(flow.work_c.cloud_compute, CloudCompute)
cloud_compute.ENABLE_MULTIPLE_WORKS_IN_NON_DEFAULT_CONTAINER = False
with pytest.raises(Exception, match="A Cloud Compute can be assigned only to a single Work"):
FlowCC()

View File

@ -325,6 +325,17 @@ def test_lightning_flow_and_work():
"_paths": {}, "_paths": {},
"_restarting": False, "_restarting": False,
"_internal_ip": "", "_internal_ip": "",
"_cloud_compute": {
"type": "__cloud_compute__",
"name": "default",
"disk_size": 0,
"clusters": None,
"preemptible": False,
"wait_timeout": None,
"idle_timeout": None,
"shm_size": 0,
"_internal_id": "default",
},
}, },
"calls": {CacheCallsKeys.LATEST_CALL_HASH: None}, "calls": {CacheCallsKeys.LATEST_CALL_HASH: None},
"changes": {}, "changes": {},
@ -339,6 +350,17 @@ def test_lightning_flow_and_work():
"_paths": {}, "_paths": {},
"_restarting": False, "_restarting": False,
"_internal_ip": "", "_internal_ip": "",
"_cloud_compute": {
"type": "__cloud_compute__",
"name": "default",
"disk_size": 0,
"clusters": None,
"preemptible": False,
"wait_timeout": None,
"idle_timeout": None,
"shm_size": 0,
"_internal_id": "default",
},
}, },
"calls": {CacheCallsKeys.LATEST_CALL_HASH: None}, "calls": {CacheCallsKeys.LATEST_CALL_HASH: None},
"changes": {}, "changes": {},
@ -369,6 +391,17 @@ def test_lightning_flow_and_work():
"_paths": {}, "_paths": {},
"_restarting": False, "_restarting": False,
"_internal_ip": "", "_internal_ip": "",
"_cloud_compute": {
"type": "__cloud_compute__",
"name": "default",
"disk_size": 0,
"clusters": None,
"preemptible": False,
"wait_timeout": None,
"idle_timeout": None,
"shm_size": 0,
"_internal_id": "default",
},
}, },
"calls": {CacheCallsKeys.LATEST_CALL_HASH: None}, "calls": {CacheCallsKeys.LATEST_CALL_HASH: None},
"changes": {}, "changes": {},
@ -383,6 +416,17 @@ def test_lightning_flow_and_work():
"_paths": {}, "_paths": {},
"_restarting": False, "_restarting": False,
"_internal_ip": "", "_internal_ip": "",
"_cloud_compute": {
"type": "__cloud_compute__",
"name": "default",
"disk_size": 0,
"clusters": None,
"preemptible": False,
"wait_timeout": None,
"idle_timeout": None,
"shm_size": 0,
"_internal_id": "default",
},
}, },
"calls": { "calls": {
CacheCallsKeys.LATEST_CALL_HASH: None, CacheCallsKeys.LATEST_CALL_HASH: None,

View File

@ -45,6 +45,17 @@ def test_dict():
"_paths": {}, "_paths": {},
"_restarting": False, "_restarting": False,
"_internal_ip": "", "_internal_ip": "",
"_cloud_compute": {
"type": "__cloud_compute__",
"name": "default",
"disk_size": 0,
"clusters": None,
"preemptible": False,
"wait_timeout": None,
"idle_timeout": None,
"shm_size": 0,
"_internal_id": "default",
},
} }
for k in ("a", "b", "c", "d") for k in ("a", "b", "c", "d")
) )
@ -68,6 +79,17 @@ def test_dict():
"_paths": {}, "_paths": {},
"_restarting": False, "_restarting": False,
"_internal_ip": "", "_internal_ip": "",
"_cloud_compute": {
"type": "__cloud_compute__",
"name": "default",
"disk_size": 0,
"clusters": None,
"preemptible": False,
"wait_timeout": None,
"idle_timeout": None,
"shm_size": 0,
"_internal_id": "default",
},
} }
for k in ("a", "b", "c", "d") for k in ("a", "b", "c", "d")
) )
@ -91,6 +113,17 @@ def test_dict():
"_paths": {}, "_paths": {},
"_restarting": False, "_restarting": False,
"_internal_ip": "", "_internal_ip": "",
"_cloud_compute": {
"type": "__cloud_compute__",
"name": "default",
"disk_size": 0,
"clusters": None,
"preemptible": False,
"wait_timeout": None,
"idle_timeout": None,
"shm_size": 0,
"_internal_id": "default",
},
} }
for k in ("a", "b", "c", "d") for k in ("a", "b", "c", "d")
) )
@ -166,6 +199,17 @@ def test_list():
"_paths": {}, "_paths": {},
"_restarting": False, "_restarting": False,
"_internal_ip": "", "_internal_ip": "",
"_cloud_compute": {
"type": "__cloud_compute__",
"name": "default",
"disk_size": 0,
"clusters": None,
"preemptible": False,
"wait_timeout": None,
"idle_timeout": None,
"shm_size": 0,
"_internal_id": "default",
},
} }
for i in range(4) for i in range(4)
) )
@ -189,6 +233,17 @@ def test_list():
"_paths": {}, "_paths": {},
"_restarting": False, "_restarting": False,
"_internal_ip": "", "_internal_ip": "",
"_cloud_compute": {
"type": "__cloud_compute__",
"name": "default",
"disk_size": 0,
"clusters": None,
"preemptible": False,
"wait_timeout": None,
"idle_timeout": None,
"shm_size": 0,
"_internal_id": "default",
},
} }
for i in range(4) for i in range(4)
) )
@ -207,6 +262,17 @@ def test_list():
"_paths": {}, "_paths": {},
"_restarting": False, "_restarting": False,
"_internal_ip": "", "_internal_ip": "",
"_cloud_compute": {
"type": "__cloud_compute__",
"name": "default",
"disk_size": 0,
"clusters": None,
"preemptible": False,
"wait_timeout": None,
"idle_timeout": None,
"shm_size": 0,
"_internal_id": "default",
},
} }
for i in range(4) for i in range(4)
) )

View File

@ -3,7 +3,6 @@ from unittest import mock
import pytest import pytest
from lightning.__version__ import version
from lightning_app.testing.helpers import RunIf from lightning_app.testing.helpers import RunIf
from lightning_app.utilities.packaging import lightning_utils from lightning_app.utilities.packaging import lightning_utils
from lightning_app.utilities.packaging.lightning_utils import ( from lightning_app.utilities.packaging.lightning_utils import (
@ -16,8 +15,10 @@ def test_prepare_lightning_wheels_and_requirement(tmpdir):
"""This test ensures the lightning source gets packaged inside the lightning repo.""" """This test ensures the lightning source gets packaged inside the lightning repo."""
cleanup_handle = _prepare_lightning_wheels_and_requirements(tmpdir) cleanup_handle = _prepare_lightning_wheels_and_requirements(tmpdir)
from lightning.__version__ import version
tar_name = f"lightning-{version}.tar.gz" tar_name = f"lightning-{version}.tar.gz"
assert sorted(os.listdir(tmpdir)) == [tar_name] assert sorted(os.listdir(tmpdir))[0] == tar_name
cleanup_handle() cleanup_handle()
assert os.listdir(tmpdir) == [] assert os.listdir(tmpdir) == []

View File

@ -49,7 +49,7 @@ def test_extract_metadata_from_component():
"docstring": "WorkA.", "docstring": "WorkA.",
"local_build_config": {"__build_config__": ANY}, "local_build_config": {"__build_config__": ANY},
"cloud_build_config": {"__build_config__": ANY}, "cloud_build_config": {"__build_config__": ANY},
"cloud_compute": {"__cloud_compute__": ANY}, "cloud_compute": ANY,
}, },
{ {
"affiliation": ["root", "flow_a_2"], "affiliation": ["root", "flow_a_2"],
@ -64,7 +64,7 @@ def test_extract_metadata_from_component():
"docstring": "WorkA.", "docstring": "WorkA.",
"local_build_config": {"__build_config__": ANY}, "local_build_config": {"__build_config__": ANY},
"cloud_build_config": {"__build_config__": ANY}, "cloud_build_config": {"__build_config__": ANY},
"cloud_compute": {"__cloud_compute__": ANY}, "cloud_compute": ANY,
}, },
{ {
"affiliation": ["root", "flow_b"], "affiliation": ["root", "flow_b"],
@ -79,6 +79,6 @@ def test_extract_metadata_from_component():
"docstring": "WorkB.", "docstring": "WorkB.",
"local_build_config": {"__build_config__": ANY}, "local_build_config": {"__build_config__": ANY},
"cloud_build_config": {"__build_config__": ANY}, "cloud_build_config": {"__build_config__": ANY},
"cloud_compute": {"__cloud_compute__": ANY}, "cloud_compute": ANY,
}, },
] ]