[app] Add CloudCompute ID serializable within the flow and works state (#14819)
This commit is contained in:
parent
53694eb93d
commit
b936fd4380
|
@ -112,7 +112,8 @@ celerybeat-schedule
|
|||
|
||||
# dotenv
|
||||
.env
|
||||
.env_stagging
|
||||
.env_staging
|
||||
.env_local
|
||||
|
||||
# virtualenv
|
||||
.venv
|
||||
|
|
|
@ -0,0 +1 @@
|
|||
streamlit
|
|
@ -0,0 +1 @@
|
|||
name: app_works_on_default_machine
|
|
@ -0,0 +1,53 @@
|
|||
from fastapi import FastAPI
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from uvicorn import run
|
||||
|
||||
from lightning import CloudCompute, LightningApp, LightningFlow, LightningWork
|
||||
|
||||
|
||||
class Work(LightningWork):
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(parallel=True, **kwargs)
|
||||
|
||||
def run(self):
|
||||
fastapi_service = FastAPI()
|
||||
|
||||
fastapi_service.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=["*"],
|
||||
allow_credentials=True,
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
@fastapi_service.get("/")
|
||||
def get_root():
|
||||
return {"Hello Word!"}
|
||||
|
||||
run(fastapi_service, host=self.host, port=self.port)
|
||||
|
||||
|
||||
class Flow(LightningFlow):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
# In the Cloud: All the works defined without passing explicitly a CloudCompute object
|
||||
# are running on the default machine.
|
||||
# This would apply to `work_a`, `work_b` and the dynamically created `work_d`.
|
||||
|
||||
self.work_a = Work()
|
||||
self.work_b = Work()
|
||||
|
||||
self.work_c = Work(cloud_compute=CloudCompute(name="cpu-small"))
|
||||
|
||||
def run(self):
|
||||
if not hasattr(self, "work_d"):
|
||||
self.work_d = Work()
|
||||
|
||||
for work in self.works():
|
||||
work.run()
|
||||
|
||||
def configure_layout(self):
|
||||
return [{"name": w.name, "content": w} for i, w in enumerate(self.works())]
|
||||
|
||||
|
||||
app = LightningApp(Flow(), debug=True)
|
|
@ -0,0 +1 @@
|
|||
streamlit
|
|
@ -1,4 +1,4 @@
|
|||
lightning-cloud==0.5.7
|
||||
lightning-cloud>=0.5.7
|
||||
packaging
|
||||
deepdiff>=5.7.0, <=5.8.1
|
||||
starsessions>=1.2.1, <2.0 # strict
|
||||
|
|
|
@ -15,6 +15,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
|
|||
- Add `--secret` option to CLI to allow binding Secrets to app environment variables when running in the cloud ([#14612](https://github.com/Lightning-AI/lightning/pull/14612))
|
||||
|
||||
|
||||
- Add support for running the works without cloud compute in the default container ([#14819](https://github.com/Lightning-AI/lightning/pull/14819))
|
||||
|
||||
### Changed
|
||||
|
||||
-
|
||||
|
|
|
@ -47,5 +47,13 @@ DEBUG_ENABLED = bool(int(os.getenv("LIGHTNING_DEBUG", "0")))
|
|||
LIGHTNING_COMPONENT_PUBLIC_REGISTRY = "https://lightning.ai/v1/components"
|
||||
LIGHTNING_APPS_PUBLIC_REGISTRY = "https://lightning.ai/v1/apps"
|
||||
ENABLE_STATE_WEBSOCKET = bool(int(os.getenv("ENABLE_STATE_WEBSOCKET", "0")))
|
||||
|
||||
DEBUG: bool = lightning_cloud.env.DEBUG
|
||||
|
||||
# EXPERIMENTAL: ENV VARIABLES TO ENABLE MULTIPLE WORKS IN THE SAME MACHINE
|
||||
DEFAULT_NUMBER_OF_EXPOSED_PORTS = int(os.getenv("DEFAULT_NUMBER_OF_EXPOSED_PORTS", "50"))
|
||||
ENABLE_MULTIPLE_WORKS_IN_DEFAULT_CONTAINER = bool(
|
||||
int(os.getenv("ENABLE_MULTIPLE_WORKS_IN_DEFAULT_CONTAINER", "0"))
|
||||
) # Note: This is disabled for the time being.
|
||||
ENABLE_MULTIPLE_WORKS_IN_NON_DEFAULT_CONTAINER = bool(
|
||||
int(os.getenv("ENABLE_MULTIPLE_WORKS_IN_NON_DEFAULT_CONTAINER", "0"))
|
||||
) # This isn't used in the cloud yet.
|
||||
|
|
|
@ -14,6 +14,7 @@ from lightning_app.utilities.app_helpers import _is_json_serializable, _Lightnin
|
|||
from lightning_app.utilities.component import _sanitize_state
|
||||
from lightning_app.utilities.exceptions import ExitAppException
|
||||
from lightning_app.utilities.introspection import _is_init_context, _is_run_context
|
||||
from lightning_app.utilities.packaging.cloud_compute import _maybe_create_cloud_compute, CloudCompute
|
||||
|
||||
|
||||
class LightningFlow:
|
||||
|
@ -145,6 +146,8 @@ class LightningFlow:
|
|||
# Attach the backend to the flow and its children work.
|
||||
if self._backend:
|
||||
LightningFlow._attach_backend(value, self._backend)
|
||||
for work in value.works():
|
||||
work._register_cloud_compute()
|
||||
|
||||
elif isinstance(value, LightningWork):
|
||||
self._works.add(name)
|
||||
|
@ -153,6 +156,7 @@ class LightningFlow:
|
|||
self._state.remove(name)
|
||||
if self._backend:
|
||||
self._backend._wrap_run_method(_LightningAppRef().get_current(), value)
|
||||
value._register_cloud_compute()
|
||||
|
||||
elif isinstance(value, (Dict, List)):
|
||||
value._backend = self._backend
|
||||
|
@ -177,6 +181,9 @@ class LightningFlow:
|
|||
value.component_name = self.name
|
||||
self._state.add(name)
|
||||
|
||||
elif isinstance(value, CloudCompute):
|
||||
self._state.add(name)
|
||||
|
||||
elif _is_json_serializable(value):
|
||||
self._state.add(name)
|
||||
|
||||
|
@ -320,6 +327,8 @@ class LightningFlow:
|
|||
for k, v in provided_state["vars"].items():
|
||||
if isinstance(v, Dict):
|
||||
v = _maybe_create_drive(self.name, v)
|
||||
if isinstance(v, Dict):
|
||||
v = _maybe_create_cloud_compute(v)
|
||||
setattr(self, k, v)
|
||||
self._changes = provided_state["changes"]
|
||||
self._calls.update(provided_state["calls"])
|
||||
|
|
|
@ -24,7 +24,12 @@ from lightning_app.utilities.exceptions import LightningWorkException
|
|||
from lightning_app.utilities.introspection import _is_init_context
|
||||
from lightning_app.utilities.network import find_free_network_port
|
||||
from lightning_app.utilities.packaging.build_config import BuildConfig
|
||||
from lightning_app.utilities.packaging.cloud_compute import CloudCompute
|
||||
from lightning_app.utilities.packaging.cloud_compute import (
|
||||
_CLOUD_COMPUTE_STORE,
|
||||
_CloudComputeStore,
|
||||
_maybe_create_cloud_compute,
|
||||
CloudCompute,
|
||||
)
|
||||
from lightning_app.utilities.proxies import LightningWorkSetAttrProxy, ProxyWorkRun, unwrap
|
||||
|
||||
|
||||
|
@ -103,7 +108,7 @@ class LightningWork:
|
|||
" in the next version. Use `cache_calls` instead."
|
||||
)
|
||||
self._cache_calls = run_once if run_once is not None else cache_calls
|
||||
self._state = {"_host", "_port", "_url", "_future_url", "_internal_ip", "_restarting"}
|
||||
self._state = {"_host", "_port", "_url", "_future_url", "_internal_ip", "_restarting", "_cloud_compute"}
|
||||
self._parallel = parallel
|
||||
self._host: str = host
|
||||
self._port: Optional[int] = port
|
||||
|
@ -226,8 +231,14 @@ class LightningWork:
|
|||
return self._cloud_compute
|
||||
|
||||
@cloud_compute.setter
|
||||
def cloud_compute(self, cloud_compute) -> None:
|
||||
def cloud_compute(self, cloud_compute: CloudCompute) -> None:
|
||||
"""Returns the cloud compute used to select the cloud hardware."""
|
||||
# A new ID
|
||||
current_id = self._cloud_compute.id
|
||||
new_id = cloud_compute.id
|
||||
if current_id != new_id:
|
||||
compute_store: _CloudComputeStore = _CLOUD_COMPUTE_STORE[current_id]
|
||||
compute_store.remove(self.name)
|
||||
self._cloud_compute = cloud_compute
|
||||
|
||||
@property
|
||||
|
@ -488,6 +499,8 @@ class LightningWork:
|
|||
for k, v in provided_state["vars"].items():
|
||||
if isinstance(v, Dict):
|
||||
v = _maybe_create_drive(self.name, v)
|
||||
if isinstance(v, Dict):
|
||||
v = _maybe_create_cloud_compute(v)
|
||||
setattr(self, k, v)
|
||||
|
||||
self._changes = provided_state["changes"]
|
||||
|
@ -575,3 +588,10 @@ class LightningWork:
|
|||
f"The work `{self.__class__.__name__}` is missing the `run()` method. This is required. Implement it"
|
||||
" first and then call it in your Flow."
|
||||
)
|
||||
|
||||
def _register_cloud_compute(self):
|
||||
internal_id = self.cloud_compute.id
|
||||
assert internal_id
|
||||
if internal_id not in _CLOUD_COMPUTE_STORE:
|
||||
_CLOUD_COMPUTE_STORE[internal_id] = _CloudComputeStore(id=internal_id, component_names=[])
|
||||
_CLOUD_COMPUTE_STORE[internal_id].add_component_name(self.name)
|
||||
|
|
|
@ -74,8 +74,8 @@ class MultiProcessingBackend(Backend):
|
|||
and work._url == ""
|
||||
and work._port
|
||||
):
|
||||
url = f"http://{work._host}:{work._port}"
|
||||
if _check_service_url_is_ready(url):
|
||||
url = work._future_url if work._future_url else f"http://{work._host}:{work._port}"
|
||||
if _check_service_url_is_ready(url, metadata=f"Checking {work.name}"):
|
||||
work._url = url
|
||||
|
||||
def stop_work(self, app, work: "lightning_app.LightningWork") -> None:
|
||||
|
|
|
@ -40,7 +40,13 @@ from lightning_cloud.openapi import (
|
|||
from lightning_cloud.openapi.rest import ApiException
|
||||
|
||||
from lightning_app.core.app import LightningApp
|
||||
from lightning_app.core.constants import CLOUD_UPLOAD_WARNING, DISABLE_DEPENDENCY_CACHE
|
||||
from lightning_app.core.constants import (
|
||||
CLOUD_UPLOAD_WARNING,
|
||||
DEFAULT_NUMBER_OF_EXPOSED_PORTS,
|
||||
DISABLE_DEPENDENCY_CACHE,
|
||||
ENABLE_MULTIPLE_WORKS_IN_DEFAULT_CONTAINER,
|
||||
ENABLE_MULTIPLE_WORKS_IN_NON_DEFAULT_CONTAINER,
|
||||
)
|
||||
from lightning_app.runners.backends.cloud import CloudBackend
|
||||
from lightning_app.runners.runtime import Runtime
|
||||
from lightning_app.source_code import LocalSourceCodeDir
|
||||
|
@ -108,6 +114,12 @@ class CloudRuntime(Runtime):
|
|||
]
|
||||
v1_env_vars.extend(env_vars_from_secrets)
|
||||
|
||||
if ENABLE_MULTIPLE_WORKS_IN_DEFAULT_CONTAINER:
|
||||
v1_env_vars.append(V1EnvVar(name="ENABLE_MULTIPLE_WORKS_IN_DEFAULT_CONTAINER", value="1"))
|
||||
|
||||
if ENABLE_MULTIPLE_WORKS_IN_NON_DEFAULT_CONTAINER:
|
||||
v1_env_vars.append(V1EnvVar(name="ENABLE_MULTIPLE_WORKS_IN_NON_DEFAULT_CONTAINER", value="1"))
|
||||
|
||||
work_reqs: List[V1Work] = []
|
||||
for flow in self.app.flows:
|
||||
for work in flow.works(recurse=False):
|
||||
|
@ -184,6 +196,7 @@ class CloudRuntime(Runtime):
|
|||
desired_state=V1LightningappInstanceState.RUNNING,
|
||||
env=v1_env_vars,
|
||||
)
|
||||
|
||||
# if requirements file at the root of the repository is present,
|
||||
# we pass just the file name to the backend, so backend can find it in the relative path
|
||||
if requirements_file.is_file():
|
||||
|
@ -219,6 +232,22 @@ class CloudRuntime(Runtime):
|
|||
local_source=True,
|
||||
dependency_cache_key=app_spec.dependency_cache_key,
|
||||
)
|
||||
|
||||
if ENABLE_MULTIPLE_WORKS_IN_DEFAULT_CONTAINER:
|
||||
network_configs: List[V1NetworkConfig] = []
|
||||
|
||||
initial_port = 8080 + 1 + len(frontend_specs)
|
||||
for _ in range(DEFAULT_NUMBER_OF_EXPOSED_PORTS):
|
||||
network_configs.append(
|
||||
V1NetworkConfig(
|
||||
name="w" + str(initial_port),
|
||||
port=initial_port,
|
||||
)
|
||||
)
|
||||
initial_port += 1
|
||||
|
||||
release_body.network_config = network_configs
|
||||
|
||||
if cluster_id is not None:
|
||||
self._ensure_cluster_project_binding(project.project_id, cluster_id)
|
||||
|
||||
|
|
|
@ -7,6 +7,7 @@ from lightning_utilities.core.apply_func import apply_to_collection
|
|||
|
||||
from lightning_app.utilities.app_helpers import is_overridden
|
||||
from lightning_app.utilities.enum import ComponentContext
|
||||
from lightning_app.utilities.packaging.cloud_compute import CloudCompute
|
||||
from lightning_app.utilities.tree import breadth_first
|
||||
|
||||
if TYPE_CHECKING:
|
||||
|
@ -52,9 +53,13 @@ def _sanitize_state(state: Dict[str, Any]) -> Dict[str, Any]:
|
|||
def sanitize_drive(drive: Drive) -> Dict:
|
||||
return drive.to_dict()
|
||||
|
||||
def sanitize_cloud_compute(cloud_compute: CloudCompute) -> Dict:
|
||||
return cloud_compute.to_dict()
|
||||
|
||||
state = apply_to_collection(state, dtype=Path, function=sanitize_path)
|
||||
state = apply_to_collection(state, dtype=BasePayload, function=sanitize_payload)
|
||||
state = apply_to_collection(state, dtype=Drive, function=sanitize_drive)
|
||||
state = apply_to_collection(state, dtype=CloudCompute, function=sanitize_cloud_compute)
|
||||
return state
|
||||
|
||||
|
||||
|
|
|
@ -92,8 +92,10 @@ def _collect_content_layout(layout: List[Dict], flow: "lightning_app.LightningFl
|
|||
f"You configured an http link {url[:32]}... but it won't be accessible in the cloud."
|
||||
f" Consider replacing 'http' with 'https' in the link above."
|
||||
)
|
||||
|
||||
elif isinstance(entry["content"], lightning_app.LightningFlow):
|
||||
entry["content"] = entry["content"].name
|
||||
|
||||
elif isinstance(entry["content"], lightning_app.LightningWork):
|
||||
if entry["content"].url and not entry["content"].url.startswith("/"):
|
||||
entry["content"] = entry["content"].url
|
||||
|
|
|
@ -49,12 +49,12 @@ def _configure_session() -> Session:
|
|||
return http
|
||||
|
||||
|
||||
def _check_service_url_is_ready(url: str, timeout: float = 5) -> bool:
|
||||
def _check_service_url_is_ready(url: str, timeout: float = 5, metadata="") -> bool:
|
||||
try:
|
||||
response = requests.get(url, timeout=timeout)
|
||||
return response.status_code in (200, 404)
|
||||
except (ConnectionError, ConnectTimeout, ReadTimeout):
|
||||
logger.debug(f"The url {url} is not ready.")
|
||||
logger.debug(f"The url {url} is not ready. {metadata}")
|
||||
return False
|
||||
|
||||
|
||||
|
|
|
@ -1,5 +1,48 @@
|
|||
from dataclasses import asdict, dataclass
|
||||
from typing import List, Optional, Union
|
||||
from typing import Dict, List, Optional, Union
|
||||
from uuid import uuid4
|
||||
|
||||
from lightning_app.core.constants import ENABLE_MULTIPLE_WORKS_IN_NON_DEFAULT_CONTAINER
|
||||
|
||||
__CLOUD_COMPUTE_IDENTIFIER__ = "__cloud_compute__"
|
||||
|
||||
|
||||
@dataclass
|
||||
class _CloudComputeStore:
|
||||
id: str
|
||||
component_names: List[str]
|
||||
|
||||
def add_component_name(self, new_component_name: str) -> None:
|
||||
found_index = None
|
||||
# When the work is being named by the flow, pop its previous names
|
||||
for index, component_name in enumerate(self.component_names):
|
||||
if new_component_name.endswith(component_name.replace("root.", "")):
|
||||
found_index = index
|
||||
|
||||
if found_index is not None:
|
||||
self.component_names[found_index] = new_component_name
|
||||
else:
|
||||
if (
|
||||
len(self.component_names) == 1
|
||||
and not ENABLE_MULTIPLE_WORKS_IN_NON_DEFAULT_CONTAINER
|
||||
and self.id != "default"
|
||||
):
|
||||
raise Exception(
|
||||
f"A Cloud Compute can be assigned only to a single Work. Attached to {self.component_names[0]}"
|
||||
)
|
||||
self.component_names.append(new_component_name)
|
||||
|
||||
def remove(self, new_component_name: str) -> None:
|
||||
found_index = None
|
||||
for index, component_name in enumerate(self.component_names):
|
||||
if new_component_name == component_name:
|
||||
found_index = index
|
||||
|
||||
if found_index is not None:
|
||||
del self.component_names[found_index]
|
||||
|
||||
|
||||
_CLOUD_COMPUTE_STORE = {}
|
||||
|
||||
|
||||
@dataclass
|
||||
|
@ -43,6 +86,7 @@ class CloudCompute:
|
|||
wait_timeout: Optional[int] = None
|
||||
idle_timeout: Optional[int] = None
|
||||
shm_size: Optional[int] = 0
|
||||
_internal_id: Optional[str] = None
|
||||
|
||||
def __post_init__(self):
|
||||
if self.clusters:
|
||||
|
@ -52,9 +96,28 @@ class CloudCompute:
|
|||
|
||||
self.name = self.name.lower()
|
||||
|
||||
# All `default` CloudCompute are identified in the same way.
|
||||
if self._internal_id is None:
|
||||
self._internal_id = "default" if self.name == "default" else uuid4().hex[:7]
|
||||
|
||||
def to_dict(self):
|
||||
return {"__cloud_compute__": asdict(self)}
|
||||
return {"type": __CLOUD_COMPUTE_IDENTIFIER__, **asdict(self)}
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, d):
|
||||
return cls(**d["__cloud_compute__"])
|
||||
assert d.pop("type") == __CLOUD_COMPUTE_IDENTIFIER__
|
||||
return cls(**d)
|
||||
|
||||
@property
|
||||
def id(self) -> Optional[str]:
|
||||
return self._internal_id
|
||||
|
||||
def is_default(self) -> bool:
|
||||
return self.name == "default"
|
||||
|
||||
|
||||
def _maybe_create_cloud_compute(state: Dict) -> Union[CloudCompute, Dict]:
|
||||
if state and __CLOUD_COMPUTE_IDENTIFIER__ == state.get("type", None):
|
||||
cloud_compute = CloudCompute.from_dict(state)
|
||||
return cloud_compute
|
||||
return state
|
||||
|
|
|
@ -110,7 +110,7 @@ def _prepare_lightning_wheels_and_requirements(root: Path) -> Optional[Callable]
|
|||
download_frontend(_PROJECT_ROOT)
|
||||
_prepare_wheel(_PROJECT_ROOT)
|
||||
|
||||
logger.info("Packaged Lightning with your application.")
|
||||
logger.info(f"Packaged Lightning with your application. Version: {version}")
|
||||
|
||||
tar_name = _copy_tar(_PROJECT_ROOT, root)
|
||||
|
||||
|
@ -121,7 +121,9 @@ def _prepare_lightning_wheels_and_requirements(root: Path) -> Optional[Callable]
|
|||
# building and copying launcher wheel if installed in editable mode
|
||||
launcher_project_path = get_dist_path_if_editable_install("lightning_launcher")
|
||||
if launcher_project_path:
|
||||
logger.info("Packaged Lightning Launcher with your application.")
|
||||
from lightning_launcher.__version__ import __version__ as launcher_version
|
||||
|
||||
logger.info(f"Packaged Lightning Launcher with your application. Version: {launcher_version}")
|
||||
_prepare_wheel(launcher_project_path)
|
||||
tar_name = _copy_tar(launcher_project_path, root)
|
||||
tar_files.append(os.path.join(root, tar_name))
|
||||
|
@ -129,7 +131,9 @@ def _prepare_lightning_wheels_and_requirements(root: Path) -> Optional[Callable]
|
|||
# building and copying lightning-cloud wheel if installed in editable mode
|
||||
lightning_cloud_project_path = get_dist_path_if_editable_install("lightning_cloud")
|
||||
if lightning_cloud_project_path:
|
||||
logger.info("Packaged Lightning Cloud with your application.")
|
||||
from lightning_cloud.__version__ import __version__ as cloud_version
|
||||
|
||||
logger.info(f"Packaged Lightning Cloud with your application. Version: {cloud_version}")
|
||||
_prepare_wheel(lightning_cloud_project_path)
|
||||
tar_name = _copy_tar(lightning_cloud_project_path, root)
|
||||
tar_files.append(os.path.join(root, tar_name))
|
||||
|
|
|
@ -11,6 +11,7 @@ from tests_app import _PROJECT_ROOT
|
|||
|
||||
from lightning_app.storage.path import storage_root_dir
|
||||
from lightning_app.utilities.component import _set_context
|
||||
from lightning_app.utilities.packaging import cloud_compute
|
||||
from lightning_app.utilities.packaging.app_config import _APP_CONFIG_FILENAME
|
||||
from lightning_app.utilities.state import AppState
|
||||
|
||||
|
@ -74,6 +75,7 @@ def clear_app_state_state_variables():
|
|||
lightning_app.utilities.state._STATE = None
|
||||
lightning_app.utilities.state._LAST_STATE = None
|
||||
AppState._MY_AFFILIATION = ()
|
||||
cloud_compute._CLOUD_COMPUTE_STORE.clear()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
|
|
|
@ -11,7 +11,7 @@ from deepdiff import Delta
|
|||
from pympler import asizeof
|
||||
from tests_app import _PROJECT_ROOT
|
||||
|
||||
from lightning_app import LightningApp, LightningFlow, LightningWork # F401
|
||||
from lightning_app import CloudCompute, LightningApp, LightningFlow, LightningWork # F401
|
||||
from lightning_app.core.constants import (
|
||||
FLOW_DURATION_SAMPLES,
|
||||
FLOW_DURATION_THRESHOLD,
|
||||
|
@ -27,6 +27,7 @@ from lightning_app.testing.helpers import RunIf
|
|||
from lightning_app.testing.testing import LightningTestApp
|
||||
from lightning_app.utilities.app_helpers import affiliation
|
||||
from lightning_app.utilities.enum import AppStage, WorkStageStatus, WorkStopReasons
|
||||
from lightning_app.utilities.packaging import cloud_compute
|
||||
from lightning_app.utilities.redis import check_if_redis_running
|
||||
from lightning_app.utilities.warnings import LightningFlowWarning
|
||||
|
||||
|
@ -255,7 +256,7 @@ def test_nested_component(runtime_cls):
|
|||
assert app.root.b.c.d.e.w_e.c == 1
|
||||
|
||||
|
||||
class WorkCC(LightningWork):
|
||||
class WorkCCC(LightningWork):
|
||||
def run(self):
|
||||
pass
|
||||
|
||||
|
@ -263,7 +264,7 @@ class WorkCC(LightningWork):
|
|||
class CC(LightningFlow):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.work_cc = WorkCC()
|
||||
self.work_cc = WorkCCC()
|
||||
|
||||
def run(self):
|
||||
pass
|
||||
|
@ -719,7 +720,7 @@ class WorkDD(LightningWork):
|
|||
self.counter += 1
|
||||
|
||||
|
||||
class FlowCC(LightningFlow):
|
||||
class FlowCCTolerance(LightningFlow):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.work = WorkDD()
|
||||
|
@ -744,7 +745,7 @@ class FaultToleranceLightningTestApp(LightningTestApp):
|
|||
# TODO (tchaton) Resolve this test with Resumable App.
|
||||
@RunIf(skip_windows=True)
|
||||
def test_fault_tolerance_work():
|
||||
app = FaultToleranceLightningTestApp(FlowCC())
|
||||
app = FaultToleranceLightningTestApp(FlowCCTolerance())
|
||||
MultiProcessRuntime(app, start_server=False).dispatch()
|
||||
assert app.root.work.counter == 2
|
||||
|
||||
|
@ -952,8 +953,8 @@ class SizeFlow(LightningFlow):
|
|||
def test_state_size_constant_growth():
|
||||
app = LightningApp(SizeFlow())
|
||||
MultiProcessRuntime(app, start_server=False).dispatch()
|
||||
assert app.root._state_sizes[0] <= 5904
|
||||
assert app.root._state_sizes[20] <= 23736
|
||||
assert app.root._state_sizes[0] <= 6952
|
||||
assert app.root._state_sizes[20] <= 24896
|
||||
|
||||
|
||||
class FlowUpdated(LightningFlow):
|
||||
|
@ -1041,3 +1042,70 @@ class TestLightningHasUpdatedApp(LightningApp):
|
|||
def test_lightning_app_has_updated():
|
||||
app = TestLightningHasUpdatedApp(FlowPath())
|
||||
MultiProcessRuntime(app, start_server=False).dispatch()
|
||||
|
||||
|
||||
class WorkCC(LightningWork):
|
||||
def run(self):
|
||||
pass
|
||||
|
||||
|
||||
class FlowCC(LightningFlow):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.cloud_compute = CloudCompute(name="gpu", _internal_id="a")
|
||||
self.work_a = WorkCC(cloud_compute=self.cloud_compute)
|
||||
self.work_b = WorkCC(cloud_compute=self.cloud_compute)
|
||||
self.work_c = WorkCC()
|
||||
assert self.work_a.cloud_compute._internal_id == self.work_b.cloud_compute._internal_id
|
||||
|
||||
def run(self):
|
||||
self.work_d = WorkCC()
|
||||
|
||||
|
||||
class FlowWrapper(LightningFlow):
|
||||
def __init__(self, flow):
|
||||
super().__init__()
|
||||
self.w = flow
|
||||
|
||||
|
||||
def test_cloud_compute_binding():
|
||||
|
||||
cloud_compute.ENABLE_MULTIPLE_WORKS_IN_NON_DEFAULT_CONTAINER = True
|
||||
|
||||
assert cloud_compute._CLOUD_COMPUTE_STORE == {}
|
||||
flow = FlowCC()
|
||||
assert len(cloud_compute._CLOUD_COMPUTE_STORE) == 2
|
||||
assert cloud_compute._CLOUD_COMPUTE_STORE["default"].component_names == ["root.work_c"]
|
||||
assert cloud_compute._CLOUD_COMPUTE_STORE["a"].component_names == ["root.work_a", "root.work_b"]
|
||||
|
||||
wrapper = FlowWrapper(flow)
|
||||
assert cloud_compute._CLOUD_COMPUTE_STORE["default"].component_names == ["root.w.work_c"]
|
||||
assert cloud_compute._CLOUD_COMPUTE_STORE["a"].component_names == ["root.w.work_a", "root.w.work_b"]
|
||||
|
||||
_ = FlowWrapper(wrapper)
|
||||
assert cloud_compute._CLOUD_COMPUTE_STORE["default"].component_names == ["root.w.w.work_c"]
|
||||
assert cloud_compute._CLOUD_COMPUTE_STORE["a"].component_names == ["root.w.w.work_a", "root.w.w.work_b"]
|
||||
|
||||
assert "__cloud_compute__" == flow.state["vars"]["cloud_compute"]["type"]
|
||||
assert "__cloud_compute__" == flow.work_a.state["vars"]["_cloud_compute"]["type"]
|
||||
assert "__cloud_compute__" == flow.work_b.state["vars"]["_cloud_compute"]["type"]
|
||||
assert "__cloud_compute__" == flow.work_c.state["vars"]["_cloud_compute"]["type"]
|
||||
work_a_id = flow.work_a.state["vars"]["_cloud_compute"]["_internal_id"]
|
||||
work_b_id = flow.work_b.state["vars"]["_cloud_compute"]["_internal_id"]
|
||||
work_c_id = flow.work_c.state["vars"]["_cloud_compute"]["_internal_id"]
|
||||
assert work_a_id == work_b_id
|
||||
assert work_a_id != work_c_id
|
||||
assert work_c_id == "default"
|
||||
|
||||
flow.work_a.cloud_compute = CloudCompute(name="something_else")
|
||||
assert cloud_compute._CLOUD_COMPUTE_STORE["a"].component_names == ["root.w.w.work_b"]
|
||||
|
||||
flow.set_state(flow.state)
|
||||
assert isinstance(flow.cloud_compute, CloudCompute)
|
||||
assert isinstance(flow.work_a.cloud_compute, CloudCompute)
|
||||
assert isinstance(flow.work_c.cloud_compute, CloudCompute)
|
||||
|
||||
cloud_compute.ENABLE_MULTIPLE_WORKS_IN_NON_DEFAULT_CONTAINER = False
|
||||
|
||||
with pytest.raises(Exception, match="A Cloud Compute can be assigned only to a single Work"):
|
||||
FlowCC()
|
||||
|
|
|
@ -325,6 +325,17 @@ def test_lightning_flow_and_work():
|
|||
"_paths": {},
|
||||
"_restarting": False,
|
||||
"_internal_ip": "",
|
||||
"_cloud_compute": {
|
||||
"type": "__cloud_compute__",
|
||||
"name": "default",
|
||||
"disk_size": 0,
|
||||
"clusters": None,
|
||||
"preemptible": False,
|
||||
"wait_timeout": None,
|
||||
"idle_timeout": None,
|
||||
"shm_size": 0,
|
||||
"_internal_id": "default",
|
||||
},
|
||||
},
|
||||
"calls": {CacheCallsKeys.LATEST_CALL_HASH: None},
|
||||
"changes": {},
|
||||
|
@ -339,6 +350,17 @@ def test_lightning_flow_and_work():
|
|||
"_paths": {},
|
||||
"_restarting": False,
|
||||
"_internal_ip": "",
|
||||
"_cloud_compute": {
|
||||
"type": "__cloud_compute__",
|
||||
"name": "default",
|
||||
"disk_size": 0,
|
||||
"clusters": None,
|
||||
"preemptible": False,
|
||||
"wait_timeout": None,
|
||||
"idle_timeout": None,
|
||||
"shm_size": 0,
|
||||
"_internal_id": "default",
|
||||
},
|
||||
},
|
||||
"calls": {CacheCallsKeys.LATEST_CALL_HASH: None},
|
||||
"changes": {},
|
||||
|
@ -369,6 +391,17 @@ def test_lightning_flow_and_work():
|
|||
"_paths": {},
|
||||
"_restarting": False,
|
||||
"_internal_ip": "",
|
||||
"_cloud_compute": {
|
||||
"type": "__cloud_compute__",
|
||||
"name": "default",
|
||||
"disk_size": 0,
|
||||
"clusters": None,
|
||||
"preemptible": False,
|
||||
"wait_timeout": None,
|
||||
"idle_timeout": None,
|
||||
"shm_size": 0,
|
||||
"_internal_id": "default",
|
||||
},
|
||||
},
|
||||
"calls": {CacheCallsKeys.LATEST_CALL_HASH: None},
|
||||
"changes": {},
|
||||
|
@ -383,6 +416,17 @@ def test_lightning_flow_and_work():
|
|||
"_paths": {},
|
||||
"_restarting": False,
|
||||
"_internal_ip": "",
|
||||
"_cloud_compute": {
|
||||
"type": "__cloud_compute__",
|
||||
"name": "default",
|
||||
"disk_size": 0,
|
||||
"clusters": None,
|
||||
"preemptible": False,
|
||||
"wait_timeout": None,
|
||||
"idle_timeout": None,
|
||||
"shm_size": 0,
|
||||
"_internal_id": "default",
|
||||
},
|
||||
},
|
||||
"calls": {
|
||||
CacheCallsKeys.LATEST_CALL_HASH: None,
|
||||
|
|
|
@ -45,6 +45,17 @@ def test_dict():
|
|||
"_paths": {},
|
||||
"_restarting": False,
|
||||
"_internal_ip": "",
|
||||
"_cloud_compute": {
|
||||
"type": "__cloud_compute__",
|
||||
"name": "default",
|
||||
"disk_size": 0,
|
||||
"clusters": None,
|
||||
"preemptible": False,
|
||||
"wait_timeout": None,
|
||||
"idle_timeout": None,
|
||||
"shm_size": 0,
|
||||
"_internal_id": "default",
|
||||
},
|
||||
}
|
||||
for k in ("a", "b", "c", "d")
|
||||
)
|
||||
|
@ -68,6 +79,17 @@ def test_dict():
|
|||
"_paths": {},
|
||||
"_restarting": False,
|
||||
"_internal_ip": "",
|
||||
"_cloud_compute": {
|
||||
"type": "__cloud_compute__",
|
||||
"name": "default",
|
||||
"disk_size": 0,
|
||||
"clusters": None,
|
||||
"preemptible": False,
|
||||
"wait_timeout": None,
|
||||
"idle_timeout": None,
|
||||
"shm_size": 0,
|
||||
"_internal_id": "default",
|
||||
},
|
||||
}
|
||||
for k in ("a", "b", "c", "d")
|
||||
)
|
||||
|
@ -91,6 +113,17 @@ def test_dict():
|
|||
"_paths": {},
|
||||
"_restarting": False,
|
||||
"_internal_ip": "",
|
||||
"_cloud_compute": {
|
||||
"type": "__cloud_compute__",
|
||||
"name": "default",
|
||||
"disk_size": 0,
|
||||
"clusters": None,
|
||||
"preemptible": False,
|
||||
"wait_timeout": None,
|
||||
"idle_timeout": None,
|
||||
"shm_size": 0,
|
||||
"_internal_id": "default",
|
||||
},
|
||||
}
|
||||
for k in ("a", "b", "c", "d")
|
||||
)
|
||||
|
@ -166,6 +199,17 @@ def test_list():
|
|||
"_paths": {},
|
||||
"_restarting": False,
|
||||
"_internal_ip": "",
|
||||
"_cloud_compute": {
|
||||
"type": "__cloud_compute__",
|
||||
"name": "default",
|
||||
"disk_size": 0,
|
||||
"clusters": None,
|
||||
"preemptible": False,
|
||||
"wait_timeout": None,
|
||||
"idle_timeout": None,
|
||||
"shm_size": 0,
|
||||
"_internal_id": "default",
|
||||
},
|
||||
}
|
||||
for i in range(4)
|
||||
)
|
||||
|
@ -189,6 +233,17 @@ def test_list():
|
|||
"_paths": {},
|
||||
"_restarting": False,
|
||||
"_internal_ip": "",
|
||||
"_cloud_compute": {
|
||||
"type": "__cloud_compute__",
|
||||
"name": "default",
|
||||
"disk_size": 0,
|
||||
"clusters": None,
|
||||
"preemptible": False,
|
||||
"wait_timeout": None,
|
||||
"idle_timeout": None,
|
||||
"shm_size": 0,
|
||||
"_internal_id": "default",
|
||||
},
|
||||
}
|
||||
for i in range(4)
|
||||
)
|
||||
|
@ -207,6 +262,17 @@ def test_list():
|
|||
"_paths": {},
|
||||
"_restarting": False,
|
||||
"_internal_ip": "",
|
||||
"_cloud_compute": {
|
||||
"type": "__cloud_compute__",
|
||||
"name": "default",
|
||||
"disk_size": 0,
|
||||
"clusters": None,
|
||||
"preemptible": False,
|
||||
"wait_timeout": None,
|
||||
"idle_timeout": None,
|
||||
"shm_size": 0,
|
||||
"_internal_id": "default",
|
||||
},
|
||||
}
|
||||
for i in range(4)
|
||||
)
|
||||
|
|
|
@ -3,7 +3,6 @@ from unittest import mock
|
|||
|
||||
import pytest
|
||||
|
||||
from lightning.__version__ import version
|
||||
from lightning_app.testing.helpers import RunIf
|
||||
from lightning_app.utilities.packaging import lightning_utils
|
||||
from lightning_app.utilities.packaging.lightning_utils import (
|
||||
|
@ -16,8 +15,10 @@ def test_prepare_lightning_wheels_and_requirement(tmpdir):
|
|||
"""This test ensures the lightning source gets packaged inside the lightning repo."""
|
||||
|
||||
cleanup_handle = _prepare_lightning_wheels_and_requirements(tmpdir)
|
||||
from lightning.__version__ import version
|
||||
|
||||
tar_name = f"lightning-{version}.tar.gz"
|
||||
assert sorted(os.listdir(tmpdir)) == [tar_name]
|
||||
assert sorted(os.listdir(tmpdir))[0] == tar_name
|
||||
cleanup_handle()
|
||||
assert os.listdir(tmpdir) == []
|
||||
|
||||
|
|
|
@ -49,7 +49,7 @@ def test_extract_metadata_from_component():
|
|||
"docstring": "WorkA.",
|
||||
"local_build_config": {"__build_config__": ANY},
|
||||
"cloud_build_config": {"__build_config__": ANY},
|
||||
"cloud_compute": {"__cloud_compute__": ANY},
|
||||
"cloud_compute": ANY,
|
||||
},
|
||||
{
|
||||
"affiliation": ["root", "flow_a_2"],
|
||||
|
@ -64,7 +64,7 @@ def test_extract_metadata_from_component():
|
|||
"docstring": "WorkA.",
|
||||
"local_build_config": {"__build_config__": ANY},
|
||||
"cloud_build_config": {"__build_config__": ANY},
|
||||
"cloud_compute": {"__cloud_compute__": ANY},
|
||||
"cloud_compute": ANY,
|
||||
},
|
||||
{
|
||||
"affiliation": ["root", "flow_b"],
|
||||
|
@ -79,6 +79,6 @@ def test_extract_metadata_from_component():
|
|||
"docstring": "WorkB.",
|
||||
"local_build_config": {"__build_config__": ANY},
|
||||
"cloud_build_config": {"__build_config__": ANY},
|
||||
"cloud_compute": {"__cloud_compute__": ANY},
|
||||
"cloud_compute": ANY,
|
||||
},
|
||||
]
|
||||
|
|
Loading…
Reference in New Issue