2023-08-08 08:34:03 +00:00
|
|
|
import os
|
|
|
|
from unittest import mock
|
|
|
|
|
ruff: replace isort with ruff +TPU (#17684)
* ruff: replace isort with ruff
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* fixing & imports
* lines in warning test
* docs
* fix enum import
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* fixing
* import
* fix lines
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* type ClusterEnvironment
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
---------
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
2023-09-26 15:54:55 +00:00
|
|
|
from lightning.store import download_model, list_models, upload_model
|
2023-08-08 08:34:03 +00:00
|
|
|
from lightning_cloud.openapi import (
|
|
|
|
V1DownloadModelResponse,
|
|
|
|
V1GetUserResponse,
|
|
|
|
V1ListMembershipsResponse,
|
|
|
|
V1ListModelsResponse,
|
|
|
|
V1Membership,
|
|
|
|
V1Model,
|
|
|
|
V1Project,
|
|
|
|
V1UploadModelRequest,
|
|
|
|
V1UploadModelResponse,
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
@mock.patch("lightning.store.store._Client")
|
|
|
|
@mock.patch("lightning.store.store._upload_file_to_url")
|
|
|
|
def test_upload_model(mock_upload_file_to_url, mock_client):
|
|
|
|
mock_client = mock_client()
|
|
|
|
|
|
|
|
mock_client.auth_service_get_user.return_value = V1GetUserResponse(username="test-username")
|
|
|
|
|
|
|
|
# either one of these project APIs could be called
|
|
|
|
mock_client.projects_service_list_memberships.return_value = V1ListMembershipsResponse(
|
|
|
|
memberships=[V1Membership(project_id="test-project-id")],
|
|
|
|
)
|
|
|
|
mock_client.projects_service_get_project.return_value = V1Project(id="test-project-id")
|
|
|
|
|
|
|
|
mock_client.models_store_upload_model.return_value = V1UploadModelResponse(
|
|
|
|
upload_url="https://test",
|
|
|
|
)
|
|
|
|
|
|
|
|
upload_model("test-model", "test.ckpt", version="0.0.1")
|
|
|
|
|
|
|
|
mock_client.auth_service_get_user.assert_called_once()
|
|
|
|
mock_client.models_store_upload_model.assert_called_once_with(
|
|
|
|
V1UploadModelRequest(
|
|
|
|
name="test-username/test-model",
|
|
|
|
version="0.0.1",
|
|
|
|
project_id="test-project-id",
|
|
|
|
)
|
|
|
|
)
|
|
|
|
|
|
|
|
mock_upload_file_to_url.assert_called_once_with("https://test", "test.ckpt", progress_bar=True)
|
|
|
|
|
|
|
|
|
|
|
|
@mock.patch("lightning.store.store._Client")
|
|
|
|
@mock.patch("lightning.store.store._download_file_from_url")
|
|
|
|
def test_download_model(mock_download_file_from_url, mock_client):
|
|
|
|
mock_client = mock_client()
|
|
|
|
|
|
|
|
mock_client.models_store_download_model.return_value = V1DownloadModelResponse(
|
|
|
|
download_url="https://test",
|
|
|
|
)
|
|
|
|
|
|
|
|
download_model("test-username/test-model", "test.ckpt", version="0.0.1")
|
|
|
|
|
|
|
|
mock_client.models_store_download_model.assert_called_once_with(
|
|
|
|
name="test-username/test-model",
|
|
|
|
version="0.0.1",
|
|
|
|
)
|
|
|
|
|
|
|
|
mock_download_file_from_url.assert_called_once_with("https://test", os.path.abspath("test.ckpt"), progress_bar=True)
|
|
|
|
|
|
|
|
|
|
|
|
@mock.patch("lightning.store.store._Client")
|
|
|
|
def test_list_models(mock_client):
|
|
|
|
mock_client = mock_client()
|
|
|
|
|
|
|
|
# either one of these project APIs could be called
|
|
|
|
mock_client.projects_service_list_memberships.return_value = V1ListMembershipsResponse(
|
|
|
|
memberships=[V1Membership(project_id="test-project-id")],
|
|
|
|
)
|
|
|
|
mock_client.projects_service_get_project.return_value = V1Project(id="test-project-id")
|
|
|
|
|
|
|
|
mock_client.models_store_list_models.return_value = V1ListModelsResponse(models=[V1Model(name="test-model")])
|
|
|
|
|
|
|
|
res = list_models()
|
|
|
|
assert res[0].name == "test-model"
|
|
|
|
|
|
|
|
mock_client.models_store_list_models.assert_called_once_with(project_id="test-project-id")
|