Separate the concept of a Drive from that of a Mount (#15120)
* added mount class and configured it into compute config * added mount to the cloud runtime dispatcher * raise error if s3 bucket is passed to a drive telling the user to utilize mounts * added example for app * udpated tests * updated tests * addressed code review comments * fix bug * bugfix * updates' * code review comments * updates * fixed tests after rename * fix tests
This commit is contained in:
parent
576757fd79
commit
7a827b09bb
|
@ -0,0 +1 @@
|
|||
name: mount_test
|
|
@ -0,0 +1,35 @@
|
|||
import os
|
||||
|
||||
import lightning as L
|
||||
from lightning_app import CloudCompute
|
||||
from lightning_app.storage import Mount
|
||||
|
||||
|
||||
class Work(L.LightningWork):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
def run(self):
|
||||
files = os.listdir("/content/esRedditJson/")
|
||||
for file in files:
|
||||
print(file)
|
||||
assert "esRedditJson1" in files
|
||||
|
||||
|
||||
class Flow(L.LightningFlow):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.work_1 = Work(
|
||||
cloud_compute=CloudCompute(
|
||||
mounts=Mount(
|
||||
source="s3://ryft-public-sample-data/esRedditJson/",
|
||||
root_dir="/content/esRedditJson/",
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
def run(self):
|
||||
self.work_1.run()
|
||||
|
||||
|
||||
app = L.LightningApp(Flow(), debug=True)
|
|
@ -52,7 +52,7 @@ from lightning_app.core.constants import (
|
|||
from lightning_app.runners.backends.cloud import CloudBackend
|
||||
from lightning_app.runners.runtime import Runtime
|
||||
from lightning_app.source_code import LocalSourceCodeDir
|
||||
from lightning_app.storage import Drive
|
||||
from lightning_app.storage import Drive, Mount
|
||||
from lightning_app.utilities.app_helpers import Logger
|
||||
from lightning_app.utilities.cloud import _get_project
|
||||
from lightning_app.utilities.dependency_caching import get_hash
|
||||
|
@ -148,9 +148,6 @@ class CloudRuntime(Runtime):
|
|||
if drive.protocol == "lit://":
|
||||
drive_type = V1DriveType.NO_MOUNT_S3
|
||||
source_type = V1SourceType.S3
|
||||
elif drive.protocol == "s3://":
|
||||
drive_type = V1DriveType.INDEXED_S3
|
||||
source_type = V1SourceType.S3
|
||||
else:
|
||||
raise RuntimeError(
|
||||
f"unknown drive protocol `{drive.protocol}`. Please verify this "
|
||||
|
@ -174,6 +171,19 @@ class CloudRuntime(Runtime):
|
|||
),
|
||||
)
|
||||
|
||||
# TODO: Move this to the CloudCompute class and update backend
|
||||
if work.cloud_compute.mounts is not None:
|
||||
mounts = work.cloud_compute.mounts
|
||||
if isinstance(mounts, Mount):
|
||||
mounts = [mounts]
|
||||
for mount in mounts:
|
||||
drive_specs.append(
|
||||
_create_mount_drive_spec(
|
||||
work_name=work.name,
|
||||
mount=mount,
|
||||
)
|
||||
)
|
||||
|
||||
random_name = "".join(random.choice(string.ascii_lowercase) for _ in range(5))
|
||||
spec = V1LightningworkSpec(
|
||||
build_spec=build_spec,
|
||||
|
@ -398,3 +408,29 @@ class CloudRuntime(Runtime):
|
|||
balance = 0 # value is missing in some tests
|
||||
|
||||
return balance >= 1
|
||||
|
||||
|
||||
def _create_mount_drive_spec(work_name: str, mount: Mount) -> V1LightningworkDrives:
|
||||
if mount.protocol == "s3://":
|
||||
drive_type = V1DriveType.INDEXED_S3
|
||||
source_type = V1SourceType.S3
|
||||
else:
|
||||
raise RuntimeError(
|
||||
f"unknown mount protocol `{mount.protocol}`. Please verify this "
|
||||
f"drive type has been configured for use in the cloud dispatcher."
|
||||
)
|
||||
|
||||
return V1LightningworkDrives(
|
||||
drive=V1Drive(
|
||||
metadata=V1Metadata(
|
||||
name=work_name,
|
||||
),
|
||||
spec=V1DriveSpec(
|
||||
drive_type=drive_type,
|
||||
source_type=source_type,
|
||||
source=mount.source,
|
||||
),
|
||||
status=V1DriveStatus(),
|
||||
),
|
||||
mount_location=str(mount.root_dir),
|
||||
)
|
||||
|
|
|
@ -1,3 +1,4 @@
|
|||
from lightning_app.storage.drive import Drive # noqa: F401
|
||||
from lightning_app.storage.mount import Mount # noqa: F401
|
||||
from lightning_app.storage.path import Path # noqa: F401
|
||||
from lightning_app.storage.payload import Payload # noqa: F401
|
||||
|
|
|
@ -13,7 +13,7 @@ from lightning_app.utilities.component import _is_flow_context
|
|||
class Drive:
|
||||
|
||||
__IDENTIFIER__ = "__drive__"
|
||||
__PROTOCOLS__ = ["lit://", "s3://"]
|
||||
__PROTOCOLS__ = ["lit://"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
|
@ -34,6 +34,13 @@ class Drive:
|
|||
When not provided, it is automatically inferred by Lightning.
|
||||
root_folder: This is the folder from where the Drive perceives the data (e.g this acts as a mount dir).
|
||||
"""
|
||||
if id.startswith("s3://"):
|
||||
raise ValueError(
|
||||
"Using S3 buckets in a Drive is no longer supported. Please pass an S3 `Mount` to "
|
||||
"a Work's CloudCompute config in order to mount an s3 bucket as a filesystem in a work.\n"
|
||||
f"`CloudCompute(mount=Mount({id}), ...)`"
|
||||
)
|
||||
|
||||
self.id = None
|
||||
self.protocol = None
|
||||
for protocol in self.__PROTOCOLS__:
|
||||
|
@ -47,16 +54,10 @@ class Drive:
|
|||
f"must start with one of the following prefixes {self.__PROTOCOLS__}"
|
||||
)
|
||||
|
||||
if self.protocol == "s3://" and not self.id.endswith("/"):
|
||||
raise ValueError(
|
||||
"S3 drives must end in a trailing slash (`/`) to indicate a folder is being mounted. "
|
||||
f"Recieved: '{id}'. Mounting a single file is not currently supported."
|
||||
)
|
||||
|
||||
if not self.id:
|
||||
raise Exception(f"The Drive id needs to start with one of the following protocols: {self.__PROTOCOLS__}")
|
||||
|
||||
if self.protocol != "s3://" and "/" in self.id:
|
||||
if "/" in self.id:
|
||||
raise Exception(f"The id should be unique to identify your drive. Found `{self.id}`.")
|
||||
|
||||
self.root_folder = pathlib.Path(root_folder).resolve() if root_folder else pathlib.Path(os.getcwd())
|
||||
|
@ -88,10 +89,6 @@ class Drive:
|
|||
raise Exception("The component name needs to be known to put a path to the Drive.")
|
||||
if _is_flow_context():
|
||||
raise Exception("The flow isn't allowed to put files into a Drive.")
|
||||
if self.protocol == "s3://":
|
||||
raise PermissionError(
|
||||
"S3 based drives cannot currently add files via this API. Did you mean to use `lit://` drives?"
|
||||
)
|
||||
|
||||
self._validate_path(path)
|
||||
|
||||
|
@ -115,10 +112,6 @@ class Drive:
|
|||
"""
|
||||
if _is_flow_context():
|
||||
raise Exception("The flow isn't allowed to list files from a Drive.")
|
||||
if self.protocol == "s3://":
|
||||
raise PermissionError(
|
||||
"S3 based drives cannot currently list files via this API. Did you mean to use `lit://` drives?"
|
||||
)
|
||||
|
||||
if component_name:
|
||||
paths = [
|
||||
|
@ -163,10 +156,6 @@ class Drive:
|
|||
"""
|
||||
if _is_flow_context():
|
||||
raise Exception("The flow isn't allowed to get files from a Drive.")
|
||||
if self.protocol == "s3://":
|
||||
raise PermissionError(
|
||||
"S3 based drives cannot currently get files via this API. Did you mean to use `lit://` drives?"
|
||||
)
|
||||
|
||||
if component_name:
|
||||
shared_path = self._to_shared_path(
|
||||
|
@ -214,10 +203,6 @@ class Drive:
|
|||
"""
|
||||
if not self.component_name:
|
||||
raise Exception("The component name needs to be known to delete a path to the Drive.")
|
||||
if self.protocol == "s3://":
|
||||
raise PermissionError(
|
||||
"S3 based drives cannot currently delete files via this API. Did you mean to use `lit://` drives?"
|
||||
)
|
||||
|
||||
shared_path = self._to_shared_path(
|
||||
path,
|
||||
|
|
|
@ -0,0 +1,54 @@
|
|||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import List
|
||||
|
||||
__MOUNT_IDENTIFIER__: str = "__mount__"
|
||||
__MOUNT_PROTOCOLS__: List[str] = ["s3://"]
|
||||
|
||||
|
||||
@dataclass
|
||||
class Mount:
|
||||
"""Allows you to mount the contents of an AWS S3 bucket on disk when running an app on the cloud.
|
||||
|
||||
Arguments:
|
||||
source: The location which contains the external data which should be mounted in the
|
||||
running work. At the moment, only AWS S3 mounts are supported. This must be a full
|
||||
`s3` style identifier pointing to a bucket and (optionally) prefix to mount. For
|
||||
example: `s3://foo/bar/`.
|
||||
|
||||
root_dir: An absolute directory path in the work where external data source should
|
||||
be mounted as a filesystem. This path should not already exist in your codebase.
|
||||
If not included, then the root_dir will be set to `/data/<last folder name in the bucket>`
|
||||
"""
|
||||
|
||||
source: str = ""
|
||||
root_dir: str = ""
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
|
||||
for protocol in __MOUNT_PROTOCOLS__:
|
||||
if self.source.startswith(protocol):
|
||||
protocol = protocol
|
||||
break
|
||||
else: # N.B. for-else loop
|
||||
raise ValueError(
|
||||
f"Unknown protocol for the mount 'source' argument '{self.source}`. The 'source' "
|
||||
f"string must start with one of the following prefixes: {__MOUNT_PROTOCOLS__}"
|
||||
)
|
||||
|
||||
if protocol == "s3://" and not self.source.endswith("/"):
|
||||
raise ValueError(
|
||||
"S3 mounts must end in a trailing slash (`/`) to indicate a folder is being mounted. "
|
||||
f"Received: '{self.source}'. Mounting a single file is not currently supported."
|
||||
)
|
||||
|
||||
if self.root_dir == "":
|
||||
self.root_dir = f"/data/{Path(self.source).stem}"
|
||||
|
||||
@property
|
||||
def protocol(self) -> str:
|
||||
"""The backing storage protocol indicated by this drive source."""
|
||||
for protocol in __MOUNT_PROTOCOLS__:
|
||||
if self.source.startswith(protocol):
|
||||
return protocol
|
||||
return ""
|
|
@ -6,10 +6,10 @@ from types import FrameType
|
|||
from typing import cast, List, Optional, TYPE_CHECKING, Union
|
||||
|
||||
from lightning_app.utilities.app_helpers import Logger
|
||||
from lightning_app.utilities.packaging.cloud_compute import CloudCompute
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from lightning_app import LightningWork
|
||||
from lightning_app.utilities.packaging.cloud_compute import CloudCompute
|
||||
|
||||
logger = Logger(__name__)
|
||||
|
||||
|
|
|
@ -1,8 +1,9 @@
|
|||
from dataclasses import asdict, dataclass
|
||||
from typing import Dict, List, Optional, Union
|
||||
from typing import Dict, List, Optional, Tuple, Union
|
||||
from uuid import uuid4
|
||||
|
||||
from lightning_app.core.constants import ENABLE_MULTIPLE_WORKS_IN_NON_DEFAULT_CONTAINER
|
||||
from lightning_app.storage.mount import Mount
|
||||
|
||||
__CLOUD_COMPUTE_IDENTIFIER__ = "__cloud_compute__"
|
||||
|
||||
|
@ -47,7 +48,8 @@ _CLOUD_COMPUTE_STORE = {}
|
|||
|
||||
@dataclass
|
||||
class CloudCompute:
|
||||
"""
|
||||
"""Configure the cloud runtime for a lightning work or flow.
|
||||
|
||||
Arguments:
|
||||
name: The name of the hardware to use. A full list of supported options can be found in
|
||||
:doc:`/core_api/lightning_work/compute`. If you have a request for more hardware options, please contact
|
||||
|
@ -77,6 +79,8 @@ class CloudCompute:
|
|||
|
||||
shm_size: Shared memory size in MiB, backed by RAM. min 512, max 8192, it will auto update in steps of 512.
|
||||
For example 1100 will become 1024. If set to zero (the default) will get the default 64MiB inside docker.
|
||||
|
||||
mounts: External data sources which should be mounted into a work as a filesystem at runtime.
|
||||
"""
|
||||
|
||||
name: str = "default"
|
||||
|
@ -86,9 +90,12 @@ class CloudCompute:
|
|||
wait_timeout: Optional[int] = None
|
||||
idle_timeout: Optional[int] = None
|
||||
shm_size: Optional[int] = 0
|
||||
mounts: Optional[Union[Mount, List[Mount]]] = None
|
||||
_internal_id: Optional[str] = None
|
||||
|
||||
def __post_init__(self):
|
||||
def __post_init__(self) -> None:
|
||||
_verify_mount_root_dirs_are_unique(self.mounts)
|
||||
|
||||
if self.clusters:
|
||||
raise ValueError("Clusters are't supported yet. Coming soon.")
|
||||
if self.wait_timeout:
|
||||
|
@ -100,12 +107,28 @@ class CloudCompute:
|
|||
if self._internal_id is None:
|
||||
self._internal_id = "default" if self.name == "default" else uuid4().hex[:7]
|
||||
|
||||
def to_dict(self):
|
||||
def to_dict(self) -> dict:
|
||||
_verify_mount_root_dirs_are_unique(self.mounts)
|
||||
return {"type": __CLOUD_COMPUTE_IDENTIFIER__, **asdict(self)}
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, d):
|
||||
def from_dict(cls, d: dict) -> "CloudCompute":
|
||||
assert d.pop("type") == __CLOUD_COMPUTE_IDENTIFIER__
|
||||
mounts = d.pop("mounts", None)
|
||||
if mounts is None:
|
||||
pass
|
||||
elif isinstance(mounts, dict):
|
||||
d["mounts"] = Mount(**mounts)
|
||||
elif isinstance(mounts, (list)):
|
||||
d["mounts"] = []
|
||||
for mount in mounts:
|
||||
d["mounts"].append(Mount(**mount))
|
||||
else:
|
||||
raise TypeError(
|
||||
f"mounts argument must be one of [None, Mount, List[Mount]], "
|
||||
f"received {mounts} of type {type(mounts)}"
|
||||
)
|
||||
_verify_mount_root_dirs_are_unique(d.get("mounts", None))
|
||||
return cls(**d)
|
||||
|
||||
@property
|
||||
|
@ -116,6 +139,13 @@ class CloudCompute:
|
|||
return self.name == "default"
|
||||
|
||||
|
||||
def _verify_mount_root_dirs_are_unique(mounts: Union[None, Mount, List[Mount], Tuple[Mount]]) -> None:
|
||||
if isinstance(mounts, (list, tuple, set)):
|
||||
root_dirs = [mount.root_dir for mount in mounts]
|
||||
if len(set(root_dirs)) != len(root_dirs):
|
||||
raise ValueError("Every Mount attached to a work must have a unique 'root_dir' argument.")
|
||||
|
||||
|
||||
def _maybe_create_cloud_compute(state: Dict) -> Union[CloudCompute, Dict]:
|
||||
if state and __CLOUD_COMPUTE_IDENTIFIER__ == state.get("type", None):
|
||||
cloud_compute = CloudCompute.from_dict(state)
|
||||
|
|
|
@ -333,6 +333,7 @@ def test_lightning_flow_and_work():
|
|||
"preemptible": False,
|
||||
"wait_timeout": None,
|
||||
"idle_timeout": None,
|
||||
"mounts": None,
|
||||
"shm_size": 0,
|
||||
"_internal_id": "default",
|
||||
},
|
||||
|
@ -358,6 +359,7 @@ def test_lightning_flow_and_work():
|
|||
"preemptible": False,
|
||||
"wait_timeout": None,
|
||||
"idle_timeout": None,
|
||||
"mounts": None,
|
||||
"shm_size": 0,
|
||||
"_internal_id": "default",
|
||||
},
|
||||
|
@ -399,6 +401,7 @@ def test_lightning_flow_and_work():
|
|||
"preemptible": False,
|
||||
"wait_timeout": None,
|
||||
"idle_timeout": None,
|
||||
"mounts": None,
|
||||
"shm_size": 0,
|
||||
"_internal_id": "default",
|
||||
},
|
||||
|
@ -424,6 +427,7 @@ def test_lightning_flow_and_work():
|
|||
"preemptible": False,
|
||||
"wait_timeout": None,
|
||||
"idle_timeout": None,
|
||||
"mounts": None,
|
||||
"shm_size": 0,
|
||||
"_internal_id": "default",
|
||||
},
|
||||
|
|
|
@ -32,7 +32,7 @@ from lightning_cloud.openapi import (
|
|||
|
||||
from lightning_app import LightningApp, LightningWork
|
||||
from lightning_app.runners import backends, cloud
|
||||
from lightning_app.storage import Drive
|
||||
from lightning_app.storage import Drive, Mount
|
||||
from lightning_app.utilities.cloud import _get_project
|
||||
from lightning_app.utilities.dependency_caching import get_hash
|
||||
|
||||
|
@ -54,8 +54,8 @@ class WorkWithSingleDrive(LightningWork):
|
|||
class WorkWithTwoDrives(LightningWork):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.lit_drive = None
|
||||
self.s3_drive = None
|
||||
self.lit_drive_1 = None
|
||||
self.lit_drive_2 = None
|
||||
|
||||
def run(self):
|
||||
pass
|
||||
|
@ -474,21 +474,10 @@ class TestAppCreationClient:
|
|||
# should be the results of the deepcopy operation (an instance of the original class)
|
||||
mocked_lit_drive.__deepcopy__.return_value = copy(mocked_lit_drive)
|
||||
|
||||
mocked_s3_drive = MagicMock(spec=Drive)
|
||||
setattr(mocked_s3_drive, "id", "some-bucket/path/")
|
||||
setattr(mocked_s3_drive, "protocol", "s3://")
|
||||
setattr(mocked_s3_drive, "component_name", "test-work")
|
||||
setattr(mocked_s3_drive, "allow_duplicates", False)
|
||||
setattr(mocked_s3_drive, "root_folder", "/hello/")
|
||||
# deepcopy on a MagicMock instance will return an empty magicmock instance. To
|
||||
# overcome this we set the __deepcopy__ method `return_value` to equal what
|
||||
# should be the results of the deepcopy operation (an instance of the original class)
|
||||
mocked_s3_drive.__deepcopy__.return_value = copy(mocked_s3_drive)
|
||||
|
||||
work = WorkWithTwoDrives()
|
||||
monkeypatch.setattr(work, "lit_drive", mocked_lit_drive)
|
||||
monkeypatch.setattr(work, "s3_drive", mocked_s3_drive)
|
||||
monkeypatch.setattr(work, "_state", {"_port", "_name", "lit_drive", "s3_drive"})
|
||||
monkeypatch.setattr(work, "lit_drive_1", mocked_lit_drive)
|
||||
monkeypatch.setattr(work, "lit_drive_2", mocked_lit_drive)
|
||||
monkeypatch.setattr(work, "_state", {"_port", "_name", "lit_drive_1", "lit_drive_2"})
|
||||
monkeypatch.setattr(work, "_name", "test-work")
|
||||
monkeypatch.setattr(work._cloud_build_config, "build_commands", lambda: ["echo 'start'"])
|
||||
monkeypatch.setattr(work._cloud_build_config, "requirements", ["torch==1.0.0", "numpy==1.0.0"])
|
||||
|
@ -507,24 +496,24 @@ class TestAppCreationClient:
|
|||
cloud_runtime.dispatch()
|
||||
|
||||
if lightningapps:
|
||||
s3_drive_spec = V1LightningworkDrives(
|
||||
lit_drive_1_spec = V1LightningworkDrives(
|
||||
drive=V1Drive(
|
||||
metadata=V1Metadata(
|
||||
name="test-work.s3_drive",
|
||||
name="test-work.lit_drive_1",
|
||||
),
|
||||
spec=V1DriveSpec(
|
||||
drive_type=V1DriveType.INDEXED_S3,
|
||||
drive_type=V1DriveType.NO_MOUNT_S3,
|
||||
source_type=V1SourceType.S3,
|
||||
source="s3://some-bucket/path/",
|
||||
source="lit://foobar",
|
||||
),
|
||||
status=V1DriveStatus(),
|
||||
),
|
||||
mount_location="/hello/",
|
||||
mount_location=str(tmpdir),
|
||||
)
|
||||
lit_drive_spec = V1LightningworkDrives(
|
||||
lit_drive_2_spec = V1LightningworkDrives(
|
||||
drive=V1Drive(
|
||||
metadata=V1Metadata(
|
||||
name="test-work.lit_drive",
|
||||
name="test-work.lit_drive_2",
|
||||
),
|
||||
spec=V1DriveSpec(
|
||||
drive_type=V1DriveType.NO_MOUNT_S3,
|
||||
|
@ -562,7 +551,7 @@ class TestAppCreationClient:
|
|||
),
|
||||
image="random_base_public_image",
|
||||
),
|
||||
drives=[lit_drive_spec, s3_drive_spec],
|
||||
drives=[lit_drive_2_spec, lit_drive_1_spec],
|
||||
user_requested_compute_config=V1UserRequestedComputeConfig(
|
||||
name="default", count=1, disk_size=0, preemptible=False, shm_size=0
|
||||
),
|
||||
|
@ -595,7 +584,7 @@ class TestAppCreationClient:
|
|||
),
|
||||
image="random_base_public_image",
|
||||
),
|
||||
drives=[s3_drive_spec, lit_drive_spec],
|
||||
drives=[lit_drive_1_spec, lit_drive_2_spec],
|
||||
user_requested_compute_config=V1UserRequestedComputeConfig(
|
||||
name="default", count=1, disk_size=0, preemptible=False, shm_size=0
|
||||
),
|
||||
|
@ -632,6 +621,153 @@ class TestAppCreationClient:
|
|||
project_id="test-project-id", app_id=mock.ANY, id=mock.ANY, body=mock.ANY
|
||||
)
|
||||
|
||||
@mock.patch("lightning_app.runners.backends.cloud.LightningClient", mock.MagicMock())
|
||||
@pytest.mark.parametrize("lightningapps", [[], [MagicMock()]])
|
||||
def test_call_with_work_app_and_attached_mount_and_drive(self, lightningapps, monkeypatch, tmpdir):
|
||||
source_code_root_dir = Path(tmpdir / "src").absolute()
|
||||
source_code_root_dir.mkdir()
|
||||
Path(source_code_root_dir / ".lightning").write_text("name: myapp")
|
||||
requirements_file = Path(source_code_root_dir / "requirements.txt")
|
||||
Path(requirements_file).touch()
|
||||
|
||||
mock_client = mock.MagicMock()
|
||||
if lightningapps:
|
||||
lightningapps[0].status.phase = V1LightningappInstanceState.STOPPED
|
||||
mock_client.lightningapp_instance_service_list_lightningapp_instances.return_value = (
|
||||
V1ListLightningappInstancesResponse(lightningapps=lightningapps)
|
||||
)
|
||||
lightning_app_instance = MagicMock()
|
||||
mock_client.lightningapp_v2_service_create_lightningapp_release = MagicMock(return_value=lightning_app_instance)
|
||||
mock_client.lightningapp_v2_service_create_lightningapp_release_instance = MagicMock(
|
||||
return_value=lightning_app_instance
|
||||
)
|
||||
existing_instance = MagicMock()
|
||||
existing_instance.status.phase = V1LightningappInstanceState.STOPPED
|
||||
mock_client.lightningapp_service_get_lightningapp = MagicMock(return_value=existing_instance)
|
||||
cloud_backend = mock.MagicMock()
|
||||
cloud_backend.client = mock_client
|
||||
monkeypatch.setattr(backends, "CloudBackend", mock.MagicMock(return_value=cloud_backend))
|
||||
monkeypatch.setattr(cloud, "LocalSourceCodeDir", mock.MagicMock())
|
||||
monkeypatch.setattr(cloud, "_prepare_lightning_wheels_and_requirements", mock.MagicMock())
|
||||
app = mock.MagicMock()
|
||||
flow = mock.MagicMock()
|
||||
|
||||
mocked_drive = MagicMock(spec=Drive)
|
||||
setattr(mocked_drive, "id", "foobar")
|
||||
setattr(mocked_drive, "protocol", "lit://")
|
||||
setattr(mocked_drive, "component_name", "test-work")
|
||||
setattr(mocked_drive, "allow_duplicates", False)
|
||||
setattr(mocked_drive, "root_folder", tmpdir)
|
||||
# deepcopy on a MagicMock instance will return an empty magicmock instance. To
|
||||
# overcome this we set the __deepcopy__ method `return_value` to equal what
|
||||
# should be the results of the deepcopy operation (an instance of the original class)
|
||||
mocked_drive.__deepcopy__.return_value = copy(mocked_drive)
|
||||
|
||||
mocked_mount = MagicMock(spec=Mount)
|
||||
setattr(mocked_mount, "source", "s3://foo/")
|
||||
setattr(mocked_mount, "root_dir", "/content/")
|
||||
setattr(mocked_mount, "protocol", "s3://")
|
||||
|
||||
work = WorkWithSingleDrive()
|
||||
monkeypatch.setattr(work, "drive", mocked_drive)
|
||||
monkeypatch.setattr(work, "_state", {"_port", "drive"})
|
||||
monkeypatch.setattr(work, "_name", "test-work")
|
||||
monkeypatch.setattr(work._cloud_build_config, "build_commands", lambda: ["echo 'start'"])
|
||||
monkeypatch.setattr(work._cloud_build_config, "requirements", ["torch==1.0.0", "numpy==1.0.0"])
|
||||
monkeypatch.setattr(work._cloud_build_config, "image", "random_base_public_image")
|
||||
monkeypatch.setattr(work._cloud_compute, "disk_size", 0)
|
||||
monkeypatch.setattr(work._cloud_compute, "preemptible", False)
|
||||
monkeypatch.setattr(work._cloud_compute, "mounts", mocked_mount)
|
||||
monkeypatch.setattr(work, "_port", 8080)
|
||||
|
||||
flow.works = lambda recurse: [work]
|
||||
app.flows = [flow]
|
||||
cloud_runtime = cloud.CloudRuntime(app=app, entrypoint_file=(source_code_root_dir / "entrypoint.py"))
|
||||
monkeypatch.setattr(
|
||||
"lightning_app.runners.cloud._get_project",
|
||||
lambda x: V1Membership(name="test-project", project_id="test-project-id"),
|
||||
)
|
||||
cloud_runtime.dispatch()
|
||||
|
||||
if lightningapps:
|
||||
expected_body = Body8(
|
||||
description=None,
|
||||
local_source=True,
|
||||
app_entrypoint_file="entrypoint.py",
|
||||
enable_app_server=True,
|
||||
flow_servers=[],
|
||||
dependency_cache_key=get_hash(requirements_file),
|
||||
image_spec=Gridv1ImageSpec(
|
||||
dependency_file_info=V1DependencyFileInfo(
|
||||
package_manager=V1PackageManager.PIP, path="requirements.txt"
|
||||
)
|
||||
),
|
||||
works=[
|
||||
V1Work(
|
||||
name="test-work",
|
||||
spec=V1LightningworkSpec(
|
||||
build_spec=V1BuildSpec(
|
||||
commands=["echo 'start'"],
|
||||
python_dependencies=V1PythonDependencyInfo(
|
||||
package_manager=V1PackageManager.PIP, packages="torch==1.0.0\nnumpy==1.0.0"
|
||||
),
|
||||
image="random_base_public_image",
|
||||
),
|
||||
drives=[
|
||||
V1LightningworkDrives(
|
||||
drive=V1Drive(
|
||||
metadata=V1Metadata(
|
||||
name="test-work.drive",
|
||||
),
|
||||
spec=V1DriveSpec(
|
||||
drive_type=V1DriveType.NO_MOUNT_S3,
|
||||
source_type=V1SourceType.S3,
|
||||
source="lit://foobar",
|
||||
),
|
||||
status=V1DriveStatus(),
|
||||
),
|
||||
mount_location=str(tmpdir),
|
||||
),
|
||||
V1LightningworkDrives(
|
||||
drive=V1Drive(
|
||||
metadata=V1Metadata(
|
||||
name="test-work",
|
||||
),
|
||||
spec=V1DriveSpec(
|
||||
drive_type=V1DriveType.INDEXED_S3,
|
||||
source_type=V1SourceType.S3,
|
||||
source="s3://foo/",
|
||||
),
|
||||
status=V1DriveStatus(),
|
||||
),
|
||||
mount_location="/content/",
|
||||
),
|
||||
],
|
||||
user_requested_compute_config=V1UserRequestedComputeConfig(
|
||||
name="default", count=1, disk_size=0, preemptible=False, shm_size=0
|
||||
),
|
||||
network_config=[V1NetworkConfig(name=mock.ANY, host=None, port=8080)],
|
||||
),
|
||||
)
|
||||
],
|
||||
)
|
||||
mock_client.lightningapp_v2_service_create_lightningapp_release.assert_called_once_with(
|
||||
project_id="test-project-id", app_id=mock.ANY, body=expected_body
|
||||
)
|
||||
|
||||
# running dispatch with disabled dependency cache
|
||||
mock_client.reset_mock()
|
||||
monkeypatch.setattr(cloud, "DISABLE_DEPENDENCY_CACHE", True)
|
||||
expected_body.dependency_cache_key = None
|
||||
cloud_runtime.dispatch()
|
||||
mock_client.lightningapp_v2_service_create_lightningapp_release.assert_called_once_with(
|
||||
project_id="test-project-id", app_id=mock.ANY, body=expected_body
|
||||
)
|
||||
else:
|
||||
mock_client.lightningapp_v2_service_create_lightningapp_release_instance.assert_called_once_with(
|
||||
project_id="test-project-id", app_id=mock.ANY, id=mock.ANY, body=mock.ANY
|
||||
)
|
||||
|
||||
|
||||
@mock.patch("lightning_app.core.queues.QueuingSystem", MagicMock())
|
||||
@mock.patch("lightning_app.runners.backends.cloud.LightningClient", MagicMock())
|
||||
|
|
|
@ -213,44 +213,7 @@ def test_lit_drive():
|
|||
os.remove("a.txt")
|
||||
|
||||
|
||||
def test_s3_drives():
|
||||
drive = Drive("s3://foo/", allow_duplicates=True)
|
||||
drive.component_name = "root.work"
|
||||
|
||||
with pytest.raises(
|
||||
Exception, match="S3 based drives cannot currently add files via this API. Did you mean to use `lit://` drives?"
|
||||
):
|
||||
drive.put("a.txt")
|
||||
with pytest.raises(
|
||||
Exception,
|
||||
match="S3 based drives cannot currently list files via this API. Did you mean to use `lit://` drives?",
|
||||
):
|
||||
drive.list("a.txt")
|
||||
with pytest.raises(
|
||||
Exception, match="S3 based drives cannot currently get files via this API. Did you mean to use `lit://` drives?"
|
||||
):
|
||||
drive.get("a.txt")
|
||||
with pytest.raises(
|
||||
Exception,
|
||||
match="S3 based drives cannot currently delete files via this API. Did you mean to use `lit://` drives?",
|
||||
):
|
||||
drive.delete("a.txt")
|
||||
|
||||
_set_flow_context()
|
||||
with pytest.raises(Exception, match="The flow isn't allowed to put files into a Drive."):
|
||||
drive.put("a.txt")
|
||||
with pytest.raises(Exception, match="The flow isn't allowed to list files from a Drive."):
|
||||
drive.list("a.txt")
|
||||
with pytest.raises(Exception, match="The flow isn't allowed to get files from a Drive."):
|
||||
drive.get("a.txt")
|
||||
|
||||
|
||||
def test_create_s3_drive_without_trailing_slash_fails():
|
||||
with pytest.raises(ValueError, match="S3 drives must end in a trailing slash"):
|
||||
Drive("s3://foo")
|
||||
|
||||
|
||||
@pytest.mark.parametrize("drive_id", ["lit://drive", "s3://drive/"])
|
||||
@pytest.mark.parametrize("drive_id", ["lit://drive"])
|
||||
def test_maybe_create_drive(drive_id):
|
||||
drive = Drive(drive_id, allow_duplicates=False)
|
||||
drive.component_name = "root.work1"
|
||||
|
@ -260,7 +223,7 @@ def test_maybe_create_drive(drive_id):
|
|||
assert new_drive.component_name == drive.component_name
|
||||
|
||||
|
||||
@pytest.mark.parametrize("drive_id", ["lit://drive", "s3://drive/"])
|
||||
@pytest.mark.parametrize("drive_id", ["lit://drive"])
|
||||
def test_drive_deepcopy(drive_id):
|
||||
drive = Drive(drive_id, allow_duplicates=True)
|
||||
drive.component_name = "root.work1"
|
||||
|
@ -269,8 +232,9 @@ def test_drive_deepcopy(drive_id):
|
|||
assert new_drive.component_name == drive.component_name
|
||||
|
||||
|
||||
def test_drive_root_folder_pass():
|
||||
Drive("s3://drive/", root_folder="a")
|
||||
def test_s3_drive_raises_error_telling_users_to_use_mounts():
|
||||
with pytest.raises(ValueError, match="Using S3 buckets in a Drive is no longer supported."):
|
||||
Drive("s3://foo/")
|
||||
|
||||
|
||||
def test_drive_root_folder_breaks():
|
||||
|
|
|
@ -0,0 +1,37 @@
|
|||
import pytest
|
||||
|
||||
from lightning_app.storage.mount import Mount
|
||||
|
||||
|
||||
def test_create_s3_mount_successfully():
|
||||
mount = Mount(source="s3://foo/bar/", root_dir="./foo")
|
||||
assert mount.source == "s3://foo/bar/"
|
||||
assert mount.root_dir == "./foo"
|
||||
assert mount.protocol == "s3://"
|
||||
|
||||
|
||||
def test_create_non_s3_mount_fails():
|
||||
with pytest.raises(ValueError, match="Unknown protocol for the mount 'source' argument"):
|
||||
Mount(source="foo/bar/", root_dir="./foo")
|
||||
|
||||
with pytest.raises(ValueError, match="Unknown protocol for the mount 'source' argument"):
|
||||
Mount(source="gcs://foo/bar/", root_dir="./foo")
|
||||
|
||||
with pytest.raises(ValueError, match="Unknown protocol for the mount 'source' argument"):
|
||||
Mount(source="3://foo/bar/", root_dir="./foo")
|
||||
|
||||
|
||||
def test_create_s3_mount_without_directory_prefix_fails():
|
||||
with pytest.raises(ValueError, match="S3 mounts must end in a trailing slash"):
|
||||
Mount(source="s3://foo/bar", root_dir="./foo")
|
||||
|
||||
with pytest.raises(ValueError, match="S3 mounts must end in a trailing slash"):
|
||||
Mount(source="s3://foo", root_dir="./foo")
|
||||
|
||||
|
||||
def test_create_mount_without_root_dir_argument():
|
||||
m = Mount(source="s3://foo/")
|
||||
assert m.root_dir == "/data/foo"
|
||||
|
||||
m = Mount(source="s3://foo/bar/")
|
||||
assert m.root_dir == "/data/bar"
|
|
@ -53,6 +53,7 @@ def test_dict():
|
|||
"preemptible": False,
|
||||
"wait_timeout": None,
|
||||
"idle_timeout": None,
|
||||
"mounts": None,
|
||||
"shm_size": 0,
|
||||
"_internal_id": "default",
|
||||
},
|
||||
|
@ -87,6 +88,7 @@ def test_dict():
|
|||
"preemptible": False,
|
||||
"wait_timeout": None,
|
||||
"idle_timeout": None,
|
||||
"mounts": None,
|
||||
"shm_size": 0,
|
||||
"_internal_id": "default",
|
||||
},
|
||||
|
@ -121,6 +123,7 @@ def test_dict():
|
|||
"preemptible": False,
|
||||
"wait_timeout": None,
|
||||
"idle_timeout": None,
|
||||
"mounts": None,
|
||||
"shm_size": 0,
|
||||
"_internal_id": "default",
|
||||
},
|
||||
|
@ -207,6 +210,7 @@ def test_list():
|
|||
"preemptible": False,
|
||||
"wait_timeout": None,
|
||||
"idle_timeout": None,
|
||||
"mounts": None,
|
||||
"shm_size": 0,
|
||||
"_internal_id": "default",
|
||||
},
|
||||
|
@ -241,6 +245,7 @@ def test_list():
|
|||
"preemptible": False,
|
||||
"wait_timeout": None,
|
||||
"idle_timeout": None,
|
||||
"mounts": None,
|
||||
"shm_size": 0,
|
||||
"_internal_id": "default",
|
||||
},
|
||||
|
@ -270,6 +275,7 @@ def test_list():
|
|||
"preemptible": False,
|
||||
"wait_timeout": None,
|
||||
"idle_timeout": None,
|
||||
"mounts": None,
|
||||
"shm_size": 0,
|
||||
"_internal_id": "default",
|
||||
},
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
import pytest
|
||||
|
||||
from lightning_app import CloudCompute
|
||||
from lightning_app.storage import Mount
|
||||
|
||||
|
||||
def test_cloud_compute_unsupported_features():
|
||||
|
@ -17,6 +18,33 @@ def test_cloud_compute_names():
|
|||
|
||||
|
||||
def test_cloud_compute_shared_memory():
|
||||
|
||||
cloud_compute = CloudCompute("gpu", shm_size=1100)
|
||||
assert cloud_compute.shm_size == 1100
|
||||
|
||||
|
||||
def test_cloud_compute_with_mounts():
|
||||
mount_1 = Mount(source="s3://foo/", root_dir="./foo")
|
||||
mount_2 = Mount(source="s3://foo/bar/", root_dir="./bar")
|
||||
|
||||
cloud_compute = CloudCompute("gpu", mounts=mount_1)
|
||||
assert cloud_compute.mounts == mount_1
|
||||
|
||||
cloud_compute = CloudCompute("gpu", mounts=[mount_1, mount_2])
|
||||
assert cloud_compute.mounts == [mount_1, mount_2]
|
||||
|
||||
cc_dict = cloud_compute.to_dict()
|
||||
assert "mounts" in cc_dict
|
||||
assert cc_dict["mounts"] == [
|
||||
{"root_dir": "./foo", "source": "s3://foo/"},
|
||||
{"root_dir": "./bar", "source": "s3://foo/bar/"},
|
||||
]
|
||||
|
||||
assert CloudCompute.from_dict(cc_dict) == cloud_compute
|
||||
|
||||
|
||||
def test_cloud_compute_with_non_unique_mount_root_dirs():
|
||||
mount_1 = Mount(source="s3://foo/", root_dir="./foo")
|
||||
mount_2 = Mount(source="s3://foo/bar/", root_dir="./foo")
|
||||
|
||||
with pytest.raises(ValueError, match="Every Mount attached to a work must have a unique 'root_dir' argument."):
|
||||
CloudCompute("gpu", mounts=[mount_1, mount_2])
|
||||
|
|
Loading…
Reference in New Issue