[App] Content for plugins (#17243)

Co-authored-by: Yurij Mikhalevich <yurij@grid.ai>
Co-authored-by: Luca Antiga <luca.antiga@gmail.com>
This commit is contained in:
Ethan Harris 2023-07-07 11:05:58 +01:00 committed by GitHub
parent c8656f1a27
commit 2c3dfc0fb7
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 126 additions and 81 deletions

View File

@ -61,6 +61,8 @@ class LightningPlugin:
"""
from lightning.app.runners.cloud import CloudRuntime
logger.info(f"Processing job run request. name: {name}, app_entrypoint: {app_entrypoint}, env_vars: {env_vars}")
# Dispatch the job
_set_flow_context()
@ -123,6 +125,8 @@ def _run_plugin(run: _Run) -> Dict[str, Any]:
# Download the tarball
try:
logger.info(f"Downloading plugin source: {run.source_code_url}")
# Sometimes the URL gets encoded, so we parse it here
source_code_url = urlparse(run.source_code_url).geturl()
@ -141,6 +145,8 @@ def _run_plugin(run: _Run) -> Dict[str, Any]:
# Extract
try:
logger.info("Extracting plugin source.")
with tarfile.open(download_path, "r:gz") as tf:
tf.extractall(source_path)
except Exception as ex:
@ -151,6 +157,8 @@ def _run_plugin(run: _Run) -> Dict[str, Any]:
# Import the plugin
try:
logger.info(f"Importing plugin: {run.plugin_entrypoint}")
plugin = _load_plugin_from_file(os.path.join(source_path, run.plugin_entrypoint))
except Exception as ex:
raise HTTPException(
@ -163,6 +171,11 @@ def _run_plugin(run: _Run) -> Dict[str, Any]:
# Setup and run the plugin
try:
logger.info(
"Running plugin. "
f"project_id: {run.project_id}, cloudspace_id: {run.cloudspace_id}, cluster_id: {run.cluster_id}."
)
plugin._setup(
project_id=run.project_id,
cloudspace_id=run.cloudspace_id,

View File

@ -218,16 +218,22 @@ class CloudRuntime(Runtime):
# Dispatch in four phases: resolution, validation, spec creation, API transactions
# Resolution
root = self._resolve_root()
repo = self._resolve_repo(root)
# If the root will already be there, we don't need to upload and preserve the absolute entrypoint
absolute_entrypoint = str(root).startswith("/project")
# If system customization files found, it will set their location path
sys_customizations_root = self._resolve_env_root()
repo = self._resolve_repo(
root,
default_ignore=False,
package_source=not absolute_entrypoint,
sys_customizations_root=sys_customizations_root,
)
project = self._resolve_project(project_id=project_id)
existing_instances = self._resolve_run_instances_by_name(project_id, name)
name = self._resolve_run_name(name, existing_instances)
cloudspace = self._resolve_cloudspace(project_id, cloudspace_id)
queue_server_type = self._resolve_queue_server_type()
# If system customization files found, it will set their location path
sys_customizations_sync_root = self._resolve_env_root()
self.app._update_index_file()
# Validation
@ -241,17 +247,26 @@ class CloudRuntime(Runtime):
flow_servers = self._get_flow_servers()
network_configs = self._get_network_configs(flow_servers)
works = self._get_works(cloudspace=cloudspace)
run_body = self._get_run_body(cluster_id, flow_servers, network_configs, works, False, root, True)
run_body = self._get_run_body(
cluster_id,
flow_servers,
network_configs,
works,
False,
root,
True,
True,
absolute_entrypoint,
)
env_vars = self._get_env_vars(self.env_vars, self.secrets, self.run_app_comment_commands)
# If the system customization root is set, prepare files for environment synchronization
if sys_customizations_sync_root is not None:
repo.prepare_sys_customizations_sync(sys_customizations_sync_root)
# API transactions
logger.info(f"Creating cloudspace run. run_body: {run_body}")
run = self._api_create_run(project_id, cloudspace_id, run_body)
self._api_package_and_upload_repo(repo, run)
logger.info(f"Creating cloudspace run instance. name: {name}")
run_instance = self._api_create_run_instance(
cluster_id,
project_id,
@ -454,6 +469,9 @@ class CloudRuntime(Runtime):
self,
root: Path,
ignore_functions: Optional[List[_IGNORE_FUNCTION]] = None,
default_ignore: bool = True,
package_source: bool = True,
sys_customizations_root: Optional[Path] = None,
) -> LocalSourceCodeDir:
"""Gather and merge all lightningignores from the app children and create the ``LocalSourceCodeDir``
object."""
@ -470,7 +488,13 @@ class CloudRuntime(Runtime):
patterns = _parse_lightningignore(merged)
ignore_functions = [*ignore_functions, partial(_filter_ignored, root, patterns)]
return LocalSourceCodeDir(path=root, ignore_functions=ignore_functions)
return LocalSourceCodeDir(
path=root,
ignore_functions=ignore_functions,
default_ignore=default_ignore,
package_source=package_source,
sys_customizations_root=sys_customizations_root,
)
def _resolve_project(self, project_id: Optional[str] = None) -> V1Membership:
"""Determine the project to run on, choosing a default if multiple projects are found."""
@ -788,7 +812,7 @@ class CloudRuntime(Runtime):
network_config=[V1NetworkConfig(name=random_name, port=work.port)],
data_connection_mounts=data_connection_mounts,
)
works.append(V1Work(name=work.name, spec=work_spec))
works.append(V1Work(name=work.name, display_name=work.display_name, spec=work_spec))
return works
@ -801,12 +825,18 @@ class CloudRuntime(Runtime):
no_cache: bool,
root: Path,
start_server: bool,
should_mount_cloudspace_content: bool = False,
absolute_entrypoint: bool = False,
) -> CloudspaceIdRunsBody:
"""Get the specification of the run creation request."""
# The entry point file needs to be relative to the root of the uploaded source file directory,
# because the backend will invoke the lightning commands relative said source directory
# TODO: we shouldn't set this if the entrypoint isn't a file but the backend gives an error if we don't
app_entrypoint_file = Path(self.entrypoint).absolute().relative_to(root)
if absolute_entrypoint:
# If the entrypoint will already exist in the cloud then we can choose to keep it as an absolute path.
app_entrypoint_file = Path(self.entrypoint).absolute()
else:
# The entry point file needs to be relative to the root of the uploaded source file directory,
# because the backend will invoke the lightning commands relative said source directory
# TODO: we shouldn't set this if the entrypoint isn't a file but the backend gives an error if we don't
app_entrypoint_file = Path(self.entrypoint).absolute().relative_to(root)
run_body = CloudspaceIdRunsBody(
cluster_id=cluster_id,
@ -816,6 +846,7 @@ class CloudRuntime(Runtime):
network_config=network_configs,
works=works,
local_source=True,
should_mount_cloudspace_content=should_mount_cloudspace_content,
)
if self.app is not None:
@ -830,9 +861,10 @@ class CloudRuntime(Runtime):
# if requirements file at the root of the repository is present,
# we pass just the file name to the backend, so backend can find it in the relative path
requirements_file = root / "requirements.txt"
if requirements_file.is_file():
if requirements_file.is_file() and requirements_file.exists():
requirements_path = requirements_file if absolute_entrypoint else "requirements.txt"
run_body.image_spec = Gridv1ImageSpec(
dependency_file_info=V1DependencyFileInfo(package_manager=V1PackageManager.PIP, path="requirements.txt")
dependency_file_info=V1DependencyFileInfo(package_manager=V1PackageManager.PIP, path=requirements_path)
)
if not DISABLE_DEPENDENCY_CACHE and not no_cache:
# hash used for caching the dependencies
@ -1000,7 +1032,10 @@ class CloudRuntime(Runtime):
)
@staticmethod
def _api_package_and_upload_repo(repo: LocalSourceCodeDir, run: V1LightningRun) -> None:
def _api_package_and_upload_repo(
repo: LocalSourceCodeDir,
run: V1LightningRun,
) -> None:
"""Package and upload the provided local source code directory to the provided run."""
if run.source_upload_url == "":
raise RuntimeError("The source upload url is empty.")

View File

@ -13,6 +13,7 @@
# limitations under the License.
import os
import uuid
from contextlib import contextmanager
from pathlib import Path
from shutil import copytree, rmtree
@ -20,7 +21,6 @@ from typing import List, Optional
from lightning.app.core.constants import DOT_IGNORE_FILENAME, SYS_CUSTOMIZATIONS_SYNC_PATH
from lightning.app.source_code.copytree import _copytree, _IGNORE_FUNCTION
from lightning.app.source_code.hashing import _get_hash
from lightning.app.source_code.tar import _tar_path
from lightning.app.source_code.uploader import FileUploader
@ -28,7 +28,14 @@ from lightning.app.source_code.uploader import FileUploader
class LocalSourceCodeDir:
"""Represents the source code directory and provide the utilities to manage it."""
def __init__(self, path: Path, ignore_functions: Optional[List[_IGNORE_FUNCTION]] = None) -> None:
def __init__(
self,
path: Path,
ignore_functions: Optional[List[_IGNORE_FUNCTION]] = None,
default_ignore: bool = True,
package_source: bool = True,
sys_customizations_root: Optional[Path] = None,
) -> None:
if "LIGHTNING_VSCODE_WORKSPACE" in os.environ:
# Don't use home to store the tar ball. This won't play nice with symlinks
self.cache_location: Path = Path("/tmp", ".lightning", "cache", "repositories")
@ -37,8 +44,10 @@ class LocalSourceCodeDir:
self.path = path
self.ignore_functions = ignore_functions
self.package_source = package_source
self.sys_customizations_root = sys_customizations_root
# cache checksum version
# cache version
self._version: Optional[str] = None
self._non_ignored_files: Optional[List[str]] = None
@ -46,8 +55,8 @@ class LocalSourceCodeDir:
if not self.cache_location.exists():
self.cache_location.mkdir(parents=True, exist_ok=True)
# Create a default dotignore if it doesn't exist
if not (path / DOT_IGNORE_FILENAME).is_file():
# Create a default dotignore if requested and it doesn't exist
if default_ignore and not (path / DOT_IGNORE_FILENAME).is_file():
with open(path / DOT_IGNORE_FILENAME, "w") as f:
f.write("venv/\n")
if (path / "bin" / "activate").is_file() or (path / "pyvenv.cfg").is_file():
@ -61,7 +70,10 @@ class LocalSourceCodeDir:
def files(self) -> List[str]:
"""Returns a set of files that are not ignored by .lightningignore."""
if self._non_ignored_files is None:
self._non_ignored_files = _copytree(self.path, "", ignore_functions=self.ignore_functions, dry_run=True)
if self.package_source:
self._non_ignored_files = _copytree(self.path, "", ignore_functions=self.ignore_functions, dry_run=True)
else:
self._non_ignored_files = []
return self._non_ignored_files
@property
@ -71,8 +83,8 @@ class LocalSourceCodeDir:
if self._version is not None:
return self._version
# stores both version and a set with the files used to generate the checksum
self._version = _get_hash(files=self.files, algorithm="blake2")
# create a random version ID and store it
self._version = uuid.uuid4().hex
return self._version
@property
@ -87,7 +99,11 @@ class LocalSourceCodeDir:
session_path = self.cache_location / "packaging_sessions" / self.version
try:
rmtree(session_path, ignore_errors=True)
_copytree(self.path, session_path, ignore_functions=self.ignore_functions)
if self.package_source:
_copytree(self.path, session_path, ignore_functions=self.ignore_functions)
if self.sys_customizations_root is not None:
path_to_sync = Path(session_path, SYS_CUSTOMIZATIONS_SYNC_PATH)
copytree(self.sys_customizations_root, path_to_sync, dirs_exist_ok=True)
yield session_path
finally:
rmtree(session_path, ignore_errors=True)
@ -108,12 +124,6 @@ class LocalSourceCodeDir:
_tar_path(source_path=session_path, target_file=str(self.package_path), compression=True)
return self.package_path
def prepare_sys_customizations_sync(self, sys_customizations_root: Path) -> None:
"""Prepares files for system environment customization setup by copying conda and system environment files
to an app files directory."""
path_to_sync = Path(self.path, SYS_CUSTOMIZATIONS_SYNC_PATH)
copytree(sys_customizations_root, path_to_sync, dirs_exist_ok=True)
def upload(self, url: str) -> None:
"""Uploads package to URL, usually pre-signed URL.

View File

@ -98,6 +98,7 @@ def get_cloud_runtime_request_body(**kwargs) -> "CloudspaceIdRunsBody":
"app_entrypoint_file": mock.ANY,
"enable_app_server": True,
"is_headless": True,
"should_mount_cloudspace_content": False,
"flow_servers": [],
"image_spec": None,
"works": [],
@ -386,6 +387,7 @@ class TestAppCreationClient:
app_entrypoint_file=mock.ANY,
enable_app_server=True,
is_headless=False,
should_mount_cloudspace_content=False,
flow_servers=[],
image_spec=None,
works=[],
@ -433,6 +435,7 @@ class TestAppCreationClient:
app_entrypoint_file=mock.ANY,
enable_app_server=True,
is_headless=False,
should_mount_cloudspace_content=False,
flow_servers=[],
image_spec=None,
works=[],
@ -491,6 +494,7 @@ class TestAppCreationClient:
app_entrypoint_file=mock.ANY,
enable_app_server=True,
is_headless=False,
should_mount_cloudspace_content=False,
flow_servers=[],
image_spec=None,
works=[],
@ -624,6 +628,7 @@ class TestAppCreationClient:
app_entrypoint_file="entrypoint.py",
enable_app_server=True,
is_headless=False,
should_mount_cloudspace_content=False,
flow_servers=[],
dependency_cache_key=get_hash(requirements_file),
user_requested_flow_compute_config=mock.ANY,
@ -639,6 +644,7 @@ class TestAppCreationClient:
expected_body.works = [
V1Work(
name="test-work",
display_name="",
spec=V1LightningworkSpec(
build_spec=V1BuildSpec(
commands=["echo 'start'"],
@ -813,6 +819,7 @@ class TestAppCreationClient:
app_entrypoint_file="entrypoint.py",
enable_app_server=True,
is_headless=False,
should_mount_cloudspace_content=False,
flow_servers=[],
dependency_cache_key=get_hash(requirements_file),
user_requested_flow_compute_config=mock.ANY,
@ -825,6 +832,7 @@ class TestAppCreationClient:
works=[
V1Work(
name="test-work",
display_name="",
spec=V1LightningworkSpec(
build_spec=V1BuildSpec(
commands=["echo 'start'"],
@ -942,6 +950,7 @@ class TestAppCreationClient:
app_entrypoint_file="entrypoint.py",
enable_app_server=True,
is_headless=False,
should_mount_cloudspace_content=False,
flow_servers=[],
dependency_cache_key=get_hash(requirements_file),
user_requested_flow_compute_config=mock.ANY,
@ -954,6 +963,7 @@ class TestAppCreationClient:
works=[
V1Work(
name="test-work",
display_name="",
spec=V1LightningworkSpec(
build_spec=V1BuildSpec(
commands=["echo 'start'"],
@ -1112,6 +1122,7 @@ class TestAppCreationClient:
app_entrypoint_file="entrypoint.py",
enable_app_server=True,
is_headless=False,
should_mount_cloudspace_content=False,
flow_servers=[],
dependency_cache_key=get_hash(requirements_file),
user_requested_flow_compute_config=mock.ANY,
@ -1124,6 +1135,7 @@ class TestAppCreationClient:
works=[
V1Work(
name="test-work",
display_name="",
spec=V1LightningworkSpec(
build_spec=V1BuildSpec(
commands=["echo 'start'"],
@ -1153,6 +1165,7 @@ class TestAppCreationClient:
app_entrypoint_file="entrypoint.py",
enable_app_server=True,
is_headless=False,
should_mount_cloudspace_content=False,
flow_servers=[],
dependency_cache_key=get_hash(requirements_file),
user_requested_flow_compute_config=mock.ANY,
@ -1165,6 +1178,7 @@ class TestAppCreationClient:
works=[
V1Work(
name="test-work",
display_name="",
spec=V1LightningworkSpec(
build_spec=V1BuildSpec(
commands=["echo 'start'"],
@ -1300,6 +1314,7 @@ class TestAppCreationClient:
app_entrypoint_file="entrypoint.py",
enable_app_server=True,
is_headless=False,
should_mount_cloudspace_content=False,
flow_servers=[],
dependency_cache_key=get_hash(requirements_file),
image_spec=Gridv1ImageSpec(
@ -1312,6 +1327,7 @@ class TestAppCreationClient:
works=[
V1Work(
name="test-work",
display_name="",
spec=V1LightningworkSpec(
build_spec=V1BuildSpec(
commands=["echo 'start'"],
@ -1616,9 +1632,6 @@ class TestCloudspaceDispatch:
project_id="project_id", id="cloudspace_id"
)
if custom_env_sync_path_value is not None:
mock_repo.prepare_sys_customizations_sync.assert_called_once_with(custom_env_sync_path_value)
mock_client.cloud_space_service_create_lightning_run.assert_called_once_with(
project_id="project_id", cloudspace_id="cloudspace_id", body=mock.ANY
)
@ -1791,6 +1804,7 @@ def test_load_app_from_file():
"web",
[
{
"displayName": "",
"name": "root.work",
"spec": {
"buildSpec": {
@ -1815,6 +1829,7 @@ def test_load_app_from_file():
"gallery",
[
{
"display_name": "",
"name": "root.work",
"spec": {
"build_spec": {

View File

@ -11,26 +11,15 @@ from lightning.app.source_code import LocalSourceCodeDir
def test_repository_checksum(tmp_path):
"""LocalRepository.checksum() generates a hash of local dir."""
"""LocalRepository.version() generates a different version each time."""
repository = LocalSourceCodeDir(path=Path(tmp_path))
version_a = repository.version
test_path = tmp_path / "test.txt"
version_a = str(uuid.uuid4())
test_path.write_text(version_a)
checksum_a = repository.version
# file contents don't change; checksum is the same
# version is different
repository = LocalSourceCodeDir(path=Path(tmp_path))
test_path.write_text(version_a)
checksum_b = repository.version
assert checksum_a == checksum_b
version_b = repository.version
# file contents change; checksum is different
repository = LocalSourceCodeDir(path=Path(tmp_path))
test_path.write_text(str(uuid.uuid4()))
checksum_c = repository.version
assert checksum_a != checksum_c
assert version_a != version_b
@pytest.mark.skipif(sys.platform == "win32", reason="this runs only on linux")
@ -48,7 +37,7 @@ def test_local_cache_path_home(tmp_path):
def test_repository_package(tmp_path, monkeypatch):
"""LocalRepository.package() ceates package from local dir."""
"""LocalRepository.package() creates package from local dir."""
cache_path = Path(tmp_path)
source_path = cache_path / "nested"
source_path.mkdir(parents=True, exist_ok=True)
@ -73,23 +62,19 @@ def test_repository_lightningignore(tmp_path):
"""
(tmp_path / ".lightningignore").write_text(lightningignore)
(tmp_path / "test.txt").write_text("test")
# write some data to file and check version
(tmp_path / "test.txt").write_text(str(uuid.uuid4()))
# create repo object
repository = LocalSourceCodeDir(path=Path(tmp_path))
checksum_a = repository.version
assert repository.files == [str(tmp_path / ".lightningignore"), str(tmp_path / "test.txt")]
# write file that needs to be ignored
(tmp_path / "ignore").mkdir()
(tmp_path / "ignore/test.txt").write_text(str(uuid.uuid4()))
# check that version remains the same
repository = LocalSourceCodeDir(path=Path(tmp_path))
checksum_b = repository.version
assert checksum_a == checksum_b
assert repository.files == [str(tmp_path / ".lightningignore"), str(tmp_path / "test.txt")]
def test_repository_filters_with_absolute_relative_path(tmp_path):
@ -100,16 +85,11 @@ def test_repository_filters_with_absolute_relative_path(tmp_path):
/ignore_dir
"""
(tmp_path / ".lightningignore").write_text(lightningignore)
(tmp_path / "test.txt").write_text("test")
# write some data to file and check version
(tmp_path / "test.txt").write_text(str(uuid.uuid4()))
# create repo object
repository = LocalSourceCodeDir(path=Path(tmp_path))
checksum_a = repository.version
# only two files in hash
assert len(repository._non_ignored_files) == 2
assert repository.files == [str(tmp_path / ".lightningignore"), str(tmp_path / "test.txt")]
# write file that needs to be ignored
(tmp_path / "ignore_file").mkdir()
@ -117,14 +97,9 @@ def test_repository_filters_with_absolute_relative_path(tmp_path):
(tmp_path / "ignore_file/test.txt").write_text(str(uuid.uuid4()))
(tmp_path / "ignore_dir/test.txt").write_text(str(uuid.uuid4()))
# check that version remains the same
repository = LocalSourceCodeDir(path=Path(tmp_path))
checksum_b = repository.version
# still only two files in hash
assert len(repository._non_ignored_files) == 2
assert checksum_a == checksum_b
assert repository.files == [str(tmp_path / ".lightningignore"), str(tmp_path / "test.txt")]
def test_repository_lightningignore_supports_different_patterns(tmp_path):
@ -269,13 +244,11 @@ def test_repository_lightningignore_supports_different_patterns(tmp_path):
"""
(tmp_path / ".lightningignore").write_text(lightningignore)
(tmp_path / "test.txt").write_text("test")
# write some data to file and check version
(tmp_path / "test.txt").write_text(str(uuid.uuid4()))
# create repo object
repository = LocalSourceCodeDir(path=Path(tmp_path))
checksum_a = repository.version
assert repository.files == [str(tmp_path / ".lightningignore"), str(tmp_path / "test.txt")]
# write file that needs to be ignored
(tmp_path / "ignore").mkdir()
@ -283,9 +256,8 @@ def test_repository_lightningignore_supports_different_patterns(tmp_path):
# check that version remains the same
repository = LocalSourceCodeDir(path=Path(tmp_path))
checksum_b = repository.version
assert checksum_a == checksum_b
assert repository.files == [str(tmp_path / ".lightningignore"), str(tmp_path / "test.txt")]
def test_repository_lightningignore_unpackage(tmp_path, monkeypatch):