lightning/tests/tests_data/processing/test_readers.py

Ignoring revisions in .git-blame-ignore-revs. Click here to bypass and see the normal blame view.

66 lines
1.9 KiB
Python
Raw Normal View History

import os
import sys
import pytest
from lightning.data import map
from lightning.data.processing.readers import _PYARROW_AVAILABLE, BaseReader, ParquetReader
class DummyReader(BaseReader):
def remap_items(self, items, num_workers: int):
return [(worker_idx, idx, item) for idx, item in enumerate(items) for worker_idx in range(num_workers)]
def read(self, item):
return item
def fn(data: str, output_dir):
worker_idx, idx, _ = data
with open(os.path.join(output_dir, f"{worker_idx}_{idx}"), "w") as f:
f.write("hello world")
def test_reader(tmpdir):
map(fn, list(range(3)), output_dir=str(tmpdir), reader=DummyReader(), num_workers=2)
assert sorted(os.listdir(tmpdir)) == ['0_0', '0_1', '0_2', '1_0', '1_1', '1_2']
def map_parquet(df, output_dir):
for row in df.iter_batches(batch_size=1):
for row in row.to_pandas().values.tolist():
filename = f"{row[0]}_{df.metadata.num_rows}"
with open(os.path.join(output_dir, filename), "w") as f:
f.write("hello world")
return
@pytest.mark.skipif(
not _PYARROW_AVAILABLE or sys.platform == "linux",
reason="polars and pyarrow are required"
)
def test_parquet_reader(tmpdir):
import pandas as pd
inputs = []
for i in range(3):
parquet_path = os.path.join(tmpdir, f"{i}.parquet")
df = pd.DataFrame(list(range(i * 10, (i + 1) * 10)), columns=["value"])
df.to_parquet(parquet_path)
inputs.append(parquet_path)
cache_folder = os.path.join(tmpdir, "cache")
map(
map_parquet,
inputs=inputs,
output_dir=os.path.join(tmpdir, "output_dir"),
reader=ParquetReader(cache_folder, num_rows=5, to_pandas=False),
num_workers=2
)
assert sorted(os.listdir(os.path.join(tmpdir, "output_dir"))) == ['0_5', '10_5', '15_5', '20_5', '25_5', '5_5']