Add batch_size to map, optimize (#19489)
This commit is contained in:
parent
bbc5488a62
commit
bb35e8e0d3
|
@ -70,6 +70,12 @@ def _get_input_dir(inputs: Sequence[Any]) -> Optional[str]:
|
||||||
return "/" + os.path.join(*str(absolute_path).split("/")[:4])
|
return "/" + os.path.join(*str(absolute_path).split("/")[:4])
|
||||||
|
|
||||||
|
|
||||||
|
def _get_default_num_workers() -> int:
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
return torch.cuda.device_count()
|
||||||
|
return os.cpu_count() or 1
|
||||||
|
|
||||||
|
|
||||||
class LambdaDataTransformRecipe(DataTransformRecipe):
|
class LambdaDataTransformRecipe(DataTransformRecipe):
|
||||||
def __init__(self, fn: Callable[[str, Any], None], inputs: Sequence[Any]):
|
def __init__(self, fn: Callable[[str, Any], None], inputs: Sequence[Any]):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
@ -161,6 +167,7 @@ def map(
|
||||||
reorder_files: bool = True,
|
reorder_files: bool = True,
|
||||||
error_when_not_empty: bool = False,
|
error_when_not_empty: bool = False,
|
||||||
reader: Optional[BaseReader] = None,
|
reader: Optional[BaseReader] = None,
|
||||||
|
batch_size: Optional[int] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""This function map a callbable over a collection of files possibly in a distributed way.
|
"""This function map a callbable over a collection of files possibly in a distributed way.
|
||||||
|
|
||||||
|
@ -178,6 +185,7 @@ def map(
|
||||||
reorder_files: By default, reorders the files by file size to distribute work equally among all workers.
|
reorder_files: By default, reorders the files by file size to distribute work equally among all workers.
|
||||||
Set this to ``False`` if the order in which samples are processed should be preserved.
|
Set this to ``False`` if the order in which samples are processed should be preserved.
|
||||||
error_when_not_empty: Whether we should error if the output folder isn't empty.
|
error_when_not_empty: Whether we should error if the output folder isn't empty.
|
||||||
|
batch_size: Group the inputs into batches of batch_size length.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
if not isinstance(inputs, Sequence):
|
if not isinstance(inputs, Sequence):
|
||||||
|
@ -212,10 +220,13 @@ def map(
|
||||||
|
|
||||||
input_dir = _resolve_dir(_get_input_dir(inputs))
|
input_dir = _resolve_dir(_get_input_dir(inputs))
|
||||||
|
|
||||||
|
if isinstance(batch_size, int) and batch_size > 1:
|
||||||
|
inputs = [inputs[pos : pos + batch_size] for pos in range(0, len(inputs), batch_size)]
|
||||||
|
|
||||||
data_processor = DataProcessor(
|
data_processor = DataProcessor(
|
||||||
input_dir=input_dir,
|
input_dir=input_dir,
|
||||||
output_dir=_output_dir,
|
output_dir=_output_dir,
|
||||||
num_workers=num_workers or os.cpu_count(),
|
num_workers=num_workers or _get_default_num_workers(),
|
||||||
fast_dev_run=fast_dev_run,
|
fast_dev_run=fast_dev_run,
|
||||||
num_downloaders=num_downloaders,
|
num_downloaders=num_downloaders,
|
||||||
num_uploaders=num_uploaders,
|
num_uploaders=num_uploaders,
|
||||||
|
@ -247,6 +258,7 @@ def optimize(
|
||||||
num_uploaders: Optional[int] = None,
|
num_uploaders: Optional[int] = None,
|
||||||
reorder_files: bool = True,
|
reorder_files: bool = True,
|
||||||
reader: Optional[BaseReader] = None,
|
reader: Optional[BaseReader] = None,
|
||||||
|
batch_size: Optional[int] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""This function converts a dataset into chunks possibly in a distributed way.
|
"""This function converts a dataset into chunks possibly in a distributed way.
|
||||||
|
|
||||||
|
@ -266,6 +278,7 @@ def optimize(
|
||||||
num_uploaders: The numbers of uploaders per worker.
|
num_uploaders: The numbers of uploaders per worker.
|
||||||
reorder_files: By default, reorders the files by file size to distribute work equally among all workers.
|
reorder_files: By default, reorders the files by file size to distribute work equally among all workers.
|
||||||
Set this to ``False`` if the order in which samples are processed should be preserved.
|
Set this to ``False`` if the order in which samples are processed should be preserved.
|
||||||
|
batch_size: Group the inputs into batches of batch_size length.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
if not isinstance(inputs, Sequence):
|
if not isinstance(inputs, Sequence):
|
||||||
|
@ -302,10 +315,13 @@ def optimize(
|
||||||
|
|
||||||
input_dir = _resolve_dir(_get_input_dir(inputs))
|
input_dir = _resolve_dir(_get_input_dir(inputs))
|
||||||
|
|
||||||
|
if isinstance(batch_size, int) and batch_size > 1:
|
||||||
|
inputs = [inputs[pos : pos + batch_size] for pos in range(0, len(inputs), batch_size)]
|
||||||
|
|
||||||
data_processor = DataProcessor(
|
data_processor = DataProcessor(
|
||||||
input_dir=input_dir,
|
input_dir=input_dir,
|
||||||
output_dir=_output_dir,
|
output_dir=_output_dir,
|
||||||
num_workers=num_workers or os.cpu_count(),
|
num_workers=num_workers or _get_default_num_workers(),
|
||||||
fast_dev_run=fast_dev_run,
|
fast_dev_run=fast_dev_run,
|
||||||
num_downloaders=num_downloaders,
|
num_downloaders=num_downloaders,
|
||||||
num_uploaders=num_uploaders,
|
num_uploaders=num_uploaders,
|
||||||
|
|
|
@ -1025,6 +1025,25 @@ def test_map_is_last(num_workers, expected, tmpdir):
|
||||||
assert sorted(os.listdir(tmpdir)) == expected
|
assert sorted(os.listdir(tmpdir)) == expected
|
||||||
|
|
||||||
|
|
||||||
|
def map_batch_size_fn(indexes, output_dir):
|
||||||
|
path = os.path.join(output_dir, str(indexes))
|
||||||
|
with open(path, "w") as f:
|
||||||
|
f.write("hello world")
|
||||||
|
|
||||||
|
|
||||||
|
def test_map_batch_size(tmpdir):
|
||||||
|
map(
|
||||||
|
map_batch_size_fn,
|
||||||
|
list(range(5)),
|
||||||
|
output_dir=str(tmpdir),
|
||||||
|
error_when_not_empty=False,
|
||||||
|
num_workers=1,
|
||||||
|
batch_size=2,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert sorted(os.listdir(tmpdir)) == ["[0, 1]", "[2, 3]", "[4]"]
|
||||||
|
|
||||||
|
|
||||||
def no_op(index):
|
def no_op(index):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue