lightning/tests/tests_data/test_fileio.py

122 lines
3.2 KiB
Python

import os
from unittest import mock
import pytest
from lightning.data.fileio import is_path, is_url, open_single_file, OpenCloudFileObj, path_to_url
@pytest.mark.parametrize(
("input_str", "expected"),
[
("s3://my_bucket/a", True),
("s3:/my_bucket", False),
("my_bucket", False),
("my_bucket_s3://", False),
],
)
def test_is_url(input_str, expected):
assert is_url(input_str) == expected
@pytest.mark.parametrize(
("input_str", "expected"),
[
("s3://my_bucket/a", False),
("s3:/my_bucket", False),
("my_bucket", False),
("my_bucket_s3://", False),
("/my_bucket", True),
],
)
def test_is_path(input_str, expected):
assert is_path(input_str) == expected
@pytest.mark.parametrize(
("path", "bucket_name", "bucket_root_path", "expected"),
[
("/data/abc/def", "my_bucket", "/data/abc", "s3://my_bucket/def"),
("/data/abc/def", "my_bucket", "/data", "s3://my_bucket/abc/def"),
],
)
def test_path_to_url(path, bucket_name, bucket_root_path, expected):
assert path_to_url(path, bucket_name, bucket_root_path) == expected
def test_path_to_url_error():
with pytest.raises(ValueError, match="Cannot create a path from /path1/abc relative to /path2"):
path_to_url("/path1/abc", "foo", "/path2")
@pytest.mark.parametrize("path", ["s3://my_bucket/da.txt", "abc.txt"])
@mock.patch("s3fs.S3FileSystem", autospec=True)
def test_read_single_file_read(patch: mock.Mock, path, tmp_path):
from torchdata.datapipes.utils import StreamWrapper
is_s3 = is_url(path)
if not is_s3:
path = os.path.join(tmp_path, path)
with open(path, "w") as f:
f.write("mytestfile")
file_stream = open_single_file(path)
assert isinstance(file_stream, StreamWrapper)
content = file_stream.read()
if is_s3:
assert isinstance(file_stream.file_obj, mock.Mock)
assert patch.open.assert_called_once
else:
assert content == "mytestfile"
@pytest.mark.parametrize("path", ["s3://my_bucket/da.txt", "abc.txt"])
@mock.patch("s3fs.S3FileSystem", autospec=True)
def test_read_single_file_write(patch: mock.Mock, path, tmp_path):
from torchdata.datapipes.utils import StreamWrapper
is_s3 = is_url(path)
if not is_s3:
path = os.path.join(tmp_path, path)
file_stream = open_single_file(path, mode="w")
assert isinstance(file_stream, StreamWrapper)
file_stream.write("mytestfile")
file_stream.close()
if is_s3:
assert isinstance(file_stream.file_obj, mock.Mock)
assert patch.open.assert_called_once
else:
with open(path) as f:
assert f.read() == "mytestfile"
def test_open_cloud_file_obj(tmp_path):
path = os.path.join(tmp_path, "foo.txt")
with open(path, "w") as f:
f.write("bar!")
f = OpenCloudFileObj(path)
with f:
assert f.read() == "bar!"
assert f._stream.closed
f = OpenCloudFileObj(path)
assert f.read() == "bar!"
f.close()
assert f._stream.closed
with OpenCloudFileObj(path, "w") as f:
f.write("not bar anymore!")
with open(path) as f:
assert f.read() == "not bar anymore!"