[App] Add `start_with_flow` flag to works (#15591)
* Initial commit * Update cloud runner * Add `start_with_flow` flag * Update CHANGELOG.md * Update src/lightning_app/core/work.py Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com> * Update cloud runner * Revert, not needed Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com>
This commit is contained in:
parent
fc78d8d6e5
commit
733695d037
|
@ -16,6 +16,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
|
|||
|
||||
- Expose `RunWorkExecutor` to the work and provides default ones for the `MultiNode` Component ([#15561](https://github.com/Lightning-AI/lightning/pull/15561))
|
||||
|
||||
- Added a `start_with_flow` flag to the `LightningWork` which can be disabled to prevent the work from starting at the same time as the flow ([#15591](https://github.com/Lightning-AI/lightning/pull/15591))
|
||||
|
||||
### Changed
|
||||
|
||||
|
|
|
@ -58,6 +58,7 @@ class LightningWork:
|
|||
cloud_build_config: Optional[BuildConfig] = None,
|
||||
cloud_compute: Optional[CloudCompute] = None,
|
||||
run_once: Optional[bool] = None, # TODO: Remove run_once
|
||||
start_with_flow: bool = True,
|
||||
):
|
||||
"""LightningWork, or Work in short, is a building block for long-running jobs.
|
||||
|
||||
|
@ -80,6 +81,8 @@ class LightningWork:
|
|||
local_build_config: The local BuildConfig isn't used until Lightning supports DockerRuntime.
|
||||
cloud_build_config: The cloud BuildConfig enables user to easily configure machine before running this work.
|
||||
run_once: Deprecated in favor of cache_calls. This will be removed soon.
|
||||
start_with_flow: Whether the work should be started at the same time as the root flow. Only applies to works
|
||||
defined in ``__init__``.
|
||||
|
||||
**Learn More About Lightning Work Inner Workings**
|
||||
|
||||
|
@ -141,6 +144,7 @@ class LightningWork:
|
|||
self._request_queue: Optional[BaseQueue] = None
|
||||
self._response_queue: Optional[BaseQueue] = None
|
||||
self._restarting = False
|
||||
self._start_with_flow = start_with_flow
|
||||
self._local_build_config = local_build_config or BuildConfig()
|
||||
self._cloud_build_config = cloud_build_config or BuildConfig()
|
||||
self._cloud_compute = cloud_compute or CloudCompute()
|
||||
|
|
|
@ -136,9 +136,12 @@ class CloudRuntime(Runtime):
|
|||
if not ENABLE_PUSHING_STATE_ENDPOINT:
|
||||
v1_env_vars.append(V1EnvVar(name="ENABLE_PUSHING_STATE_ENDPOINT", value="0"))
|
||||
|
||||
work_reqs: List[V1Work] = []
|
||||
works: List[V1Work] = []
|
||||
for flow in self.app.flows:
|
||||
for work in flow.works(recurse=False):
|
||||
if not work._start_with_flow:
|
||||
continue
|
||||
|
||||
work_requirements = "\n".join(work.cloud_build_config.requirements)
|
||||
build_spec = V1BuildSpec(
|
||||
commands=work.cloud_build_config.build_commands(),
|
||||
|
@ -151,6 +154,7 @@ class CloudRuntime(Runtime):
|
|||
name=work.cloud_compute.name,
|
||||
count=1,
|
||||
disk_size=work.cloud_compute.disk_size,
|
||||
preemptible=work.cloud_compute.preemptible,
|
||||
shm_size=work.cloud_compute.shm_size,
|
||||
)
|
||||
|
||||
|
@ -198,13 +202,13 @@ class CloudRuntime(Runtime):
|
|||
)
|
||||
|
||||
random_name = "".join(random.choice(string.ascii_lowercase) for _ in range(5))
|
||||
spec = V1LightningworkSpec(
|
||||
work_spec = V1LightningworkSpec(
|
||||
build_spec=build_spec,
|
||||
drives=drive_specs,
|
||||
user_requested_compute_config=user_compute_config,
|
||||
network_config=[V1NetworkConfig(name=random_name, port=work.port)],
|
||||
)
|
||||
work_reqs.append(V1Work(name=work.name, spec=spec))
|
||||
works.append(V1Work(name=work.name, spec=work_spec))
|
||||
|
||||
# We need to collect a spec for each flow that contains a frontend so that the backend knows
|
||||
# for which flows it needs to start servers by invoking the cli (see the serve_frontend() method below)
|
||||
|
@ -333,9 +337,6 @@ class CloudRuntime(Runtime):
|
|||
if app_config.cluster_id is not None:
|
||||
self._ensure_cluster_project_binding(project.project_id, app_config.cluster_id)
|
||||
|
||||
for work_req in work_reqs:
|
||||
work_req.spec.cluster_id = app_config.cluster_id
|
||||
|
||||
release_body = Body8(
|
||||
app_entrypoint_file=app_spec.app_entrypoint_file,
|
||||
enable_app_server=app_spec.enable_app_server,
|
||||
|
@ -343,13 +344,13 @@ class CloudRuntime(Runtime):
|
|||
image_spec=app_spec.image_spec,
|
||||
cluster_id=app_config.cluster_id,
|
||||
network_config=network_configs,
|
||||
works=[V1Work(name=work_req.name, spec=work_req.spec) for work_req in work_reqs],
|
||||
works=works,
|
||||
local_source=True,
|
||||
dependency_cache_key=app_spec.dependency_cache_key,
|
||||
user_requested_flow_compute_config=app_spec.user_requested_flow_compute_config,
|
||||
)
|
||||
|
||||
# create / upload the new app release / instace
|
||||
# create / upload the new app release
|
||||
lightning_app_release = self.backend.client.lightningapp_v2_service_create_lightningapp_release(
|
||||
project_id=project.project_id, app_id=lit_app.id, body=release_body
|
||||
)
|
||||
|
|
|
@ -368,8 +368,11 @@ class TestAppCreationClient:
|
|||
assert body.dependency_cache_key is None
|
||||
|
||||
@mock.patch("lightning_app.runners.backends.cloud.LightningClient", mock.MagicMock())
|
||||
@pytest.mark.parametrize("lightningapps", [[], [MagicMock()]])
|
||||
def test_call_with_work_app(self, lightningapps, monkeypatch, tmpdir):
|
||||
@pytest.mark.parametrize(
|
||||
"lightningapps,start_with_flow",
|
||||
[([], False), ([MagicMock()], False), ([MagicMock()], True)],
|
||||
)
|
||||
def test_call_with_work_app(self, lightningapps, start_with_flow, monkeypatch, tmpdir):
|
||||
source_code_root_dir = Path(tmpdir / "src").absolute()
|
||||
source_code_root_dir.mkdir()
|
||||
Path(source_code_root_dir / ".lightning").write_text("cluster_id: test\nname: myapp")
|
||||
|
@ -399,7 +402,7 @@ class TestAppCreationClient:
|
|||
app = mock.MagicMock()
|
||||
flow = mock.MagicMock()
|
||||
|
||||
work = MyWork()
|
||||
work = MyWork(start_with_flow=start_with_flow)
|
||||
monkeypatch.setattr(work, "_name", "test-work")
|
||||
monkeypatch.setattr(work._cloud_build_config, "build_commands", lambda: ["echo 'start'"])
|
||||
monkeypatch.setattr(work._cloud_build_config, "requirements", ["torch==1.0.0", "numpy==1.0.0"])
|
||||
|
@ -431,7 +434,10 @@ class TestAppCreationClient:
|
|||
package_manager=V1PackageManager.PIP, path="requirements.txt"
|
||||
)
|
||||
),
|
||||
works=[
|
||||
)
|
||||
|
||||
if start_with_flow:
|
||||
expected_body.works = [
|
||||
V1Work(
|
||||
name="test-work",
|
||||
spec=V1LightningworkSpec(
|
||||
|
@ -444,14 +450,19 @@ class TestAppCreationClient:
|
|||
),
|
||||
drives=[],
|
||||
user_requested_compute_config=V1UserRequestedComputeConfig(
|
||||
name="default", count=1, disk_size=0, shm_size=0
|
||||
name="default",
|
||||
count=1,
|
||||
disk_size=0,
|
||||
shm_size=0,
|
||||
preemptible=False,
|
||||
),
|
||||
network_config=[V1NetworkConfig(name=mock.ANY, host=None, port=8080)],
|
||||
cluster_id="test",
|
||||
),
|
||||
)
|
||||
],
|
||||
)
|
||||
]
|
||||
else:
|
||||
expected_body.works = []
|
||||
|
||||
mock_client.lightningapp_v2_service_create_lightningapp_release.assert_called_once_with(
|
||||
project_id="test-project-id", app_id=mock.ANY, body=expected_body
|
||||
)
|
||||
|
@ -637,10 +648,13 @@ class TestAppCreationClient:
|
|||
),
|
||||
],
|
||||
user_requested_compute_config=V1UserRequestedComputeConfig(
|
||||
name="default", count=1, disk_size=0, shm_size=0
|
||||
name="default",
|
||||
count=1,
|
||||
disk_size=0,
|
||||
shm_size=0,
|
||||
preemptible=False,
|
||||
),
|
||||
network_config=[V1NetworkConfig(name=mock.ANY, host=None, port=8080)],
|
||||
cluster_id="test",
|
||||
),
|
||||
)
|
||||
],
|
||||
|
@ -788,10 +802,13 @@ class TestAppCreationClient:
|
|||
),
|
||||
drives=[lit_drive_2_spec, lit_drive_1_spec],
|
||||
user_requested_compute_config=V1UserRequestedComputeConfig(
|
||||
name="default", count=1, disk_size=0, shm_size=0
|
||||
name="default",
|
||||
count=1,
|
||||
disk_size=0,
|
||||
shm_size=0,
|
||||
preemptible=False,
|
||||
),
|
||||
network_config=[V1NetworkConfig(name=mock.ANY, host=None, port=8080)],
|
||||
cluster_id="test",
|
||||
),
|
||||
)
|
||||
],
|
||||
|
@ -824,10 +841,13 @@ class TestAppCreationClient:
|
|||
),
|
||||
drives=[lit_drive_1_spec, lit_drive_2_spec],
|
||||
user_requested_compute_config=V1UserRequestedComputeConfig(
|
||||
name="default", count=1, disk_size=0, shm_size=0
|
||||
name="default",
|
||||
count=1,
|
||||
disk_size=0,
|
||||
shm_size=0,
|
||||
preemptible=False,
|
||||
),
|
||||
network_config=[V1NetworkConfig(name=mock.ANY, host=None, port=8080)],
|
||||
cluster_id="test",
|
||||
),
|
||||
)
|
||||
],
|
||||
|
@ -989,10 +1009,13 @@ class TestAppCreationClient:
|
|||
),
|
||||
],
|
||||
user_requested_compute_config=V1UserRequestedComputeConfig(
|
||||
name="default", count=1, disk_size=0, shm_size=0
|
||||
name="default",
|
||||
count=1,
|
||||
disk_size=0,
|
||||
shm_size=0,
|
||||
preemptible=False,
|
||||
),
|
||||
network_config=[V1NetworkConfig(name=mock.ANY, host=None, port=8080)],
|
||||
cluster_id="test",
|
||||
),
|
||||
)
|
||||
],
|
||||
|
|
Loading…
Reference in New Issue