Fix default CloudCompute for flows (#15371)

* Fix default CloudCompute for flows
* Unit test added
This commit is contained in:
Dmitry Frolov 2022-10-29 03:36:59 -04:00 committed by GitHub
parent 3fb98ad074
commit 11196b1707
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 42 additions and 18 deletions

View File

@ -38,6 +38,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Fixed a bug where the upload files endpoint would raise an error when running locally ([#14924](https://github.com/Lightning-AI/lightning/pull/14924))
- Fixed BYOC cluster region selector -> hiding it from help since only us-east-1 has been tested and is recommended ([#15277]https://github.com/Lightning-AI/lightning/pull/15277)
- Fixed a bug when launching an app on multiple clusters ([#15226](https://github.com/Lightning-AI/lightning/pull/15226))
- Fixed a bug with a default CloudCompute for Lightning flows ([#15371](https://github.com/Lightning-AI/lightning/pull/15371))
## [0.6.2] - 2022-09-21

View File

@ -103,7 +103,7 @@ class LightningApp:
_validate_root_flow(root)
self._root = root
self.flow_cloud_compute = flow_cloud_compute or lightning_app.CloudCompute()
self.flow_cloud_compute = flow_cloud_compute or lightning_app.CloudCompute(name="flow-lite")
# queues definition.
self.delta_queue: Optional[BaseQueue] = None

View File

@ -68,6 +68,28 @@ class WorkWithTwoDrives(LightningWork):
pass
def get_cloud_runtime_request_body(**kwargs) -> "Body8":
default_request_body = dict(
app_entrypoint_file=mock.ANY,
enable_app_server=True,
flow_servers=[],
image_spec=None,
works=[],
local_source=True,
dependency_cache_key=mock.ANY,
user_requested_flow_compute_config=V1UserRequestedFlowComputeConfig(
name="flow-lite",
preemptible=False,
shm_size=0,
),
)
if kwargs.get("user_requested_flow_compute_config") is not None:
default_request_body["user_requested_flow_compute_config"] = kwargs["user_requested_flow_compute_config"]
return Body8(**default_request_body)
class TestAppCreationClient:
"""Testing the calls made using GridRestClient to create the app."""
@ -138,8 +160,9 @@ class TestAppCreationClient:
body=V1ProjectClusterBinding(cluster_id=new_cluster, project_id="default-project-id"),
)
@pytest.mark.parametrize("flow_cloud_compute", [None, CloudCompute(name="t2.medium")])
@mock.patch("lightning_app.runners.backends.cloud.LightningClient", mock.MagicMock())
def test_run_with_custom_flow_compute_config(self, monkeypatch):
def test_run_with_default_flow_compute_config(self, monkeypatch, flow_cloud_compute):
mock_client = mock.MagicMock()
mock_client.projects_service_list_memberships.return_value = V1ListMembershipsResponse(
memberships=[V1Membership(name="test-project", project_id="test-project-id")]
@ -155,30 +178,30 @@ class TestAppCreationClient:
cloud_backend.client = mock_client
monkeypatch.setattr(backends, "CloudBackend", mock.MagicMock(return_value=cloud_backend))
monkeypatch.setattr(cloud, "LocalSourceCodeDir", mock.MagicMock())
app = mock.MagicMock()
app.flows = []
app.frontend = {}
app.flow_cloud_compute = CloudCompute(name="t2.medium")
dummy_flow = mock.MagicMock()
monkeypatch.setattr(dummy_flow, "run", lambda *args, **kwargs: None)
if flow_cloud_compute is None:
app = LightningApp(dummy_flow)
else:
app = LightningApp(dummy_flow, flow_cloud_compute=flow_cloud_compute)
cloud_runtime = cloud.CloudRuntime(app=app, entrypoint_file="entrypoint.py")
cloud_runtime._check_uploaded_folder = mock.MagicMock()
monkeypatch.setattr(Path, "is_file", lambda *args, **kwargs: False)
monkeypatch.setattr(cloud, "Path", Path)
cloud_runtime.dispatch()
body = Body8(
app_entrypoint_file=mock.ANY,
enable_app_server=True,
flow_servers=[],
image_spec=None,
works=[],
local_source=True,
dependency_cache_key=mock.ANY,
user_requested_flow_compute_config=V1UserRequestedFlowComputeConfig(
name="t2.medium",
user_requested_flow_compute_config = None
if flow_cloud_compute is not None:
user_requested_flow_compute_config = V1UserRequestedFlowComputeConfig(
name=flow_cloud_compute.name,
preemptible=False,
shm_size=0,
),
)
)
body = get_cloud_runtime_request_body(user_requested_flow_compute_config=user_requested_flow_compute_config)
cloud_runtime.backend.client.lightningapp_v2_service_create_lightningapp_release.assert_called_once_with(
project_id="test-project-id", app_id=mock.ANY, body=body
)