[App] Add `start_with_flow` flag to works (#15591)

* Initial commit

* Update cloud runner

* Add `start_with_flow` flag

* Update CHANGELOG.md

* Update src/lightning_app/core/work.py

Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com>

* Update cloud runner

* Revert, not needed

Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com>
This commit is contained in:
Ethan Harris 2022-11-09 13:54:22 +00:00 committed by GitHub
parent fc78d8d6e5
commit 733695d037
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 53 additions and 24 deletions

View File

@ -16,6 +16,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Expose `RunWorkExecutor` to the work and provides default ones for the `MultiNode` Component ([#15561](https://github.com/Lightning-AI/lightning/pull/15561))
- Added a `start_with_flow` flag to the `LightningWork` which can be disabled to prevent the work from starting at the same time as the flow ([#15591](https://github.com/Lightning-AI/lightning/pull/15591))
### Changed

View File

@ -58,6 +58,7 @@ class LightningWork:
cloud_build_config: Optional[BuildConfig] = None,
cloud_compute: Optional[CloudCompute] = None,
run_once: Optional[bool] = None, # TODO: Remove run_once
start_with_flow: bool = True,
):
"""LightningWork, or Work in short, is a building block for long-running jobs.
@ -80,6 +81,8 @@ class LightningWork:
local_build_config: The local BuildConfig isn't used until Lightning supports DockerRuntime.
cloud_build_config: The cloud BuildConfig enables user to easily configure machine before running this work.
run_once: Deprecated in favor of cache_calls. This will be removed soon.
start_with_flow: Whether the work should be started at the same time as the root flow. Only applies to works
defined in ``__init__``.
**Learn More About Lightning Work Inner Workings**
@ -141,6 +144,7 @@ class LightningWork:
self._request_queue: Optional[BaseQueue] = None
self._response_queue: Optional[BaseQueue] = None
self._restarting = False
self._start_with_flow = start_with_flow
self._local_build_config = local_build_config or BuildConfig()
self._cloud_build_config = cloud_build_config or BuildConfig()
self._cloud_compute = cloud_compute or CloudCompute()

View File

@ -136,9 +136,12 @@ class CloudRuntime(Runtime):
if not ENABLE_PUSHING_STATE_ENDPOINT:
v1_env_vars.append(V1EnvVar(name="ENABLE_PUSHING_STATE_ENDPOINT", value="0"))
work_reqs: List[V1Work] = []
works: List[V1Work] = []
for flow in self.app.flows:
for work in flow.works(recurse=False):
if not work._start_with_flow:
continue
work_requirements = "\n".join(work.cloud_build_config.requirements)
build_spec = V1BuildSpec(
commands=work.cloud_build_config.build_commands(),
@ -151,6 +154,7 @@ class CloudRuntime(Runtime):
name=work.cloud_compute.name,
count=1,
disk_size=work.cloud_compute.disk_size,
preemptible=work.cloud_compute.preemptible,
shm_size=work.cloud_compute.shm_size,
)
@ -198,13 +202,13 @@ class CloudRuntime(Runtime):
)
random_name = "".join(random.choice(string.ascii_lowercase) for _ in range(5))
spec = V1LightningworkSpec(
work_spec = V1LightningworkSpec(
build_spec=build_spec,
drives=drive_specs,
user_requested_compute_config=user_compute_config,
network_config=[V1NetworkConfig(name=random_name, port=work.port)],
)
work_reqs.append(V1Work(name=work.name, spec=spec))
works.append(V1Work(name=work.name, spec=work_spec))
# We need to collect a spec for each flow that contains a frontend so that the backend knows
# for which flows it needs to start servers by invoking the cli (see the serve_frontend() method below)
@ -333,9 +337,6 @@ class CloudRuntime(Runtime):
if app_config.cluster_id is not None:
self._ensure_cluster_project_binding(project.project_id, app_config.cluster_id)
for work_req in work_reqs:
work_req.spec.cluster_id = app_config.cluster_id
release_body = Body8(
app_entrypoint_file=app_spec.app_entrypoint_file,
enable_app_server=app_spec.enable_app_server,
@ -343,13 +344,13 @@ class CloudRuntime(Runtime):
image_spec=app_spec.image_spec,
cluster_id=app_config.cluster_id,
network_config=network_configs,
works=[V1Work(name=work_req.name, spec=work_req.spec) for work_req in work_reqs],
works=works,
local_source=True,
dependency_cache_key=app_spec.dependency_cache_key,
user_requested_flow_compute_config=app_spec.user_requested_flow_compute_config,
)
# create / upload the new app release / instace
# create / upload the new app release
lightning_app_release = self.backend.client.lightningapp_v2_service_create_lightningapp_release(
project_id=project.project_id, app_id=lit_app.id, body=release_body
)

View File

@ -368,8 +368,11 @@ class TestAppCreationClient:
assert body.dependency_cache_key is None
@mock.patch("lightning_app.runners.backends.cloud.LightningClient", mock.MagicMock())
@pytest.mark.parametrize("lightningapps", [[], [MagicMock()]])
def test_call_with_work_app(self, lightningapps, monkeypatch, tmpdir):
@pytest.mark.parametrize(
"lightningapps,start_with_flow",
[([], False), ([MagicMock()], False), ([MagicMock()], True)],
)
def test_call_with_work_app(self, lightningapps, start_with_flow, monkeypatch, tmpdir):
source_code_root_dir = Path(tmpdir / "src").absolute()
source_code_root_dir.mkdir()
Path(source_code_root_dir / ".lightning").write_text("cluster_id: test\nname: myapp")
@ -399,7 +402,7 @@ class TestAppCreationClient:
app = mock.MagicMock()
flow = mock.MagicMock()
work = MyWork()
work = MyWork(start_with_flow=start_with_flow)
monkeypatch.setattr(work, "_name", "test-work")
monkeypatch.setattr(work._cloud_build_config, "build_commands", lambda: ["echo 'start'"])
monkeypatch.setattr(work._cloud_build_config, "requirements", ["torch==1.0.0", "numpy==1.0.0"])
@ -431,7 +434,10 @@ class TestAppCreationClient:
package_manager=V1PackageManager.PIP, path="requirements.txt"
)
),
works=[
)
if start_with_flow:
expected_body.works = [
V1Work(
name="test-work",
spec=V1LightningworkSpec(
@ -444,14 +450,19 @@ class TestAppCreationClient:
),
drives=[],
user_requested_compute_config=V1UserRequestedComputeConfig(
name="default", count=1, disk_size=0, shm_size=0
name="default",
count=1,
disk_size=0,
shm_size=0,
preemptible=False,
),
network_config=[V1NetworkConfig(name=mock.ANY, host=None, port=8080)],
cluster_id="test",
),
)
],
)
]
else:
expected_body.works = []
mock_client.lightningapp_v2_service_create_lightningapp_release.assert_called_once_with(
project_id="test-project-id", app_id=mock.ANY, body=expected_body
)
@ -637,10 +648,13 @@ class TestAppCreationClient:
),
],
user_requested_compute_config=V1UserRequestedComputeConfig(
name="default", count=1, disk_size=0, shm_size=0
name="default",
count=1,
disk_size=0,
shm_size=0,
preemptible=False,
),
network_config=[V1NetworkConfig(name=mock.ANY, host=None, port=8080)],
cluster_id="test",
),
)
],
@ -788,10 +802,13 @@ class TestAppCreationClient:
),
drives=[lit_drive_2_spec, lit_drive_1_spec],
user_requested_compute_config=V1UserRequestedComputeConfig(
name="default", count=1, disk_size=0, shm_size=0
name="default",
count=1,
disk_size=0,
shm_size=0,
preemptible=False,
),
network_config=[V1NetworkConfig(name=mock.ANY, host=None, port=8080)],
cluster_id="test",
),
)
],
@ -824,10 +841,13 @@ class TestAppCreationClient:
),
drives=[lit_drive_1_spec, lit_drive_2_spec],
user_requested_compute_config=V1UserRequestedComputeConfig(
name="default", count=1, disk_size=0, shm_size=0
name="default",
count=1,
disk_size=0,
shm_size=0,
preemptible=False,
),
network_config=[V1NetworkConfig(name=mock.ANY, host=None, port=8080)],
cluster_id="test",
),
)
],
@ -989,10 +1009,13 @@ class TestAppCreationClient:
),
],
user_requested_compute_config=V1UserRequestedComputeConfig(
name="default", count=1, disk_size=0, shm_size=0
name="default",
count=1,
disk_size=0,
shm_size=0,
preemptible=False,
),
network_config=[V1NetworkConfig(name=mock.ANY, host=None, port=8080)],
cluster_id="test",
),
)
],