fix: get project (#17666)
This commit is contained in:
parent
3a6d0d80c3
commit
61246c3b35
|
@ -62,7 +62,6 @@ DISABLE_DEPENDENCY_CACHE = bool(int(os.getenv("DISABLE_DEPENDENCY_CACHE", "0")))
|
|||
# Project under which the resources need to run in cloud. If this env is not set,
|
||||
# cloud runner will try to get the default project from the cloud
|
||||
LIGHTNING_CLOUD_PROJECT_ID = os.getenv("LIGHTNING_CLOUD_PROJECT_ID")
|
||||
LIGHTNING_CLOUD_ORGANIZATION_ID = os.getenv("LIGHTNING_CLOUD_ORGANIZATION_ID")
|
||||
LIGHTNING_CLOUD_PRINT_SPECS = os.getenv("LIGHTNING_CLOUD_PRINT_SPECS")
|
||||
LIGHTNING_DIR = os.getenv("LIGHTNING_DIR", str(Path.home() / ".lightning"))
|
||||
LIGHTNING_CREDENTIAL_PATH = os.getenv("LIGHTNING_CREDENTIAL_PATH", str(Path(LIGHTNING_DIR) / "credentials.json"))
|
||||
|
|
|
@ -18,36 +18,35 @@ from typing import Optional
|
|||
from lightning_cloud.openapi import V1Membership
|
||||
|
||||
import lightning.app
|
||||
from lightning.app.core import constants
|
||||
from lightning.app.core.constants import LIGHTNING_CLOUD_PROJECT_ID
|
||||
from lightning.app.utilities.enum import AppStage
|
||||
from lightning.app.utilities.network import LightningClient
|
||||
|
||||
|
||||
def _get_project(
|
||||
client: LightningClient,
|
||||
organization_id: Optional[str] = None,
|
||||
project_id: Optional[str] = None,
|
||||
verbose: bool = True,
|
||||
) -> V1Membership:
|
||||
def _get_project(client: LightningClient, project_id: Optional[str] = None, verbose: bool = True) -> V1Membership:
|
||||
"""Get a project membership for the user from the backend."""
|
||||
if project_id is None:
|
||||
project_id = constants.LIGHTNING_CLOUD_PROJECT_ID
|
||||
if organization_id is None:
|
||||
organization_id = constants.LIGHTNING_CLOUD_ORGANIZATION_ID
|
||||
project_id = LIGHTNING_CLOUD_PROJECT_ID
|
||||
|
||||
projects = client.projects_service_list_memberships(
|
||||
**({"organization_id": organization_id} if organization_id is not None else {})
|
||||
)
|
||||
if project_id is not None:
|
||||
for membership in projects.memberships:
|
||||
if membership.project_id == project_id:
|
||||
break
|
||||
else:
|
||||
project = client.projects_service_get_project(project_id)
|
||||
if not project:
|
||||
raise ValueError(
|
||||
"Environment variable `LIGHTNING_CLOUD_PROJECT_ID` is set but could not find an associated project."
|
||||
)
|
||||
return membership
|
||||
return V1Membership(
|
||||
name=project.name,
|
||||
display_name=project.display_name,
|
||||
description=project.description,
|
||||
created_at=project.created_at,
|
||||
project_id=project.id,
|
||||
owner_id=project.owner_id,
|
||||
owner_type=project.owner_type,
|
||||
quotas=project.quotas,
|
||||
updated_at=project.updated_at,
|
||||
)
|
||||
|
||||
projects = client.projects_service_list_memberships()
|
||||
if len(projects.memberships) == 0:
|
||||
raise ValueError("No valid projects found. Please reach out to lightning.ai team to create a project")
|
||||
if len(projects.memberships) > 1 and verbose:
|
||||
|
|
|
@ -1,29 +1,19 @@
|
|||
import os
|
||||
from unittest import mock
|
||||
|
||||
from lightning_cloud.openapi.models import V1ListMembershipsResponse, V1Membership
|
||||
from lightning_cloud.openapi.models import V1Project
|
||||
|
||||
from lightning.app.utilities.cloud import _get_project, is_running_in_cloud
|
||||
|
||||
|
||||
@mock.patch("lightning.app.core.constants.LIGHTNING_CLOUD_ORGANIZATION_ID", "organization_id")
|
||||
def test_get_project_picks_up_organization_id():
|
||||
"""Uses organization_id from `LIGHTNING_CLOUD_ORGANIZATION_ID` config var if none passed."""
|
||||
def test_get_project_queries_by_project_id_directly_if_it_is_passed():
|
||||
lightning_client = mock.MagicMock()
|
||||
lightning_client.projects_service_list_memberships = mock.MagicMock(
|
||||
return_value=V1ListMembershipsResponse(memberships=[V1Membership(project_id="project_id")]),
|
||||
lightning_client.projects_service_get_project = mock.MagicMock(
|
||||
return_value=V1Project(id="project_id"),
|
||||
)
|
||||
_get_project(lightning_client)
|
||||
lightning_client.projects_service_list_memberships.assert_called_once_with(organization_id="organization_id")
|
||||
|
||||
|
||||
def test_get_project_doesnt_pass_organization_id_if_its_not_set():
|
||||
lightning_client = mock.MagicMock()
|
||||
lightning_client.projects_service_list_memberships = mock.MagicMock(
|
||||
return_value=V1ListMembershipsResponse(memberships=[V1Membership(project_id="project_id")]),
|
||||
)
|
||||
_get_project(lightning_client)
|
||||
lightning_client.projects_service_list_memberships.assert_called_once_with()
|
||||
project = _get_project(lightning_client, project_id="project_id")
|
||||
assert project.project_id == "project_id"
|
||||
lightning_client.projects_service_get_project.assert_called_once_with("project_id")
|
||||
|
||||
|
||||
def test_is_running_cloud():
|
||||
|
|
Loading…
Reference in New Issue