Separate the concept of a Drive from that of a Mount (#15120)

* added mount class and configured it into compute config

* added mount to the cloud runtime dispatcher

* raise error if s3 bucket is passed to a drive telling the user to utilize mounts

* added example for app

* udpated tests

* updated tests

* addressed code review comments

* fix bug

* bugfix

* updates'

* code review comments

* updates

* fixed tests after rename

* fix tests
This commit is contained in:
Rick Izzo 2022-10-19 23:24:27 -04:00 committed by GitHub
parent 576757fd79
commit 7a827b09bb
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
14 changed files with 419 additions and 102 deletions

View File

@ -0,0 +1 @@
name: mount_test

35
examples/app_mount/app.py Normal file
View File

@ -0,0 +1,35 @@
import os
import lightning as L
from lightning_app import CloudCompute
from lightning_app.storage import Mount
class Work(L.LightningWork):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
def run(self):
files = os.listdir("/content/esRedditJson/")
for file in files:
print(file)
assert "esRedditJson1" in files
class Flow(L.LightningFlow):
def __init__(self):
super().__init__()
self.work_1 = Work(
cloud_compute=CloudCompute(
mounts=Mount(
source="s3://ryft-public-sample-data/esRedditJson/",
root_dir="/content/esRedditJson/",
),
)
)
def run(self):
self.work_1.run()
app = L.LightningApp(Flow(), debug=True)

View File

@ -52,7 +52,7 @@ from lightning_app.core.constants import (
from lightning_app.runners.backends.cloud import CloudBackend
from lightning_app.runners.runtime import Runtime
from lightning_app.source_code import LocalSourceCodeDir
from lightning_app.storage import Drive
from lightning_app.storage import Drive, Mount
from lightning_app.utilities.app_helpers import Logger
from lightning_app.utilities.cloud import _get_project
from lightning_app.utilities.dependency_caching import get_hash
@ -148,9 +148,6 @@ class CloudRuntime(Runtime):
if drive.protocol == "lit://":
drive_type = V1DriveType.NO_MOUNT_S3
source_type = V1SourceType.S3
elif drive.protocol == "s3://":
drive_type = V1DriveType.INDEXED_S3
source_type = V1SourceType.S3
else:
raise RuntimeError(
f"unknown drive protocol `{drive.protocol}`. Please verify this "
@ -174,6 +171,19 @@ class CloudRuntime(Runtime):
),
)
# TODO: Move this to the CloudCompute class and update backend
if work.cloud_compute.mounts is not None:
mounts = work.cloud_compute.mounts
if isinstance(mounts, Mount):
mounts = [mounts]
for mount in mounts:
drive_specs.append(
_create_mount_drive_spec(
work_name=work.name,
mount=mount,
)
)
random_name = "".join(random.choice(string.ascii_lowercase) for _ in range(5))
spec = V1LightningworkSpec(
build_spec=build_spec,
@ -398,3 +408,29 @@ class CloudRuntime(Runtime):
balance = 0 # value is missing in some tests
return balance >= 1
def _create_mount_drive_spec(work_name: str, mount: Mount) -> V1LightningworkDrives:
if mount.protocol == "s3://":
drive_type = V1DriveType.INDEXED_S3
source_type = V1SourceType.S3
else:
raise RuntimeError(
f"unknown mount protocol `{mount.protocol}`. Please verify this "
f"drive type has been configured for use in the cloud dispatcher."
)
return V1LightningworkDrives(
drive=V1Drive(
metadata=V1Metadata(
name=work_name,
),
spec=V1DriveSpec(
drive_type=drive_type,
source_type=source_type,
source=mount.source,
),
status=V1DriveStatus(),
),
mount_location=str(mount.root_dir),
)

View File

@ -1,3 +1,4 @@
from lightning_app.storage.drive import Drive # noqa: F401
from lightning_app.storage.mount import Mount # noqa: F401
from lightning_app.storage.path import Path # noqa: F401
from lightning_app.storage.payload import Payload # noqa: F401

View File

@ -13,7 +13,7 @@ from lightning_app.utilities.component import _is_flow_context
class Drive:
__IDENTIFIER__ = "__drive__"
__PROTOCOLS__ = ["lit://", "s3://"]
__PROTOCOLS__ = ["lit://"]
def __init__(
self,
@ -34,6 +34,13 @@ class Drive:
When not provided, it is automatically inferred by Lightning.
root_folder: This is the folder from where the Drive perceives the data (e.g this acts as a mount dir).
"""
if id.startswith("s3://"):
raise ValueError(
"Using S3 buckets in a Drive is no longer supported. Please pass an S3 `Mount` to "
"a Work's CloudCompute config in order to mount an s3 bucket as a filesystem in a work.\n"
f"`CloudCompute(mount=Mount({id}), ...)`"
)
self.id = None
self.protocol = None
for protocol in self.__PROTOCOLS__:
@ -47,16 +54,10 @@ class Drive:
f"must start with one of the following prefixes {self.__PROTOCOLS__}"
)
if self.protocol == "s3://" and not self.id.endswith("/"):
raise ValueError(
"S3 drives must end in a trailing slash (`/`) to indicate a folder is being mounted. "
f"Recieved: '{id}'. Mounting a single file is not currently supported."
)
if not self.id:
raise Exception(f"The Drive id needs to start with one of the following protocols: {self.__PROTOCOLS__}")
if self.protocol != "s3://" and "/" in self.id:
if "/" in self.id:
raise Exception(f"The id should be unique to identify your drive. Found `{self.id}`.")
self.root_folder = pathlib.Path(root_folder).resolve() if root_folder else pathlib.Path(os.getcwd())
@ -88,10 +89,6 @@ class Drive:
raise Exception("The component name needs to be known to put a path to the Drive.")
if _is_flow_context():
raise Exception("The flow isn't allowed to put files into a Drive.")
if self.protocol == "s3://":
raise PermissionError(
"S3 based drives cannot currently add files via this API. Did you mean to use `lit://` drives?"
)
self._validate_path(path)
@ -115,10 +112,6 @@ class Drive:
"""
if _is_flow_context():
raise Exception("The flow isn't allowed to list files from a Drive.")
if self.protocol == "s3://":
raise PermissionError(
"S3 based drives cannot currently list files via this API. Did you mean to use `lit://` drives?"
)
if component_name:
paths = [
@ -163,10 +156,6 @@ class Drive:
"""
if _is_flow_context():
raise Exception("The flow isn't allowed to get files from a Drive.")
if self.protocol == "s3://":
raise PermissionError(
"S3 based drives cannot currently get files via this API. Did you mean to use `lit://` drives?"
)
if component_name:
shared_path = self._to_shared_path(
@ -214,10 +203,6 @@ class Drive:
"""
if not self.component_name:
raise Exception("The component name needs to be known to delete a path to the Drive.")
if self.protocol == "s3://":
raise PermissionError(
"S3 based drives cannot currently delete files via this API. Did you mean to use `lit://` drives?"
)
shared_path = self._to_shared_path(
path,

View File

@ -0,0 +1,54 @@
from dataclasses import dataclass
from pathlib import Path
from typing import List
__MOUNT_IDENTIFIER__: str = "__mount__"
__MOUNT_PROTOCOLS__: List[str] = ["s3://"]
@dataclass
class Mount:
"""Allows you to mount the contents of an AWS S3 bucket on disk when running an app on the cloud.
Arguments:
source: The location which contains the external data which should be mounted in the
running work. At the moment, only AWS S3 mounts are supported. This must be a full
`s3` style identifier pointing to a bucket and (optionally) prefix to mount. For
example: `s3://foo/bar/`.
root_dir: An absolute directory path in the work where external data source should
be mounted as a filesystem. This path should not already exist in your codebase.
If not included, then the root_dir will be set to `/data/<last folder name in the bucket>`
"""
source: str = ""
root_dir: str = ""
def __post_init__(self) -> None:
for protocol in __MOUNT_PROTOCOLS__:
if self.source.startswith(protocol):
protocol = protocol
break
else: # N.B. for-else loop
raise ValueError(
f"Unknown protocol for the mount 'source' argument '{self.source}`. The 'source' "
f"string must start with one of the following prefixes: {__MOUNT_PROTOCOLS__}"
)
if protocol == "s3://" and not self.source.endswith("/"):
raise ValueError(
"S3 mounts must end in a trailing slash (`/`) to indicate a folder is being mounted. "
f"Received: '{self.source}'. Mounting a single file is not currently supported."
)
if self.root_dir == "":
self.root_dir = f"/data/{Path(self.source).stem}"
@property
def protocol(self) -> str:
"""The backing storage protocol indicated by this drive source."""
for protocol in __MOUNT_PROTOCOLS__:
if self.source.startswith(protocol):
return protocol
return ""

View File

@ -6,10 +6,10 @@ from types import FrameType
from typing import cast, List, Optional, TYPE_CHECKING, Union
from lightning_app.utilities.app_helpers import Logger
from lightning_app.utilities.packaging.cloud_compute import CloudCompute
if TYPE_CHECKING:
from lightning_app import LightningWork
from lightning_app.utilities.packaging.cloud_compute import CloudCompute
logger = Logger(__name__)

View File

@ -1,8 +1,9 @@
from dataclasses import asdict, dataclass
from typing import Dict, List, Optional, Union
from typing import Dict, List, Optional, Tuple, Union
from uuid import uuid4
from lightning_app.core.constants import ENABLE_MULTIPLE_WORKS_IN_NON_DEFAULT_CONTAINER
from lightning_app.storage.mount import Mount
__CLOUD_COMPUTE_IDENTIFIER__ = "__cloud_compute__"
@ -47,7 +48,8 @@ _CLOUD_COMPUTE_STORE = {}
@dataclass
class CloudCompute:
"""
"""Configure the cloud runtime for a lightning work or flow.
Arguments:
name: The name of the hardware to use. A full list of supported options can be found in
:doc:`/core_api/lightning_work/compute`. If you have a request for more hardware options, please contact
@ -77,6 +79,8 @@ class CloudCompute:
shm_size: Shared memory size in MiB, backed by RAM. min 512, max 8192, it will auto update in steps of 512.
For example 1100 will become 1024. If set to zero (the default) will get the default 64MiB inside docker.
mounts: External data sources which should be mounted into a work as a filesystem at runtime.
"""
name: str = "default"
@ -86,9 +90,12 @@ class CloudCompute:
wait_timeout: Optional[int] = None
idle_timeout: Optional[int] = None
shm_size: Optional[int] = 0
mounts: Optional[Union[Mount, List[Mount]]] = None
_internal_id: Optional[str] = None
def __post_init__(self):
def __post_init__(self) -> None:
_verify_mount_root_dirs_are_unique(self.mounts)
if self.clusters:
raise ValueError("Clusters are't supported yet. Coming soon.")
if self.wait_timeout:
@ -100,12 +107,28 @@ class CloudCompute:
if self._internal_id is None:
self._internal_id = "default" if self.name == "default" else uuid4().hex[:7]
def to_dict(self):
def to_dict(self) -> dict:
_verify_mount_root_dirs_are_unique(self.mounts)
return {"type": __CLOUD_COMPUTE_IDENTIFIER__, **asdict(self)}
@classmethod
def from_dict(cls, d):
def from_dict(cls, d: dict) -> "CloudCompute":
assert d.pop("type") == __CLOUD_COMPUTE_IDENTIFIER__
mounts = d.pop("mounts", None)
if mounts is None:
pass
elif isinstance(mounts, dict):
d["mounts"] = Mount(**mounts)
elif isinstance(mounts, (list)):
d["mounts"] = []
for mount in mounts:
d["mounts"].append(Mount(**mount))
else:
raise TypeError(
f"mounts argument must be one of [None, Mount, List[Mount]], "
f"received {mounts} of type {type(mounts)}"
)
_verify_mount_root_dirs_are_unique(d.get("mounts", None))
return cls(**d)
@property
@ -116,6 +139,13 @@ class CloudCompute:
return self.name == "default"
def _verify_mount_root_dirs_are_unique(mounts: Union[None, Mount, List[Mount], Tuple[Mount]]) -> None:
if isinstance(mounts, (list, tuple, set)):
root_dirs = [mount.root_dir for mount in mounts]
if len(set(root_dirs)) != len(root_dirs):
raise ValueError("Every Mount attached to a work must have a unique 'root_dir' argument.")
def _maybe_create_cloud_compute(state: Dict) -> Union[CloudCompute, Dict]:
if state and __CLOUD_COMPUTE_IDENTIFIER__ == state.get("type", None):
cloud_compute = CloudCompute.from_dict(state)

View File

@ -333,6 +333,7 @@ def test_lightning_flow_and_work():
"preemptible": False,
"wait_timeout": None,
"idle_timeout": None,
"mounts": None,
"shm_size": 0,
"_internal_id": "default",
},
@ -358,6 +359,7 @@ def test_lightning_flow_and_work():
"preemptible": False,
"wait_timeout": None,
"idle_timeout": None,
"mounts": None,
"shm_size": 0,
"_internal_id": "default",
},
@ -399,6 +401,7 @@ def test_lightning_flow_and_work():
"preemptible": False,
"wait_timeout": None,
"idle_timeout": None,
"mounts": None,
"shm_size": 0,
"_internal_id": "default",
},
@ -424,6 +427,7 @@ def test_lightning_flow_and_work():
"preemptible": False,
"wait_timeout": None,
"idle_timeout": None,
"mounts": None,
"shm_size": 0,
"_internal_id": "default",
},

View File

@ -32,7 +32,7 @@ from lightning_cloud.openapi import (
from lightning_app import LightningApp, LightningWork
from lightning_app.runners import backends, cloud
from lightning_app.storage import Drive
from lightning_app.storage import Drive, Mount
from lightning_app.utilities.cloud import _get_project
from lightning_app.utilities.dependency_caching import get_hash
@ -54,8 +54,8 @@ class WorkWithSingleDrive(LightningWork):
class WorkWithTwoDrives(LightningWork):
def __init__(self):
super().__init__()
self.lit_drive = None
self.s3_drive = None
self.lit_drive_1 = None
self.lit_drive_2 = None
def run(self):
pass
@ -474,21 +474,10 @@ class TestAppCreationClient:
# should be the results of the deepcopy operation (an instance of the original class)
mocked_lit_drive.__deepcopy__.return_value = copy(mocked_lit_drive)
mocked_s3_drive = MagicMock(spec=Drive)
setattr(mocked_s3_drive, "id", "some-bucket/path/")
setattr(mocked_s3_drive, "protocol", "s3://")
setattr(mocked_s3_drive, "component_name", "test-work")
setattr(mocked_s3_drive, "allow_duplicates", False)
setattr(mocked_s3_drive, "root_folder", "/hello/")
# deepcopy on a MagicMock instance will return an empty magicmock instance. To
# overcome this we set the __deepcopy__ method `return_value` to equal what
# should be the results of the deepcopy operation (an instance of the original class)
mocked_s3_drive.__deepcopy__.return_value = copy(mocked_s3_drive)
work = WorkWithTwoDrives()
monkeypatch.setattr(work, "lit_drive", mocked_lit_drive)
monkeypatch.setattr(work, "s3_drive", mocked_s3_drive)
monkeypatch.setattr(work, "_state", {"_port", "_name", "lit_drive", "s3_drive"})
monkeypatch.setattr(work, "lit_drive_1", mocked_lit_drive)
monkeypatch.setattr(work, "lit_drive_2", mocked_lit_drive)
monkeypatch.setattr(work, "_state", {"_port", "_name", "lit_drive_1", "lit_drive_2"})
monkeypatch.setattr(work, "_name", "test-work")
monkeypatch.setattr(work._cloud_build_config, "build_commands", lambda: ["echo 'start'"])
monkeypatch.setattr(work._cloud_build_config, "requirements", ["torch==1.0.0", "numpy==1.0.0"])
@ -507,24 +496,24 @@ class TestAppCreationClient:
cloud_runtime.dispatch()
if lightningapps:
s3_drive_spec = V1LightningworkDrives(
lit_drive_1_spec = V1LightningworkDrives(
drive=V1Drive(
metadata=V1Metadata(
name="test-work.s3_drive",
name="test-work.lit_drive_1",
),
spec=V1DriveSpec(
drive_type=V1DriveType.INDEXED_S3,
drive_type=V1DriveType.NO_MOUNT_S3,
source_type=V1SourceType.S3,
source="s3://some-bucket/path/",
source="lit://foobar",
),
status=V1DriveStatus(),
),
mount_location="/hello/",
mount_location=str(tmpdir),
)
lit_drive_spec = V1LightningworkDrives(
lit_drive_2_spec = V1LightningworkDrives(
drive=V1Drive(
metadata=V1Metadata(
name="test-work.lit_drive",
name="test-work.lit_drive_2",
),
spec=V1DriveSpec(
drive_type=V1DriveType.NO_MOUNT_S3,
@ -562,7 +551,7 @@ class TestAppCreationClient:
),
image="random_base_public_image",
),
drives=[lit_drive_spec, s3_drive_spec],
drives=[lit_drive_2_spec, lit_drive_1_spec],
user_requested_compute_config=V1UserRequestedComputeConfig(
name="default", count=1, disk_size=0, preemptible=False, shm_size=0
),
@ -595,7 +584,7 @@ class TestAppCreationClient:
),
image="random_base_public_image",
),
drives=[s3_drive_spec, lit_drive_spec],
drives=[lit_drive_1_spec, lit_drive_2_spec],
user_requested_compute_config=V1UserRequestedComputeConfig(
name="default", count=1, disk_size=0, preemptible=False, shm_size=0
),
@ -632,6 +621,153 @@ class TestAppCreationClient:
project_id="test-project-id", app_id=mock.ANY, id=mock.ANY, body=mock.ANY
)
@mock.patch("lightning_app.runners.backends.cloud.LightningClient", mock.MagicMock())
@pytest.mark.parametrize("lightningapps", [[], [MagicMock()]])
def test_call_with_work_app_and_attached_mount_and_drive(self, lightningapps, monkeypatch, tmpdir):
source_code_root_dir = Path(tmpdir / "src").absolute()
source_code_root_dir.mkdir()
Path(source_code_root_dir / ".lightning").write_text("name: myapp")
requirements_file = Path(source_code_root_dir / "requirements.txt")
Path(requirements_file).touch()
mock_client = mock.MagicMock()
if lightningapps:
lightningapps[0].status.phase = V1LightningappInstanceState.STOPPED
mock_client.lightningapp_instance_service_list_lightningapp_instances.return_value = (
V1ListLightningappInstancesResponse(lightningapps=lightningapps)
)
lightning_app_instance = MagicMock()
mock_client.lightningapp_v2_service_create_lightningapp_release = MagicMock(return_value=lightning_app_instance)
mock_client.lightningapp_v2_service_create_lightningapp_release_instance = MagicMock(
return_value=lightning_app_instance
)
existing_instance = MagicMock()
existing_instance.status.phase = V1LightningappInstanceState.STOPPED
mock_client.lightningapp_service_get_lightningapp = MagicMock(return_value=existing_instance)
cloud_backend = mock.MagicMock()
cloud_backend.client = mock_client
monkeypatch.setattr(backends, "CloudBackend", mock.MagicMock(return_value=cloud_backend))
monkeypatch.setattr(cloud, "LocalSourceCodeDir", mock.MagicMock())
monkeypatch.setattr(cloud, "_prepare_lightning_wheels_and_requirements", mock.MagicMock())
app = mock.MagicMock()
flow = mock.MagicMock()
mocked_drive = MagicMock(spec=Drive)
setattr(mocked_drive, "id", "foobar")
setattr(mocked_drive, "protocol", "lit://")
setattr(mocked_drive, "component_name", "test-work")
setattr(mocked_drive, "allow_duplicates", False)
setattr(mocked_drive, "root_folder", tmpdir)
# deepcopy on a MagicMock instance will return an empty magicmock instance. To
# overcome this we set the __deepcopy__ method `return_value` to equal what
# should be the results of the deepcopy operation (an instance of the original class)
mocked_drive.__deepcopy__.return_value = copy(mocked_drive)
mocked_mount = MagicMock(spec=Mount)
setattr(mocked_mount, "source", "s3://foo/")
setattr(mocked_mount, "root_dir", "/content/")
setattr(mocked_mount, "protocol", "s3://")
work = WorkWithSingleDrive()
monkeypatch.setattr(work, "drive", mocked_drive)
monkeypatch.setattr(work, "_state", {"_port", "drive"})
monkeypatch.setattr(work, "_name", "test-work")
monkeypatch.setattr(work._cloud_build_config, "build_commands", lambda: ["echo 'start'"])
monkeypatch.setattr(work._cloud_build_config, "requirements", ["torch==1.0.0", "numpy==1.0.0"])
monkeypatch.setattr(work._cloud_build_config, "image", "random_base_public_image")
monkeypatch.setattr(work._cloud_compute, "disk_size", 0)
monkeypatch.setattr(work._cloud_compute, "preemptible", False)
monkeypatch.setattr(work._cloud_compute, "mounts", mocked_mount)
monkeypatch.setattr(work, "_port", 8080)
flow.works = lambda recurse: [work]
app.flows = [flow]
cloud_runtime = cloud.CloudRuntime(app=app, entrypoint_file=(source_code_root_dir / "entrypoint.py"))
monkeypatch.setattr(
"lightning_app.runners.cloud._get_project",
lambda x: V1Membership(name="test-project", project_id="test-project-id"),
)
cloud_runtime.dispatch()
if lightningapps:
expected_body = Body8(
description=None,
local_source=True,
app_entrypoint_file="entrypoint.py",
enable_app_server=True,
flow_servers=[],
dependency_cache_key=get_hash(requirements_file),
image_spec=Gridv1ImageSpec(
dependency_file_info=V1DependencyFileInfo(
package_manager=V1PackageManager.PIP, path="requirements.txt"
)
),
works=[
V1Work(
name="test-work",
spec=V1LightningworkSpec(
build_spec=V1BuildSpec(
commands=["echo 'start'"],
python_dependencies=V1PythonDependencyInfo(
package_manager=V1PackageManager.PIP, packages="torch==1.0.0\nnumpy==1.0.0"
),
image="random_base_public_image",
),
drives=[
V1LightningworkDrives(
drive=V1Drive(
metadata=V1Metadata(
name="test-work.drive",
),
spec=V1DriveSpec(
drive_type=V1DriveType.NO_MOUNT_S3,
source_type=V1SourceType.S3,
source="lit://foobar",
),
status=V1DriveStatus(),
),
mount_location=str(tmpdir),
),
V1LightningworkDrives(
drive=V1Drive(
metadata=V1Metadata(
name="test-work",
),
spec=V1DriveSpec(
drive_type=V1DriveType.INDEXED_S3,
source_type=V1SourceType.S3,
source="s3://foo/",
),
status=V1DriveStatus(),
),
mount_location="/content/",
),
],
user_requested_compute_config=V1UserRequestedComputeConfig(
name="default", count=1, disk_size=0, preemptible=False, shm_size=0
),
network_config=[V1NetworkConfig(name=mock.ANY, host=None, port=8080)],
),
)
],
)
mock_client.lightningapp_v2_service_create_lightningapp_release.assert_called_once_with(
project_id="test-project-id", app_id=mock.ANY, body=expected_body
)
# running dispatch with disabled dependency cache
mock_client.reset_mock()
monkeypatch.setattr(cloud, "DISABLE_DEPENDENCY_CACHE", True)
expected_body.dependency_cache_key = None
cloud_runtime.dispatch()
mock_client.lightningapp_v2_service_create_lightningapp_release.assert_called_once_with(
project_id="test-project-id", app_id=mock.ANY, body=expected_body
)
else:
mock_client.lightningapp_v2_service_create_lightningapp_release_instance.assert_called_once_with(
project_id="test-project-id", app_id=mock.ANY, id=mock.ANY, body=mock.ANY
)
@mock.patch("lightning_app.core.queues.QueuingSystem", MagicMock())
@mock.patch("lightning_app.runners.backends.cloud.LightningClient", MagicMock())

View File

@ -213,44 +213,7 @@ def test_lit_drive():
os.remove("a.txt")
def test_s3_drives():
drive = Drive("s3://foo/", allow_duplicates=True)
drive.component_name = "root.work"
with pytest.raises(
Exception, match="S3 based drives cannot currently add files via this API. Did you mean to use `lit://` drives?"
):
drive.put("a.txt")
with pytest.raises(
Exception,
match="S3 based drives cannot currently list files via this API. Did you mean to use `lit://` drives?",
):
drive.list("a.txt")
with pytest.raises(
Exception, match="S3 based drives cannot currently get files via this API. Did you mean to use `lit://` drives?"
):
drive.get("a.txt")
with pytest.raises(
Exception,
match="S3 based drives cannot currently delete files via this API. Did you mean to use `lit://` drives?",
):
drive.delete("a.txt")
_set_flow_context()
with pytest.raises(Exception, match="The flow isn't allowed to put files into a Drive."):
drive.put("a.txt")
with pytest.raises(Exception, match="The flow isn't allowed to list files from a Drive."):
drive.list("a.txt")
with pytest.raises(Exception, match="The flow isn't allowed to get files from a Drive."):
drive.get("a.txt")
def test_create_s3_drive_without_trailing_slash_fails():
with pytest.raises(ValueError, match="S3 drives must end in a trailing slash"):
Drive("s3://foo")
@pytest.mark.parametrize("drive_id", ["lit://drive", "s3://drive/"])
@pytest.mark.parametrize("drive_id", ["lit://drive"])
def test_maybe_create_drive(drive_id):
drive = Drive(drive_id, allow_duplicates=False)
drive.component_name = "root.work1"
@ -260,7 +223,7 @@ def test_maybe_create_drive(drive_id):
assert new_drive.component_name == drive.component_name
@pytest.mark.parametrize("drive_id", ["lit://drive", "s3://drive/"])
@pytest.mark.parametrize("drive_id", ["lit://drive"])
def test_drive_deepcopy(drive_id):
drive = Drive(drive_id, allow_duplicates=True)
drive.component_name = "root.work1"
@ -269,8 +232,9 @@ def test_drive_deepcopy(drive_id):
assert new_drive.component_name == drive.component_name
def test_drive_root_folder_pass():
Drive("s3://drive/", root_folder="a")
def test_s3_drive_raises_error_telling_users_to_use_mounts():
with pytest.raises(ValueError, match="Using S3 buckets in a Drive is no longer supported."):
Drive("s3://foo/")
def test_drive_root_folder_breaks():

View File

@ -0,0 +1,37 @@
import pytest
from lightning_app.storage.mount import Mount
def test_create_s3_mount_successfully():
mount = Mount(source="s3://foo/bar/", root_dir="./foo")
assert mount.source == "s3://foo/bar/"
assert mount.root_dir == "./foo"
assert mount.protocol == "s3://"
def test_create_non_s3_mount_fails():
with pytest.raises(ValueError, match="Unknown protocol for the mount 'source' argument"):
Mount(source="foo/bar/", root_dir="./foo")
with pytest.raises(ValueError, match="Unknown protocol for the mount 'source' argument"):
Mount(source="gcs://foo/bar/", root_dir="./foo")
with pytest.raises(ValueError, match="Unknown protocol for the mount 'source' argument"):
Mount(source="3://foo/bar/", root_dir="./foo")
def test_create_s3_mount_without_directory_prefix_fails():
with pytest.raises(ValueError, match="S3 mounts must end in a trailing slash"):
Mount(source="s3://foo/bar", root_dir="./foo")
with pytest.raises(ValueError, match="S3 mounts must end in a trailing slash"):
Mount(source="s3://foo", root_dir="./foo")
def test_create_mount_without_root_dir_argument():
m = Mount(source="s3://foo/")
assert m.root_dir == "/data/foo"
m = Mount(source="s3://foo/bar/")
assert m.root_dir == "/data/bar"

View File

@ -53,6 +53,7 @@ def test_dict():
"preemptible": False,
"wait_timeout": None,
"idle_timeout": None,
"mounts": None,
"shm_size": 0,
"_internal_id": "default",
},
@ -87,6 +88,7 @@ def test_dict():
"preemptible": False,
"wait_timeout": None,
"idle_timeout": None,
"mounts": None,
"shm_size": 0,
"_internal_id": "default",
},
@ -121,6 +123,7 @@ def test_dict():
"preemptible": False,
"wait_timeout": None,
"idle_timeout": None,
"mounts": None,
"shm_size": 0,
"_internal_id": "default",
},
@ -207,6 +210,7 @@ def test_list():
"preemptible": False,
"wait_timeout": None,
"idle_timeout": None,
"mounts": None,
"shm_size": 0,
"_internal_id": "default",
},
@ -241,6 +245,7 @@ def test_list():
"preemptible": False,
"wait_timeout": None,
"idle_timeout": None,
"mounts": None,
"shm_size": 0,
"_internal_id": "default",
},
@ -270,6 +275,7 @@ def test_list():
"preemptible": False,
"wait_timeout": None,
"idle_timeout": None,
"mounts": None,
"shm_size": 0,
"_internal_id": "default",
},

View File

@ -1,6 +1,7 @@
import pytest
from lightning_app import CloudCompute
from lightning_app.storage import Mount
def test_cloud_compute_unsupported_features():
@ -17,6 +18,33 @@ def test_cloud_compute_names():
def test_cloud_compute_shared_memory():
cloud_compute = CloudCompute("gpu", shm_size=1100)
assert cloud_compute.shm_size == 1100
def test_cloud_compute_with_mounts():
mount_1 = Mount(source="s3://foo/", root_dir="./foo")
mount_2 = Mount(source="s3://foo/bar/", root_dir="./bar")
cloud_compute = CloudCompute("gpu", mounts=mount_1)
assert cloud_compute.mounts == mount_1
cloud_compute = CloudCompute("gpu", mounts=[mount_1, mount_2])
assert cloud_compute.mounts == [mount_1, mount_2]
cc_dict = cloud_compute.to_dict()
assert "mounts" in cc_dict
assert cc_dict["mounts"] == [
{"root_dir": "./foo", "source": "s3://foo/"},
{"root_dir": "./bar", "source": "s3://foo/bar/"},
]
assert CloudCompute.from_dict(cc_dict) == cloud_compute
def test_cloud_compute_with_non_unique_mount_root_dirs():
mount_1 = Mount(source="s3://foo/", root_dir="./foo")
mount_2 = Mount(source="s3://foo/bar/", root_dir="./foo")
with pytest.raises(ValueError, match="Every Mount attached to a work must have a unique 'root_dir' argument."):
CloudCompute("gpu", mounts=[mount_1, mount_2])