Add batch_size to map, optimize (#19489)

This commit is contained in:
thomas chaton 2024-02-16 20:54:39 +00:00 committed by GitHub
parent bbc5488a62
commit bb35e8e0d3
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 37 additions and 2 deletions

View File

@ -70,6 +70,12 @@ def _get_input_dir(inputs: Sequence[Any]) -> Optional[str]:
return "/" + os.path.join(*str(absolute_path).split("/")[:4])
def _get_default_num_workers() -> int:
if torch.cuda.is_available():
return torch.cuda.device_count()
return os.cpu_count() or 1
class LambdaDataTransformRecipe(DataTransformRecipe):
def __init__(self, fn: Callable[[str, Any], None], inputs: Sequence[Any]):
super().__init__()
@ -161,6 +167,7 @@ def map(
reorder_files: bool = True,
error_when_not_empty: bool = False,
reader: Optional[BaseReader] = None,
batch_size: Optional[int] = None,
) -> None:
"""This function map a callbable over a collection of files possibly in a distributed way.
@ -178,6 +185,7 @@ def map(
reorder_files: By default, reorders the files by file size to distribute work equally among all workers.
Set this to ``False`` if the order in which samples are processed should be preserved.
error_when_not_empty: Whether we should error if the output folder isn't empty.
batch_size: Group the inputs into batches of batch_size length.
"""
if not isinstance(inputs, Sequence):
@ -212,10 +220,13 @@ def map(
input_dir = _resolve_dir(_get_input_dir(inputs))
if isinstance(batch_size, int) and batch_size > 1:
inputs = [inputs[pos : pos + batch_size] for pos in range(0, len(inputs), batch_size)]
data_processor = DataProcessor(
input_dir=input_dir,
output_dir=_output_dir,
num_workers=num_workers or os.cpu_count(),
num_workers=num_workers or _get_default_num_workers(),
fast_dev_run=fast_dev_run,
num_downloaders=num_downloaders,
num_uploaders=num_uploaders,
@ -247,6 +258,7 @@ def optimize(
num_uploaders: Optional[int] = None,
reorder_files: bool = True,
reader: Optional[BaseReader] = None,
batch_size: Optional[int] = None,
) -> None:
"""This function converts a dataset into chunks possibly in a distributed way.
@ -266,6 +278,7 @@ def optimize(
num_uploaders: The numbers of uploaders per worker.
reorder_files: By default, reorders the files by file size to distribute work equally among all workers.
Set this to ``False`` if the order in which samples are processed should be preserved.
batch_size: Group the inputs into batches of batch_size length.
"""
if not isinstance(inputs, Sequence):
@ -302,10 +315,13 @@ def optimize(
input_dir = _resolve_dir(_get_input_dir(inputs))
if isinstance(batch_size, int) and batch_size > 1:
inputs = [inputs[pos : pos + batch_size] for pos in range(0, len(inputs), batch_size)]
data_processor = DataProcessor(
input_dir=input_dir,
output_dir=_output_dir,
num_workers=num_workers or os.cpu_count(),
num_workers=num_workers or _get_default_num_workers(),
fast_dev_run=fast_dev_run,
num_downloaders=num_downloaders,
num_uploaders=num_uploaders,

View File

@ -1025,6 +1025,25 @@ def test_map_is_last(num_workers, expected, tmpdir):
assert sorted(os.listdir(tmpdir)) == expected
def map_batch_size_fn(indexes, output_dir):
path = os.path.join(output_dir, str(indexes))
with open(path, "w") as f:
f.write("hello world")
def test_map_batch_size(tmpdir):
map(
map_batch_size_fn,
list(range(5)),
output_dir=str(tmpdir),
error_when_not_empty=False,
num_workers=1,
batch_size=2,
)
assert sorted(os.listdir(tmpdir)) == ["[0, 1]", "[2, 3]", "[4]"]
def no_op(index):
pass