Renamed Mount root_dir Argument to mount_path (#15228)
* renamed Mount argument * fix tests * Apply suggestions from code review Co-authored-by: Luca Antiga <luca.antiga@gmail.com> * updated examples as well Co-authored-by: Luca Antiga <luca.antiga@gmail.com>
This commit is contained in:
parent
6a72a15a62
commit
b6f01ab010
|
@ -23,7 +23,7 @@ class Flow(L.LightningFlow):
|
|||
cloud_compute=CloudCompute(
|
||||
mounts=Mount(
|
||||
source="s3://ryft-public-sample-data/esRedditJson/",
|
||||
root_dir="/content/esRedditJson/",
|
||||
mount_path="/content/esRedditJson/",
|
||||
),
|
||||
)
|
||||
)
|
||||
|
|
|
@ -432,5 +432,5 @@ def _create_mount_drive_spec(work_name: str, mount: Mount) -> V1LightningworkDri
|
|||
),
|
||||
status=V1DriveStatus(),
|
||||
),
|
||||
mount_location=str(mount.root_dir),
|
||||
mount_location=str(mount.mount_path),
|
||||
)
|
||||
|
|
|
@ -1,3 +1,4 @@
|
|||
import os
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import List
|
||||
|
@ -16,13 +17,13 @@ class Mount:
|
|||
`s3` style identifier pointing to a bucket and (optionally) prefix to mount. For
|
||||
example: `s3://foo/bar/`.
|
||||
|
||||
root_dir: An absolute directory path in the work where external data source should
|
||||
mount_path: An absolute directory path in the work where external data source should
|
||||
be mounted as a filesystem. This path should not already exist in your codebase.
|
||||
If not included, then the root_dir will be set to `/data/<last folder name in the bucket>`
|
||||
"""
|
||||
|
||||
source: str = ""
|
||||
root_dir: str = ""
|
||||
mount_path: str = ""
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
|
||||
|
@ -42,8 +43,14 @@ class Mount:
|
|||
f"Received: '{self.source}'. Mounting a single file is not currently supported."
|
||||
)
|
||||
|
||||
if self.root_dir == "":
|
||||
self.root_dir = f"/data/{Path(self.source).stem}"
|
||||
if self.mount_path == "":
|
||||
self.mount_path = f"/data/{Path(self.source).stem}"
|
||||
|
||||
if not os.path.isabs(self.mount_path):
|
||||
raise ValueError(
|
||||
f"mount_path argument must be an absolute path to a "
|
||||
f"location; received relative path {self.mount_path}"
|
||||
)
|
||||
|
||||
@property
|
||||
def protocol(self) -> str:
|
||||
|
|
|
@ -141,9 +141,9 @@ class CloudCompute:
|
|||
|
||||
def _verify_mount_root_dirs_are_unique(mounts: Union[None, Mount, List[Mount], Tuple[Mount]]) -> None:
|
||||
if isinstance(mounts, (list, tuple, set)):
|
||||
root_dirs = [mount.root_dir for mount in mounts]
|
||||
if len(set(root_dirs)) != len(root_dirs):
|
||||
raise ValueError("Every Mount attached to a work must have a unique 'root_dir' argument.")
|
||||
mount_paths = [mount.mount_path for mount in mounts]
|
||||
if len(set(mount_paths)) != len(mount_paths):
|
||||
raise ValueError("Every Mount attached to a work must have a unique 'mount_path' argument.")
|
||||
|
||||
|
||||
def _maybe_create_cloud_compute(state: Dict) -> Union[CloudCompute, Dict]:
|
||||
|
|
|
@ -665,7 +665,7 @@ class TestAppCreationClient:
|
|||
|
||||
mocked_mount = MagicMock(spec=Mount)
|
||||
setattr(mocked_mount, "source", "s3://foo/")
|
||||
setattr(mocked_mount, "root_dir", "/content/")
|
||||
setattr(mocked_mount, "mount_path", "/content/foo")
|
||||
setattr(mocked_mount, "protocol", "s3://")
|
||||
|
||||
work = WorkWithSingleDrive()
|
||||
|
@ -740,7 +740,7 @@ class TestAppCreationClient:
|
|||
),
|
||||
status=V1DriveStatus(),
|
||||
),
|
||||
mount_location="/content/",
|
||||
mount_location="/content/foo",
|
||||
),
|
||||
],
|
||||
user_requested_compute_config=V1UserRequestedComputeConfig(
|
||||
|
|
|
@ -4,34 +4,39 @@ from lightning_app.storage.mount import Mount
|
|||
|
||||
|
||||
def test_create_s3_mount_successfully():
|
||||
mount = Mount(source="s3://foo/bar/", root_dir="./foo")
|
||||
mount = Mount(source="s3://foo/bar/", mount_path="/foo")
|
||||
assert mount.source == "s3://foo/bar/"
|
||||
assert mount.root_dir == "./foo"
|
||||
assert mount.mount_path == "/foo"
|
||||
assert mount.protocol == "s3://"
|
||||
|
||||
|
||||
def test_create_non_s3_mount_fails():
|
||||
with pytest.raises(ValueError, match="Unknown protocol for the mount 'source' argument"):
|
||||
Mount(source="foo/bar/", root_dir="./foo")
|
||||
Mount(source="foo/bar/", mount_path="/foo")
|
||||
|
||||
with pytest.raises(ValueError, match="Unknown protocol for the mount 'source' argument"):
|
||||
Mount(source="gcs://foo/bar/", root_dir="./foo")
|
||||
Mount(source="gcs://foo/bar/", mount_path="/foo")
|
||||
|
||||
with pytest.raises(ValueError, match="Unknown protocol for the mount 'source' argument"):
|
||||
Mount(source="3://foo/bar/", root_dir="./foo")
|
||||
Mount(source="3://foo/bar/", mount_path="/foo")
|
||||
|
||||
|
||||
def test_create_s3_mount_without_directory_prefix_fails():
|
||||
with pytest.raises(ValueError, match="S3 mounts must end in a trailing slash"):
|
||||
Mount(source="s3://foo/bar", root_dir="./foo")
|
||||
Mount(source="s3://foo/bar", mount_path="/foo")
|
||||
|
||||
with pytest.raises(ValueError, match="S3 mounts must end in a trailing slash"):
|
||||
Mount(source="s3://foo", root_dir="./foo")
|
||||
Mount(source="s3://foo", mount_path="/foo")
|
||||
|
||||
|
||||
def test_create_mount_without_root_dir_argument():
|
||||
def test_create_mount_without_mount_path_argument():
|
||||
m = Mount(source="s3://foo/")
|
||||
assert m.root_dir == "/data/foo"
|
||||
assert m.mount_path == "/data/foo"
|
||||
|
||||
m = Mount(source="s3://foo/bar/")
|
||||
assert m.root_dir == "/data/bar"
|
||||
assert m.mount_path == "/data/bar"
|
||||
|
||||
|
||||
def test_create_mount_path_with_relative_path_errors():
|
||||
with pytest.raises(ValueError, match="mount_path argument must be an absolute path"):
|
||||
Mount(source="s3://foo/", mount_path="./doesnotwork")
|
||||
|
|
|
@ -23,8 +23,8 @@ def test_cloud_compute_shared_memory():
|
|||
|
||||
|
||||
def test_cloud_compute_with_mounts():
|
||||
mount_1 = Mount(source="s3://foo/", root_dir="./foo")
|
||||
mount_2 = Mount(source="s3://foo/bar/", root_dir="./bar")
|
||||
mount_1 = Mount(source="s3://foo/", mount_path="/foo")
|
||||
mount_2 = Mount(source="s3://foo/bar/", mount_path="/bar")
|
||||
|
||||
cloud_compute = CloudCompute("gpu", mounts=mount_1)
|
||||
assert cloud_compute.mounts == mount_1
|
||||
|
@ -35,16 +35,16 @@ def test_cloud_compute_with_mounts():
|
|||
cc_dict = cloud_compute.to_dict()
|
||||
assert "mounts" in cc_dict
|
||||
assert cc_dict["mounts"] == [
|
||||
{"root_dir": "./foo", "source": "s3://foo/"},
|
||||
{"root_dir": "./bar", "source": "s3://foo/bar/"},
|
||||
{"mount_path": "/foo", "source": "s3://foo/"},
|
||||
{"mount_path": "/bar", "source": "s3://foo/bar/"},
|
||||
]
|
||||
|
||||
assert CloudCompute.from_dict(cc_dict) == cloud_compute
|
||||
|
||||
|
||||
def test_cloud_compute_with_non_unique_mount_root_dirs():
|
||||
mount_1 = Mount(source="s3://foo/", root_dir="./foo")
|
||||
mount_2 = Mount(source="s3://foo/bar/", root_dir="./foo")
|
||||
mount_1 = Mount(source="s3://foo/", mount_path="/foo")
|
||||
mount_2 = Mount(source="s3://foo/bar/", mount_path="/foo")
|
||||
|
||||
with pytest.raises(ValueError, match="Every Mount attached to a work must have a unique 'root_dir' argument."):
|
||||
with pytest.raises(ValueError, match="Every Mount attached to a work must have a unique"):
|
||||
CloudCompute("gpu", mounts=[mount_1, mount_2])
|
||||
|
|
Loading…
Reference in New Issue