lightning/tests/tests_cloud/test_source_code.py

101 lines
3.1 KiB
Python

import inspect
import os
import tempfile
from tests_cloud import _API_KEY, _PROJECT_ID, _PROJECT_ROOT, _TEST_ROOT, _USERNAME
from lightning.pytorch.demos.boring_classes import BoringModel
from lightning.store import download_model, upload_model
from lightning.store.save import __STORAGE_DIR_NAME
def test_source_code_implicit(lit_home, model_name: str = "model_test_source_code_implicit"):
upload_model(model_name, model=BoringModel(), api_key=_API_KEY, project_id=_PROJECT_ID)
download_model(f"{_USERNAME}/{model_name}")
assert os.path.isfile(
os.path.join(
lit_home,
__STORAGE_DIR_NAME,
_USERNAME,
model_name,
"latest",
str(os.path.basename(inspect.getsourcefile(BoringModel))),
)
)
def test_source_code_saving_disabled(lit_home, model_name: str = "model_test_source_code_dont_save"):
upload_model(model_name, model=BoringModel(), api_key=_API_KEY, project_id=_PROJECT_ID, save_code=False)
download_model(f"{_USERNAME}/{model_name}")
assert not os.path.isfile(
os.path.join(
lit_home,
__STORAGE_DIR_NAME,
_USERNAME,
model_name,
"latest",
str(os.path.basename(inspect.getsourcefile(BoringModel))),
)
)
def test_source_code_explicit_relative_folder(lit_home, model_name: str = "model_test_source_code_explicit_relative"):
upload_model(model_name, model=BoringModel(), source_code_path=_TEST_ROOT, api_key=_API_KEY, project_id=_PROJECT_ID)
download_model(f"{_USERNAME}/{model_name}")
assert os.path.isdir(
os.path.join(
lit_home,
__STORAGE_DIR_NAME,
_USERNAME,
model_name,
"latest",
os.path.basename(os.path.abspath(_TEST_ROOT)),
)
)
def test_source_code_explicit_absolute_folder(
lit_home, model_name: str = "model_test_source_code_explicit_absolute_path"
):
# TODO: unify with above `test_source_code_explicit_relative_folder`
with tempfile.TemporaryDirectory() as tmpdir:
dir_upload_path = os.path.abspath(tmpdir)
upload_model(
model_name, model=BoringModel(), source_code_path=dir_upload_path, api_key=_API_KEY, project_id=_PROJECT_ID
)
download_model(f"{_USERNAME}/{model_name}")
assert os.path.isdir(
os.path.join(
lit_home,
__STORAGE_DIR_NAME,
_USERNAME,
model_name,
"latest",
os.path.basename(os.path.abspath(dir_upload_path)),
)
)
def test_source_code_explicit_file(lit_home, model_name: str = "model_test_source_code_explicit_file"):
file_name = os.path.join(_PROJECT_ROOT, "setup.py")
upload_model(model_name, model=BoringModel(), source_code_path=file_name, api_key=_API_KEY, project_id=_PROJECT_ID)
download_model(f"{_USERNAME}/{model_name}")
assert os.path.isfile(
os.path.join(
lit_home,
__STORAGE_DIR_NAME,
_USERNAME,
model_name,
"latest",
os.path.basename(file_name),
)
)