Renamed Mount root_dir Argument to mount_path (#15228)

* renamed Mount argument

* fix tests

* Apply suggestions from code review

Co-authored-by: Luca Antiga <luca.antiga@gmail.com>

* updated examples as well

Co-authored-by: Luca Antiga <luca.antiga@gmail.com>
This commit is contained in:
Rick Izzo 2022-10-20 17:33:35 -04:00 committed by GitHub
parent 6a72a15a62
commit b6f01ab010
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 40 additions and 28 deletions

View File

@ -23,7 +23,7 @@ class Flow(L.LightningFlow):
cloud_compute=CloudCompute(
mounts=Mount(
source="s3://ryft-public-sample-data/esRedditJson/",
root_dir="/content/esRedditJson/",
mount_path="/content/esRedditJson/",
),
)
)

View File

@ -432,5 +432,5 @@ def _create_mount_drive_spec(work_name: str, mount: Mount) -> V1LightningworkDri
),
status=V1DriveStatus(),
),
mount_location=str(mount.root_dir),
mount_location=str(mount.mount_path),
)

View File

@ -1,3 +1,4 @@
import os
from dataclasses import dataclass
from pathlib import Path
from typing import List
@ -16,13 +17,13 @@ class Mount:
`s3` style identifier pointing to a bucket and (optionally) prefix to mount. For
example: `s3://foo/bar/`.
root_dir: An absolute directory path in the work where external data source should
mount_path: An absolute directory path in the work where external data source should
be mounted as a filesystem. This path should not already exist in your codebase.
If not included, then the root_dir will be set to `/data/<last folder name in the bucket>`
"""
source: str = ""
root_dir: str = ""
mount_path: str = ""
def __post_init__(self) -> None:
@ -42,8 +43,14 @@ class Mount:
f"Received: '{self.source}'. Mounting a single file is not currently supported."
)
if self.root_dir == "":
self.root_dir = f"/data/{Path(self.source).stem}"
if self.mount_path == "":
self.mount_path = f"/data/{Path(self.source).stem}"
if not os.path.isabs(self.mount_path):
raise ValueError(
f"mount_path argument must be an absolute path to a "
f"location; received relative path {self.mount_path}"
)
@property
def protocol(self) -> str:

View File

@ -141,9 +141,9 @@ class CloudCompute:
def _verify_mount_root_dirs_are_unique(mounts: Union[None, Mount, List[Mount], Tuple[Mount]]) -> None:
if isinstance(mounts, (list, tuple, set)):
root_dirs = [mount.root_dir for mount in mounts]
if len(set(root_dirs)) != len(root_dirs):
raise ValueError("Every Mount attached to a work must have a unique 'root_dir' argument.")
mount_paths = [mount.mount_path for mount in mounts]
if len(set(mount_paths)) != len(mount_paths):
raise ValueError("Every Mount attached to a work must have a unique 'mount_path' argument.")
def _maybe_create_cloud_compute(state: Dict) -> Union[CloudCompute, Dict]:

View File

@ -665,7 +665,7 @@ class TestAppCreationClient:
mocked_mount = MagicMock(spec=Mount)
setattr(mocked_mount, "source", "s3://foo/")
setattr(mocked_mount, "root_dir", "/content/")
setattr(mocked_mount, "mount_path", "/content/foo")
setattr(mocked_mount, "protocol", "s3://")
work = WorkWithSingleDrive()
@ -740,7 +740,7 @@ class TestAppCreationClient:
),
status=V1DriveStatus(),
),
mount_location="/content/",
mount_location="/content/foo",
),
],
user_requested_compute_config=V1UserRequestedComputeConfig(

View File

@ -4,34 +4,39 @@ from lightning_app.storage.mount import Mount
def test_create_s3_mount_successfully():
mount = Mount(source="s3://foo/bar/", root_dir="./foo")
mount = Mount(source="s3://foo/bar/", mount_path="/foo")
assert mount.source == "s3://foo/bar/"
assert mount.root_dir == "./foo"
assert mount.mount_path == "/foo"
assert mount.protocol == "s3://"
def test_create_non_s3_mount_fails():
with pytest.raises(ValueError, match="Unknown protocol for the mount 'source' argument"):
Mount(source="foo/bar/", root_dir="./foo")
Mount(source="foo/bar/", mount_path="/foo")
with pytest.raises(ValueError, match="Unknown protocol for the mount 'source' argument"):
Mount(source="gcs://foo/bar/", root_dir="./foo")
Mount(source="gcs://foo/bar/", mount_path="/foo")
with pytest.raises(ValueError, match="Unknown protocol for the mount 'source' argument"):
Mount(source="3://foo/bar/", root_dir="./foo")
Mount(source="3://foo/bar/", mount_path="/foo")
def test_create_s3_mount_without_directory_prefix_fails():
with pytest.raises(ValueError, match="S3 mounts must end in a trailing slash"):
Mount(source="s3://foo/bar", root_dir="./foo")
Mount(source="s3://foo/bar", mount_path="/foo")
with pytest.raises(ValueError, match="S3 mounts must end in a trailing slash"):
Mount(source="s3://foo", root_dir="./foo")
Mount(source="s3://foo", mount_path="/foo")
def test_create_mount_without_root_dir_argument():
def test_create_mount_without_mount_path_argument():
m = Mount(source="s3://foo/")
assert m.root_dir == "/data/foo"
assert m.mount_path == "/data/foo"
m = Mount(source="s3://foo/bar/")
assert m.root_dir == "/data/bar"
assert m.mount_path == "/data/bar"
def test_create_mount_path_with_relative_path_errors():
with pytest.raises(ValueError, match="mount_path argument must be an absolute path"):
Mount(source="s3://foo/", mount_path="./doesnotwork")

View File

@ -23,8 +23,8 @@ def test_cloud_compute_shared_memory():
def test_cloud_compute_with_mounts():
mount_1 = Mount(source="s3://foo/", root_dir="./foo")
mount_2 = Mount(source="s3://foo/bar/", root_dir="./bar")
mount_1 = Mount(source="s3://foo/", mount_path="/foo")
mount_2 = Mount(source="s3://foo/bar/", mount_path="/bar")
cloud_compute = CloudCompute("gpu", mounts=mount_1)
assert cloud_compute.mounts == mount_1
@ -35,16 +35,16 @@ def test_cloud_compute_with_mounts():
cc_dict = cloud_compute.to_dict()
assert "mounts" in cc_dict
assert cc_dict["mounts"] == [
{"root_dir": "./foo", "source": "s3://foo/"},
{"root_dir": "./bar", "source": "s3://foo/bar/"},
{"mount_path": "/foo", "source": "s3://foo/"},
{"mount_path": "/bar", "source": "s3://foo/bar/"},
]
assert CloudCompute.from_dict(cc_dict) == cloud_compute
def test_cloud_compute_with_non_unique_mount_root_dirs():
mount_1 = Mount(source="s3://foo/", root_dir="./foo")
mount_2 = Mount(source="s3://foo/bar/", root_dir="./foo")
mount_1 = Mount(source="s3://foo/", mount_path="/foo")
mount_2 = Mount(source="s3://foo/bar/", mount_path="/foo")
with pytest.raises(ValueError, match="Every Mount attached to a work must have a unique 'root_dir' argument."):
with pytest.raises(ValueError, match="Every Mount attached to a work must have a unique"):
CloudCompute("gpu", mounts=[mount_1, mount_2])