diff --git a/src/lightning/app/core/constants.py b/src/lightning/app/core/constants.py index 557ffd9aad..5a303b7d9c 100644 --- a/src/lightning/app/core/constants.py +++ b/src/lightning/app/core/constants.py @@ -62,7 +62,6 @@ DISABLE_DEPENDENCY_CACHE = bool(int(os.getenv("DISABLE_DEPENDENCY_CACHE", "0"))) # Project under which the resources need to run in cloud. If this env is not set, # cloud runner will try to get the default project from the cloud LIGHTNING_CLOUD_PROJECT_ID = os.getenv("LIGHTNING_CLOUD_PROJECT_ID") -LIGHTNING_CLOUD_ORGANIZATION_ID = os.getenv("LIGHTNING_CLOUD_ORGANIZATION_ID") LIGHTNING_CLOUD_PRINT_SPECS = os.getenv("LIGHTNING_CLOUD_PRINT_SPECS") LIGHTNING_DIR = os.getenv("LIGHTNING_DIR", str(Path.home() / ".lightning")) LIGHTNING_CREDENTIAL_PATH = os.getenv("LIGHTNING_CREDENTIAL_PATH", str(Path(LIGHTNING_DIR) / "credentials.json")) diff --git a/src/lightning/app/utilities/cloud.py b/src/lightning/app/utilities/cloud.py index 31beabab88..95c76c1be9 100644 --- a/src/lightning/app/utilities/cloud.py +++ b/src/lightning/app/utilities/cloud.py @@ -18,36 +18,35 @@ from typing import Optional from lightning_cloud.openapi import V1Membership import lightning.app -from lightning.app.core import constants +from lightning.app.core.constants import LIGHTNING_CLOUD_PROJECT_ID from lightning.app.utilities.enum import AppStage from lightning.app.utilities.network import LightningClient -def _get_project( - client: LightningClient, - organization_id: Optional[str] = None, - project_id: Optional[str] = None, - verbose: bool = True, -) -> V1Membership: +def _get_project(client: LightningClient, project_id: Optional[str] = None, verbose: bool = True) -> V1Membership: """Get a project membership for the user from the backend.""" if project_id is None: - project_id = constants.LIGHTNING_CLOUD_PROJECT_ID - if organization_id is None: - organization_id = constants.LIGHTNING_CLOUD_ORGANIZATION_ID + project_id = LIGHTNING_CLOUD_PROJECT_ID - projects = client.projects_service_list_memberships( - **({"organization_id": organization_id} if organization_id is not None else {}) - ) if project_id is not None: - for membership in projects.memberships: - if membership.project_id == project_id: - break - else: + project = client.projects_service_get_project(project_id) + if not project: raise ValueError( "Environment variable `LIGHTNING_CLOUD_PROJECT_ID` is set but could not find an associated project." ) - return membership + return V1Membership( + name=project.name, + display_name=project.display_name, + description=project.description, + created_at=project.created_at, + project_id=project.id, + owner_id=project.owner_id, + owner_type=project.owner_type, + quotas=project.quotas, + updated_at=project.updated_at, + ) + projects = client.projects_service_list_memberships() if len(projects.memberships) == 0: raise ValueError("No valid projects found. Please reach out to lightning.ai team to create a project") if len(projects.memberships) > 1 and verbose: diff --git a/tests/tests_app/utilities/test_cloud.py b/tests/tests_app/utilities/test_cloud.py index c0346a8305..203812d600 100644 --- a/tests/tests_app/utilities/test_cloud.py +++ b/tests/tests_app/utilities/test_cloud.py @@ -1,29 +1,19 @@ import os from unittest import mock -from lightning_cloud.openapi.models import V1ListMembershipsResponse, V1Membership +from lightning_cloud.openapi.models import V1Project from lightning.app.utilities.cloud import _get_project, is_running_in_cloud -@mock.patch("lightning.app.core.constants.LIGHTNING_CLOUD_ORGANIZATION_ID", "organization_id") -def test_get_project_picks_up_organization_id(): - """Uses organization_id from `LIGHTNING_CLOUD_ORGANIZATION_ID` config var if none passed.""" +def test_get_project_queries_by_project_id_directly_if_it_is_passed(): lightning_client = mock.MagicMock() - lightning_client.projects_service_list_memberships = mock.MagicMock( - return_value=V1ListMembershipsResponse(memberships=[V1Membership(project_id="project_id")]), + lightning_client.projects_service_get_project = mock.MagicMock( + return_value=V1Project(id="project_id"), ) - _get_project(lightning_client) - lightning_client.projects_service_list_memberships.assert_called_once_with(organization_id="organization_id") - - -def test_get_project_doesnt_pass_organization_id_if_its_not_set(): - lightning_client = mock.MagicMock() - lightning_client.projects_service_list_memberships = mock.MagicMock( - return_value=V1ListMembershipsResponse(memberships=[V1Membership(project_id="project_id")]), - ) - _get_project(lightning_client) - lightning_client.projects_service_list_memberships.assert_called_once_with() + project = _get_project(lightning_client, project_id="project_id") + assert project.project_id == "project_id" + lightning_client.projects_service_get_project.assert_called_once_with("project_id") def test_is_running_cloud():