lightning/tests/tests_store/test_store.py

84 lines
2.8 KiB
Python

import os
from unittest import mock
from lightning.store import download_model, list_models, upload_model
from lightning_cloud.openapi import (
V1DownloadModelResponse,
V1GetUserResponse,
V1ListMembershipsResponse,
V1ListModelsResponse,
V1Membership,
V1Model,
V1Project,
V1UploadModelRequest,
V1UploadModelResponse,
)
@mock.patch("lightning.store.store._Client")
@mock.patch("lightning.store.store._upload_file_to_url")
def test_upload_model(mock_upload_file_to_url, mock_client):
mock_client = mock_client()
mock_client.auth_service_get_user.return_value = V1GetUserResponse(username="test-username")
# either one of these project APIs could be called
mock_client.projects_service_list_memberships.return_value = V1ListMembershipsResponse(
memberships=[V1Membership(project_id="test-project-id")],
)
mock_client.projects_service_get_project.return_value = V1Project(id="test-project-id")
mock_client.models_store_upload_model.return_value = V1UploadModelResponse(
upload_url="https://test",
)
upload_model("test-model", "test.ckpt", version="0.0.1")
mock_client.auth_service_get_user.assert_called_once()
mock_client.models_store_upload_model.assert_called_once_with(
V1UploadModelRequest(
name="test-username/test-model",
version="0.0.1",
project_id="test-project-id",
)
)
mock_upload_file_to_url.assert_called_once_with("https://test", "test.ckpt", progress_bar=True)
@mock.patch("lightning.store.store._Client")
@mock.patch("lightning.store.store._download_file_from_url")
def test_download_model(mock_download_file_from_url, mock_client):
mock_client = mock_client()
mock_client.models_store_download_model.return_value = V1DownloadModelResponse(
download_url="https://test",
)
download_model("test-username/test-model", "test.ckpt", version="0.0.1")
mock_client.models_store_download_model.assert_called_once_with(
name="test-username/test-model",
version="0.0.1",
)
mock_download_file_from_url.assert_called_once_with("https://test", os.path.abspath("test.ckpt"), progress_bar=True)
@mock.patch("lightning.store.store._Client")
def test_list_models(mock_client):
mock_client = mock_client()
# either one of these project APIs could be called
mock_client.projects_service_list_memberships.return_value = V1ListMembershipsResponse(
memberships=[V1Membership(project_id="test-project-id")],
)
mock_client.projects_service_get_project.return_value = V1Project(id="test-project-id")
mock_client.models_store_list_models.return_value = V1ListModelsResponse(models=[V1Model(name="test-model")])
res = list_models()
assert res[0].name == "test-model"
mock_client.models_store_list_models.assert_called_once_with(project_id="test-project-id")