122 lines
3.2 KiB
Python
122 lines
3.2 KiB
Python
import os
|
|
from unittest import mock
|
|
|
|
import pytest
|
|
|
|
from lightning.data.fileio import is_path, is_url, open_single_file, OpenCloudFileObj, path_to_url
|
|
|
|
|
|
@pytest.mark.parametrize(
|
|
("input_str", "expected"),
|
|
[
|
|
("s3://my_bucket/a", True),
|
|
("s3:/my_bucket", False),
|
|
("my_bucket", False),
|
|
("my_bucket_s3://", False),
|
|
],
|
|
)
|
|
def test_is_url(input_str, expected):
|
|
assert is_url(input_str) == expected
|
|
|
|
|
|
@pytest.mark.parametrize(
|
|
("input_str", "expected"),
|
|
[
|
|
("s3://my_bucket/a", False),
|
|
("s3:/my_bucket", False),
|
|
("my_bucket", False),
|
|
("my_bucket_s3://", False),
|
|
("/my_bucket", True),
|
|
],
|
|
)
|
|
def test_is_path(input_str, expected):
|
|
assert is_path(input_str) == expected
|
|
|
|
|
|
@pytest.mark.parametrize(
|
|
("path", "bucket_name", "bucket_root_path", "expected"),
|
|
[
|
|
("/data/abc/def", "my_bucket", "/data/abc", "s3://my_bucket/def"),
|
|
("/data/abc/def", "my_bucket", "/data", "s3://my_bucket/abc/def"),
|
|
],
|
|
)
|
|
def test_path_to_url(path, bucket_name, bucket_root_path, expected):
|
|
assert path_to_url(path, bucket_name, bucket_root_path) == expected
|
|
|
|
|
|
def test_path_to_url_error():
|
|
with pytest.raises(ValueError, match="Cannot create a path from /path1/abc relative to /path2"):
|
|
path_to_url("/path1/abc", "foo", "/path2")
|
|
|
|
|
|
@pytest.mark.parametrize("path", ["s3://my_bucket/da.txt", "abc.txt"])
|
|
@mock.patch("s3fs.S3FileSystem", autospec=True)
|
|
def test_read_single_file_read(patch: mock.Mock, path, tmp_path):
|
|
from torchdata.datapipes.utils import StreamWrapper
|
|
|
|
is_s3 = is_url(path)
|
|
|
|
if not is_s3:
|
|
path = os.path.join(tmp_path, path)
|
|
with open(path, "w") as f:
|
|
f.write("mytestfile")
|
|
|
|
file_stream = open_single_file(path)
|
|
assert isinstance(file_stream, StreamWrapper)
|
|
|
|
content = file_stream.read()
|
|
|
|
if is_s3:
|
|
assert isinstance(file_stream.file_obj, mock.Mock)
|
|
assert patch.open.assert_called_once
|
|
|
|
else:
|
|
assert content == "mytestfile"
|
|
|
|
|
|
@pytest.mark.parametrize("path", ["s3://my_bucket/da.txt", "abc.txt"])
|
|
@mock.patch("s3fs.S3FileSystem", autospec=True)
|
|
def test_read_single_file_write(patch: mock.Mock, path, tmp_path):
|
|
from torchdata.datapipes.utils import StreamWrapper
|
|
|
|
is_s3 = is_url(path)
|
|
|
|
if not is_s3:
|
|
path = os.path.join(tmp_path, path)
|
|
|
|
file_stream = open_single_file(path, mode="w")
|
|
assert isinstance(file_stream, StreamWrapper)
|
|
file_stream.write("mytestfile")
|
|
file_stream.close()
|
|
|
|
if is_s3:
|
|
assert isinstance(file_stream.file_obj, mock.Mock)
|
|
assert patch.open.assert_called_once
|
|
|
|
else:
|
|
with open(path) as f:
|
|
assert f.read() == "mytestfile"
|
|
|
|
|
|
def test_open_cloud_file_obj(tmp_path):
|
|
path = os.path.join(tmp_path, "foo.txt")
|
|
with open(path, "w") as f:
|
|
f.write("bar!")
|
|
|
|
f = OpenCloudFileObj(path)
|
|
|
|
with f:
|
|
assert f.read() == "bar!"
|
|
assert f._stream.closed
|
|
|
|
f = OpenCloudFileObj(path)
|
|
assert f.read() == "bar!"
|
|
f.close()
|
|
assert f._stream.closed
|
|
|
|
with OpenCloudFileObj(path, "w") as f:
|
|
f.write("not bar anymore!")
|
|
|
|
with open(path) as f:
|
|
assert f.read() == "not bar anymore!"
|