diff --git a/src/lightning_app/CHANGELOG.md b/src/lightning_app/CHANGELOG.md index e00ae73f41..758884f4dc 100644 --- a/src/lightning_app/CHANGELOG.md +++ b/src/lightning_app/CHANGELOG.md @@ -16,6 +16,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Expose `RunWorkExecutor` to the work and provides default ones for the `MultiNode` Component ([#15561](https://github.com/Lightning-AI/lightning/pull/15561)) +- Added a `start_with_flow` flag to the `LightningWork` which can be disabled to prevent the work from starting at the same time as the flow ([#15591](https://github.com/Lightning-AI/lightning/pull/15591)) ### Changed diff --git a/src/lightning_app/core/work.py b/src/lightning_app/core/work.py index a68923a7ca..84a856b81f 100644 --- a/src/lightning_app/core/work.py +++ b/src/lightning_app/core/work.py @@ -58,6 +58,7 @@ class LightningWork: cloud_build_config: Optional[BuildConfig] = None, cloud_compute: Optional[CloudCompute] = None, run_once: Optional[bool] = None, # TODO: Remove run_once + start_with_flow: bool = True, ): """LightningWork, or Work in short, is a building block for long-running jobs. @@ -80,6 +81,8 @@ class LightningWork: local_build_config: The local BuildConfig isn't used until Lightning supports DockerRuntime. cloud_build_config: The cloud BuildConfig enables user to easily configure machine before running this work. run_once: Deprecated in favor of cache_calls. This will be removed soon. + start_with_flow: Whether the work should be started at the same time as the root flow. Only applies to works + defined in ``__init__``. **Learn More About Lightning Work Inner Workings** @@ -141,6 +144,7 @@ class LightningWork: self._request_queue: Optional[BaseQueue] = None self._response_queue: Optional[BaseQueue] = None self._restarting = False + self._start_with_flow = start_with_flow self._local_build_config = local_build_config or BuildConfig() self._cloud_build_config = cloud_build_config or BuildConfig() self._cloud_compute = cloud_compute or CloudCompute() diff --git a/src/lightning_app/runners/cloud.py b/src/lightning_app/runners/cloud.py index 919f7548bc..fe22c22bd4 100644 --- a/src/lightning_app/runners/cloud.py +++ b/src/lightning_app/runners/cloud.py @@ -136,9 +136,12 @@ class CloudRuntime(Runtime): if not ENABLE_PUSHING_STATE_ENDPOINT: v1_env_vars.append(V1EnvVar(name="ENABLE_PUSHING_STATE_ENDPOINT", value="0")) - work_reqs: List[V1Work] = [] + works: List[V1Work] = [] for flow in self.app.flows: for work in flow.works(recurse=False): + if not work._start_with_flow: + continue + work_requirements = "\n".join(work.cloud_build_config.requirements) build_spec = V1BuildSpec( commands=work.cloud_build_config.build_commands(), @@ -151,6 +154,7 @@ class CloudRuntime(Runtime): name=work.cloud_compute.name, count=1, disk_size=work.cloud_compute.disk_size, + preemptible=work.cloud_compute.preemptible, shm_size=work.cloud_compute.shm_size, ) @@ -198,13 +202,13 @@ class CloudRuntime(Runtime): ) random_name = "".join(random.choice(string.ascii_lowercase) for _ in range(5)) - spec = V1LightningworkSpec( + work_spec = V1LightningworkSpec( build_spec=build_spec, drives=drive_specs, user_requested_compute_config=user_compute_config, network_config=[V1NetworkConfig(name=random_name, port=work.port)], ) - work_reqs.append(V1Work(name=work.name, spec=spec)) + works.append(V1Work(name=work.name, spec=work_spec)) # We need to collect a spec for each flow that contains a frontend so that the backend knows # for which flows it needs to start servers by invoking the cli (see the serve_frontend() method below) @@ -333,9 +337,6 @@ class CloudRuntime(Runtime): if app_config.cluster_id is not None: self._ensure_cluster_project_binding(project.project_id, app_config.cluster_id) - for work_req in work_reqs: - work_req.spec.cluster_id = app_config.cluster_id - release_body = Body8( app_entrypoint_file=app_spec.app_entrypoint_file, enable_app_server=app_spec.enable_app_server, @@ -343,13 +344,13 @@ class CloudRuntime(Runtime): image_spec=app_spec.image_spec, cluster_id=app_config.cluster_id, network_config=network_configs, - works=[V1Work(name=work_req.name, spec=work_req.spec) for work_req in work_reqs], + works=works, local_source=True, dependency_cache_key=app_spec.dependency_cache_key, user_requested_flow_compute_config=app_spec.user_requested_flow_compute_config, ) - # create / upload the new app release / instace + # create / upload the new app release lightning_app_release = self.backend.client.lightningapp_v2_service_create_lightningapp_release( project_id=project.project_id, app_id=lit_app.id, body=release_body ) diff --git a/tests/tests_app/runners/test_cloud.py b/tests/tests_app/runners/test_cloud.py index 50be1ea32c..314735794e 100644 --- a/tests/tests_app/runners/test_cloud.py +++ b/tests/tests_app/runners/test_cloud.py @@ -368,8 +368,11 @@ class TestAppCreationClient: assert body.dependency_cache_key is None @mock.patch("lightning_app.runners.backends.cloud.LightningClient", mock.MagicMock()) - @pytest.mark.parametrize("lightningapps", [[], [MagicMock()]]) - def test_call_with_work_app(self, lightningapps, monkeypatch, tmpdir): + @pytest.mark.parametrize( + "lightningapps,start_with_flow", + [([], False), ([MagicMock()], False), ([MagicMock()], True)], + ) + def test_call_with_work_app(self, lightningapps, start_with_flow, monkeypatch, tmpdir): source_code_root_dir = Path(tmpdir / "src").absolute() source_code_root_dir.mkdir() Path(source_code_root_dir / ".lightning").write_text("cluster_id: test\nname: myapp") @@ -399,7 +402,7 @@ class TestAppCreationClient: app = mock.MagicMock() flow = mock.MagicMock() - work = MyWork() + work = MyWork(start_with_flow=start_with_flow) monkeypatch.setattr(work, "_name", "test-work") monkeypatch.setattr(work._cloud_build_config, "build_commands", lambda: ["echo 'start'"]) monkeypatch.setattr(work._cloud_build_config, "requirements", ["torch==1.0.0", "numpy==1.0.0"]) @@ -431,7 +434,10 @@ class TestAppCreationClient: package_manager=V1PackageManager.PIP, path="requirements.txt" ) ), - works=[ + ) + + if start_with_flow: + expected_body.works = [ V1Work( name="test-work", spec=V1LightningworkSpec( @@ -444,14 +450,19 @@ class TestAppCreationClient: ), drives=[], user_requested_compute_config=V1UserRequestedComputeConfig( - name="default", count=1, disk_size=0, shm_size=0 + name="default", + count=1, + disk_size=0, + shm_size=0, + preemptible=False, ), network_config=[V1NetworkConfig(name=mock.ANY, host=None, port=8080)], - cluster_id="test", ), ) - ], - ) + ] + else: + expected_body.works = [] + mock_client.lightningapp_v2_service_create_lightningapp_release.assert_called_once_with( project_id="test-project-id", app_id=mock.ANY, body=expected_body ) @@ -637,10 +648,13 @@ class TestAppCreationClient: ), ], user_requested_compute_config=V1UserRequestedComputeConfig( - name="default", count=1, disk_size=0, shm_size=0 + name="default", + count=1, + disk_size=0, + shm_size=0, + preemptible=False, ), network_config=[V1NetworkConfig(name=mock.ANY, host=None, port=8080)], - cluster_id="test", ), ) ], @@ -788,10 +802,13 @@ class TestAppCreationClient: ), drives=[lit_drive_2_spec, lit_drive_1_spec], user_requested_compute_config=V1UserRequestedComputeConfig( - name="default", count=1, disk_size=0, shm_size=0 + name="default", + count=1, + disk_size=0, + shm_size=0, + preemptible=False, ), network_config=[V1NetworkConfig(name=mock.ANY, host=None, port=8080)], - cluster_id="test", ), ) ], @@ -824,10 +841,13 @@ class TestAppCreationClient: ), drives=[lit_drive_1_spec, lit_drive_2_spec], user_requested_compute_config=V1UserRequestedComputeConfig( - name="default", count=1, disk_size=0, shm_size=0 + name="default", + count=1, + disk_size=0, + shm_size=0, + preemptible=False, ), network_config=[V1NetworkConfig(name=mock.ANY, host=None, port=8080)], - cluster_id="test", ), ) ], @@ -989,10 +1009,13 @@ class TestAppCreationClient: ), ], user_requested_compute_config=V1UserRequestedComputeConfig( - name="default", count=1, disk_size=0, shm_size=0 + name="default", + count=1, + disk_size=0, + shm_size=0, + preemptible=False, ), network_config=[V1NetworkConfig(name=mock.ANY, host=None, port=8080)], - cluster_id="test", ), ) ],