diff --git a/src/lightning_app/core/constants.py b/src/lightning_app/core/constants.py index 7b799b1347..1c5dc862ea 100644 --- a/src/lightning_app/core/constants.py +++ b/src/lightning_app/core/constants.py @@ -22,7 +22,7 @@ APP_SERVER_HOST = os.getenv("LIGHTNING_APP_STATE_URL", "http://127.0.0.1") APP_SERVER_PORT = 7501 APP_STATE_MAX_SIZE_BYTES = 1024 * 1024 # 1 MB -CLOUD_QUEUE_TYPE = os.getenv("LIGHTNING_CLOUD_QUEUE_TYPE", "redis") +CLOUD_QUEUE_TYPE = os.getenv("LIGHTNING_CLOUD_QUEUE_TYPE", None) WARNING_QUEUE_SIZE = 1000 # different flag because queue debug can be very noisy, and almost always not useful unless debugging the queue itself. QUEUE_DEBUG_ENABLED = bool(int(os.getenv("LIGHTNING_QUEUE_DEBUG_ENABLED", "0"))) diff --git a/src/lightning_app/runners/cloud.py b/src/lightning_app/runners/cloud.py index 2c66734d98..fe780c25f6 100644 --- a/src/lightning_app/runners/cloud.py +++ b/src/lightning_app/runners/cloud.py @@ -289,7 +289,13 @@ class CloudRuntime(Runtime): find_instances_resp = self.backend.client.lightningapp_instance_service_list_lightningapp_instances( project_id=project.project_id, app_id=lit_app.id ) - queue_server_type = V1QueueServerType.REDIS if CLOUD_QUEUE_TYPE == "redis" else V1QueueServerType.HTTP + + queue_server_type = V1QueueServerType.UNSPECIFIED + if CLOUD_QUEUE_TYPE == "http": + queue_server_type = V1QueueServerType.HTTP + elif CLOUD_QUEUE_TYPE == "redis": + queue_server_type = V1QueueServerType.REDIS + if find_instances_resp.lightningapps: existing_instance = find_instances_resp.lightningapps[0] if existing_instance.status.phase != V1LightningappInstanceState.STOPPED: diff --git a/tests/tests_app/core/test_lightning_api.py b/tests/tests_app/core/test_lightning_api.py index 53c39e1ca0..66873a70ac 100644 --- a/tests/tests_app/core/test_lightning_api.py +++ b/tests/tests_app/core/test_lightning_api.py @@ -330,6 +330,7 @@ async def test_health_endpoint_success(): @pytest.mark.anyio async def test_health_endpoint_failure(monkeypatch): monkeypatch.setenv("LIGHTNING_APP_STATE_URL", "http://someurl") # adding this to make is_running_in_cloud pass + monkeypatch.setattr(api, "CLOUD_QUEUE_TYPE", "redis") async with AsyncClient(app=fastapi_service, base_url="http://test") as client: # will respond 503 if redis is not running response = await client.get("/healthz") diff --git a/tests/tests_app/runners/test_cloud.py b/tests/tests_app/runners/test_cloud.py index 94064a8e51..b99efd9483 100644 --- a/tests/tests_app/runners/test_cloud.py +++ b/tests/tests_app/runners/test_cloud.py @@ -8,6 +8,7 @@ import pytest from lightning_cloud.openapi import ( Body8, Gridv1ImageSpec, + IdGetBody, V1BuildSpec, V1DependencyFileInfo, V1Drive, @@ -25,6 +26,7 @@ from lightning_cloud.openapi import ( V1PackageManager, V1ProjectClusterBinding, V1PythonDependencyInfo, + V1QueueServerType, V1SourceType, V1UserRequestedComputeConfig, V1Work, @@ -304,6 +306,62 @@ class TestAppCreationClient: project_id="test-project-id", app_id=mock.ANY, id=mock.ANY, body=mock.ANY ) + @mock.patch("lightning_app.runners.backends.cloud.LightningClient", mock.MagicMock()) + @pytest.mark.parametrize("lightningapps", [[], [MagicMock()]]) + def test_call_with_queue_server_type_specified(self, lightningapps, monkeypatch, tmpdir): + mock_client = mock.MagicMock() + mock_client.projects_service_list_memberships.return_value = V1ListMembershipsResponse( + memberships=[V1Membership(name="test-project", project_id="test-project-id")] + ) + mock_client.lightningapp_instance_service_list_lightningapp_instances.return_value = ( + V1ListLightningappInstancesResponse(lightningapps=[]) + ) + cloud_backend = mock.MagicMock() + cloud_backend.client = mock_client + monkeypatch.setattr(backends, "CloudBackend", mock.MagicMock(return_value=cloud_backend)) + monkeypatch.setattr(cloud, "LocalSourceCodeDir", mock.MagicMock()) + monkeypatch.setattr(cloud, "_prepare_lightning_wheels_and_requirements", mock.MagicMock()) + app = mock.MagicMock() + app.flows = [] + app.frontend = {} + cloud_runtime = cloud.CloudRuntime(app=app, entrypoint_file="entrypoint.py") + cloud_runtime._check_uploaded_folder = mock.MagicMock() + + # without requirements file + # setting is_file to False so requirements.txt existence check will return False + monkeypatch.setattr(Path, "is_file", lambda *args, **kwargs: False) + monkeypatch.setattr(cloud, "Path", Path) + cloud_runtime.dispatch() + + # calling with no env variable set + body = IdGetBody( + cluster_id=None, + desired_state=V1LightningappInstanceState.STOPPED, + env=[], + name=mock.ANY, + queue_server_type=V1QueueServerType.UNSPECIFIED, + ) + client = cloud_runtime.backend.client + client.lightningapp_v2_service_create_lightningapp_release_instance.assert_called_once_with( + project_id="test-project-id", app_id=mock.ANY, id=mock.ANY, body=body + ) + + # calling with env variable set to http + monkeypatch.setattr(cloud, "CLOUD_QUEUE_TYPE", "http") + cloud_runtime.backend.client.reset_mock() + cloud_runtime.dispatch() + body = IdGetBody( + cluster_id=None, + desired_state=V1LightningappInstanceState.STOPPED, + env=[], + name=mock.ANY, + queue_server_type=V1QueueServerType.HTTP, + ) + client = cloud_runtime.backend.client + client.lightningapp_v2_service_create_lightningapp_release_instance.assert_called_once_with( + project_id="test-project-id", app_id=mock.ANY, id=mock.ANY, body=body + ) + @mock.patch("lightning_app.runners.backends.cloud.LightningClient", mock.MagicMock()) @pytest.mark.parametrize("lightningapps", [[], [MagicMock()]]) def test_call_with_work_app_and_attached_drives(self, lightningapps, monkeypatch, tmpdir): diff --git a/tests/tests_app_examples/custom_work_dependencies/app.py b/tests/tests_app_examples/custom_work_dependencies/app.py index 06e5f40d52..53a8348a83 100644 --- a/tests/tests_app_examples/custom_work_dependencies/app.py +++ b/tests/tests_app_examples/custom_work_dependencies/app.py @@ -24,7 +24,7 @@ class WorkWithCustomBaseImage(LightningWork): def __init__(self, cloud_compute: CloudCompute = CloudCompute(), **kwargs): # this image has been created from ghcr.io/gridai/base-images:v1.8-cpu # by just adding an empty file at /content/.e2e_test - custom_image = "ghcr.io/gridai/image-for-testing-custom-images-in-e2e:v0.0.1" + custom_image = "ghcr.io/gridai/image-for-testing-custom-images-in-e2e:v1.12" build_config = BuildConfig(image=custom_image) super().__init__(parallel=True, **kwargs, cloud_compute=cloud_compute, cloud_build_config=build_config)