fix: get project (#17666)

This commit is contained in:
Yurij Mikhalevich 2023-05-19 22:29:22 +04:00 committed by GitHub
parent 3a6d0d80c3
commit 61246c3b35
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 24 additions and 36 deletions

View File

@ -62,7 +62,6 @@ DISABLE_DEPENDENCY_CACHE = bool(int(os.getenv("DISABLE_DEPENDENCY_CACHE", "0")))
# Project under which the resources need to run in cloud. If this env is not set,
# cloud runner will try to get the default project from the cloud
LIGHTNING_CLOUD_PROJECT_ID = os.getenv("LIGHTNING_CLOUD_PROJECT_ID")
LIGHTNING_CLOUD_ORGANIZATION_ID = os.getenv("LIGHTNING_CLOUD_ORGANIZATION_ID")
LIGHTNING_CLOUD_PRINT_SPECS = os.getenv("LIGHTNING_CLOUD_PRINT_SPECS")
LIGHTNING_DIR = os.getenv("LIGHTNING_DIR", str(Path.home() / ".lightning"))
LIGHTNING_CREDENTIAL_PATH = os.getenv("LIGHTNING_CREDENTIAL_PATH", str(Path(LIGHTNING_DIR) / "credentials.json"))

View File

@ -18,36 +18,35 @@ from typing import Optional
from lightning_cloud.openapi import V1Membership
import lightning.app
from lightning.app.core import constants
from lightning.app.core.constants import LIGHTNING_CLOUD_PROJECT_ID
from lightning.app.utilities.enum import AppStage
from lightning.app.utilities.network import LightningClient
def _get_project(
client: LightningClient,
organization_id: Optional[str] = None,
project_id: Optional[str] = None,
verbose: bool = True,
) -> V1Membership:
def _get_project(client: LightningClient, project_id: Optional[str] = None, verbose: bool = True) -> V1Membership:
"""Get a project membership for the user from the backend."""
if project_id is None:
project_id = constants.LIGHTNING_CLOUD_PROJECT_ID
if organization_id is None:
organization_id = constants.LIGHTNING_CLOUD_ORGANIZATION_ID
project_id = LIGHTNING_CLOUD_PROJECT_ID
projects = client.projects_service_list_memberships(
**({"organization_id": organization_id} if organization_id is not None else {})
)
if project_id is not None:
for membership in projects.memberships:
if membership.project_id == project_id:
break
else:
project = client.projects_service_get_project(project_id)
if not project:
raise ValueError(
"Environment variable `LIGHTNING_CLOUD_PROJECT_ID` is set but could not find an associated project."
)
return membership
return V1Membership(
name=project.name,
display_name=project.display_name,
description=project.description,
created_at=project.created_at,
project_id=project.id,
owner_id=project.owner_id,
owner_type=project.owner_type,
quotas=project.quotas,
updated_at=project.updated_at,
)
projects = client.projects_service_list_memberships()
if len(projects.memberships) == 0:
raise ValueError("No valid projects found. Please reach out to lightning.ai team to create a project")
if len(projects.memberships) > 1 and verbose:

View File

@ -1,29 +1,19 @@
import os
from unittest import mock
from lightning_cloud.openapi.models import V1ListMembershipsResponse, V1Membership
from lightning_cloud.openapi.models import V1Project
from lightning.app.utilities.cloud import _get_project, is_running_in_cloud
@mock.patch("lightning.app.core.constants.LIGHTNING_CLOUD_ORGANIZATION_ID", "organization_id")
def test_get_project_picks_up_organization_id():
"""Uses organization_id from `LIGHTNING_CLOUD_ORGANIZATION_ID` config var if none passed."""
def test_get_project_queries_by_project_id_directly_if_it_is_passed():
lightning_client = mock.MagicMock()
lightning_client.projects_service_list_memberships = mock.MagicMock(
return_value=V1ListMembershipsResponse(memberships=[V1Membership(project_id="project_id")]),
lightning_client.projects_service_get_project = mock.MagicMock(
return_value=V1Project(id="project_id"),
)
_get_project(lightning_client)
lightning_client.projects_service_list_memberships.assert_called_once_with(organization_id="organization_id")
def test_get_project_doesnt_pass_organization_id_if_its_not_set():
lightning_client = mock.MagicMock()
lightning_client.projects_service_list_memberships = mock.MagicMock(
return_value=V1ListMembershipsResponse(memberships=[V1Membership(project_id="project_id")]),
)
_get_project(lightning_client)
lightning_client.projects_service_list_memberships.assert_called_once_with()
project = _get_project(lightning_client, project_id="project_id")
assert project.project_id == "project_id"
lightning_client.projects_service_get_project.assert_called_once_with("project_id")
def test_is_running_cloud():