lightning/tests/tests_cloud/test_versioning.py

62 lines
2.0 KiB
Python
Raw Normal View History

import os
import platform
import pytest
from tests_cloud import _API_KEY, _PROJECT_ID, _USERNAME
from tests_cloud.helpers import cleanup
from lightning.store.cloud_api import download_from_cloud, upload_to_cloud
from lightning.store.save import _LIGHTNING_STORAGE_DIR
from pytorch_lightning.demos.boring_classes import BoringModel
def assert_download_successful(username, model_name, version):
folder_name = os.path.join(_LIGHTNING_STORAGE_DIR, username, model_name, version)
assert os.path.isdir(folder_name), f"Folder name: {folder_name} doesn't exist."
assert len(os.listdir(folder_name)) != 0
@pytest.mark.parametrize(
("case", "expected_case"),
(
[
("1.0.0", "version_1_0_0"),
("0.0.1", "version_0_0_1"),
("latest", "version_latest"),
("1.0", "version_1_0"),
("1", "version_1"),
("0.1", "version_0_1"),
("", "version_latest"),
]
),
)
def test_versioning_valid_case(case, expected_case, model_name: str = "boring_model_versioning"):
cleanup()
upload_to_cloud(model_name, version=case, model=BoringModel(), api_key=_API_KEY, project_id=_PROJECT_ID)
download_from_cloud(f"{_USERNAME}/{model_name}", version=case)
assert_download_successful(_USERNAME, model_name, expected_case)
@pytest.mark.parametrize(
"case",
(
[
" version with spaces ",
"*",
# "#", <-- TODO: Add it back later
"¡",
"©",
]
),
)
def test_versioning_invalid_case(case, model_name: str = "boring_model_versioning"):
cleanup()
with pytest.raises(ConnectionRefusedError):
upload_to_cloud(model_name, version=case, model=BoringModel(), api_key=_API_KEY, project_id=_PROJECT_ID)
error = OSError if case == "*" and platform.system() == "Windows" else ConnectionRefusedError
with pytest.raises(error):
download_from_cloud(f"{_USERNAME}/{model_name}", version=case)