Add batch_size to map, optimize (#19489)
This commit is contained in:
parent
bbc5488a62
commit
bb35e8e0d3
|
@ -70,6 +70,12 @@ def _get_input_dir(inputs: Sequence[Any]) -> Optional[str]:
|
|||
return "/" + os.path.join(*str(absolute_path).split("/")[:4])
|
||||
|
||||
|
||||
def _get_default_num_workers() -> int:
|
||||
if torch.cuda.is_available():
|
||||
return torch.cuda.device_count()
|
||||
return os.cpu_count() or 1
|
||||
|
||||
|
||||
class LambdaDataTransformRecipe(DataTransformRecipe):
|
||||
def __init__(self, fn: Callable[[str, Any], None], inputs: Sequence[Any]):
|
||||
super().__init__()
|
||||
|
@ -161,6 +167,7 @@ def map(
|
|||
reorder_files: bool = True,
|
||||
error_when_not_empty: bool = False,
|
||||
reader: Optional[BaseReader] = None,
|
||||
batch_size: Optional[int] = None,
|
||||
) -> None:
|
||||
"""This function map a callbable over a collection of files possibly in a distributed way.
|
||||
|
||||
|
@ -178,6 +185,7 @@ def map(
|
|||
reorder_files: By default, reorders the files by file size to distribute work equally among all workers.
|
||||
Set this to ``False`` if the order in which samples are processed should be preserved.
|
||||
error_when_not_empty: Whether we should error if the output folder isn't empty.
|
||||
batch_size: Group the inputs into batches of batch_size length.
|
||||
|
||||
"""
|
||||
if not isinstance(inputs, Sequence):
|
||||
|
@ -212,10 +220,13 @@ def map(
|
|||
|
||||
input_dir = _resolve_dir(_get_input_dir(inputs))
|
||||
|
||||
if isinstance(batch_size, int) and batch_size > 1:
|
||||
inputs = [inputs[pos : pos + batch_size] for pos in range(0, len(inputs), batch_size)]
|
||||
|
||||
data_processor = DataProcessor(
|
||||
input_dir=input_dir,
|
||||
output_dir=_output_dir,
|
||||
num_workers=num_workers or os.cpu_count(),
|
||||
num_workers=num_workers or _get_default_num_workers(),
|
||||
fast_dev_run=fast_dev_run,
|
||||
num_downloaders=num_downloaders,
|
||||
num_uploaders=num_uploaders,
|
||||
|
@ -247,6 +258,7 @@ def optimize(
|
|||
num_uploaders: Optional[int] = None,
|
||||
reorder_files: bool = True,
|
||||
reader: Optional[BaseReader] = None,
|
||||
batch_size: Optional[int] = None,
|
||||
) -> None:
|
||||
"""This function converts a dataset into chunks possibly in a distributed way.
|
||||
|
||||
|
@ -266,6 +278,7 @@ def optimize(
|
|||
num_uploaders: The numbers of uploaders per worker.
|
||||
reorder_files: By default, reorders the files by file size to distribute work equally among all workers.
|
||||
Set this to ``False`` if the order in which samples are processed should be preserved.
|
||||
batch_size: Group the inputs into batches of batch_size length.
|
||||
|
||||
"""
|
||||
if not isinstance(inputs, Sequence):
|
||||
|
@ -302,10 +315,13 @@ def optimize(
|
|||
|
||||
input_dir = _resolve_dir(_get_input_dir(inputs))
|
||||
|
||||
if isinstance(batch_size, int) and batch_size > 1:
|
||||
inputs = [inputs[pos : pos + batch_size] for pos in range(0, len(inputs), batch_size)]
|
||||
|
||||
data_processor = DataProcessor(
|
||||
input_dir=input_dir,
|
||||
output_dir=_output_dir,
|
||||
num_workers=num_workers or os.cpu_count(),
|
||||
num_workers=num_workers or _get_default_num_workers(),
|
||||
fast_dev_run=fast_dev_run,
|
||||
num_downloaders=num_downloaders,
|
||||
num_uploaders=num_uploaders,
|
||||
|
|
|
@ -1025,6 +1025,25 @@ def test_map_is_last(num_workers, expected, tmpdir):
|
|||
assert sorted(os.listdir(tmpdir)) == expected
|
||||
|
||||
|
||||
def map_batch_size_fn(indexes, output_dir):
|
||||
path = os.path.join(output_dir, str(indexes))
|
||||
with open(path, "w") as f:
|
||||
f.write("hello world")
|
||||
|
||||
|
||||
def test_map_batch_size(tmpdir):
|
||||
map(
|
||||
map_batch_size_fn,
|
||||
list(range(5)),
|
||||
output_dir=str(tmpdir),
|
||||
error_when_not_empty=False,
|
||||
num_workers=1,
|
||||
batch_size=2,
|
||||
)
|
||||
|
||||
assert sorted(os.listdir(tmpdir)) == ["[0, 1]", "[2, 3]", "[4]"]
|
||||
|
||||
|
||||
def no_op(index):
|
||||
pass
|
||||
|
||||
|
|
Loading…
Reference in New Issue