Greedily select files for data processor workers based on size (#18907)

Co-authored-by: thomas <thomas@thomass-MacBook-Pro.local>
This commit is contained in:
Adrian Wälchli 2023-11-07 01:33:50 +01:00 committed by GitHub
parent e79ac21415
commit 62771f3932
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 243 additions and 73 deletions

View File

@ -28,6 +28,7 @@ from lightning.data.streaming.constants import (
_LIGHTNING_CLOUD_GREATER_EQUAL_0_5_50,
_TORCH_GREATER_EQUAL_2_1_0,
)
from lightning.data.utilities.packing import _pack_greedily
from lightning.fabric.accelerators.cuda import is_cuda_available
from lightning.fabric.plugins.environments import LightningEnvironment
from lightning.fabric.utilities.distributed import (
@ -205,13 +206,11 @@ def _upload_fn(upload_queue: Queue, remove_queue: Queue, cache_dir: str, output_
remove_queue.put([local_filepath])
def _associated_items_to_workers(num_workers: int, user_items: List[Any]) -> Tuple[List[int], List[List[Any]]]:
# Associate the items to the workers based on number of nodes and node rank.
def _map_items_to_workers_sequentially(num_workers: int, user_items: List[Any]) -> List[List[Any]]:
num_nodes = _get_num_nodes()
current_node_rank = _get_node_rank()
node_size = len(user_items) // num_nodes
workers_user_items = []
begins = []
for node_rank in range(num_nodes):
if node_rank != current_node_rank:
continue
@ -225,9 +224,44 @@ def _associated_items_to_workers(num_workers: int, user_items: List[Any]) -> Tup
begin = worker_idx * worker_size
end = len(node_user_items) if is_last else (worker_idx + 1) * worker_size
workers_user_items.append(node_user_items[begin:end])
begins.append(begin)
return begins, workers_user_items
raise RuntimeError(f"The current_node_rank {current_node_rank} doesn't exist in {num_nodes}.")
return workers_user_items
def _map_items_to_workers_weighted(
num_workers: int, user_items: List[Any], weights: Optional[List[int]] = None
) -> List[List[Any]]:
# Associate the items to the workers based on number of nodes and node rank.
weights = [1] * len(user_items) if weights is None else weights
num_nodes = _get_num_nodes()
node_rank = _get_node_rank()
world_size = num_nodes * num_workers
worker_items, worker_weights = _pack_greedily(items=user_items, weights=weights, num_bins=world_size)
worker_ids_this_node = range(node_rank * num_workers, (node_rank + 1) * num_workers)
for worker_id, size in worker_weights.items():
if worker_id not in worker_ids_this_node:
continue
print(f"Worker {worker_id} gets {size / 1e6:.1f} MB ({len(worker_items[worker_id])} files)")
return [worker_items[worker_id] for worker_id in worker_ids_this_node]
def _get_item_filesizes(items: List[Any], base_path: str = "") -> List[int]:
"""Computes the total size in bytes of all file paths for every datastructure in the given list."""
item_sizes = []
for item in items:
flattened_item, spec = tree_flatten(item)
num_bytes = 0
for index, element in enumerate(flattened_item):
if isinstance(element, str) and element.startswith(base_path) and os.path.exists(element):
file_bytes = os.path.getsize(element)
if file_bytes == 0:
raise RuntimeError(f"The file {element} has 0 bytes!")
num_bytes += file_bytes
item_sizes.append(num_bytes)
return item_sizes
class BaseWorker:
@ -235,7 +269,6 @@ class BaseWorker:
self,
worker_index: int,
num_workers: int,
start_index: int,
node_rank: int,
data_recipe: "DataRecipe",
input_dir: Dir,
@ -250,7 +283,6 @@ class BaseWorker:
"""The BaseWorker is responsible to process the user data."""
self.worker_index = worker_index
self.num_workers = num_workers
self.start_index = start_index
self.node_rank = node_rank
self.data_recipe = data_recipe
self.input_dir = input_dir
@ -692,6 +724,7 @@ class DataProcessor:
delete_cached_files: bool = True,
fast_dev_run: Optional[Union[bool, int]] = None,
random_seed: Optional[int] = 42,
reorder_files: bool = True,
):
"""The `DatasetOptimiser` provides an efficient way to process data across multiple machine into chunks to make
training faster.
@ -704,6 +737,8 @@ class DataProcessor:
delete_cached_files: Whether to delete the cached files.
fast_dev_run: Whether to run a quick dev run.
random_seed: The random seed to be set before shuffling the data.
reorder_files: By default, reorders the files by file size to distribute work equally among all workers.
Set this to ``False`` if the order in which samples are processed should be preserved.
"""
self.input_dir = _resolve_dir(input_dir)
@ -717,6 +752,7 @@ class DataProcessor:
self.progress_queue: Optional[Queue] = None
self.error_queue: Queue = Queue()
self.stop_queues: List[Queue] = []
self.reorder_files = reorder_files
if self.input_dir:
# Ensure the input dir is the same across all nodes
@ -746,8 +782,15 @@ class DataProcessor:
if not isinstance(user_items, list):
raise ValueError("The `prepare_structure` should return a list of item metadata.")
# Associate the items to the workers based on num_nodes and node_rank
begins, workers_user_items = _associated_items_to_workers(self.num_workers, user_items)
if self.reorder_files:
# TODO: Only do this on node 0, and broadcast the item sizes to the other nodes.
item_sizes = _get_item_filesizes(user_items, base_path=self.input_dir.path)
workers_user_items = _map_items_to_workers_weighted(
num_workers=self.num_workers, user_items=user_items, weights=item_sizes
)
else:
workers_user_items = _map_items_to_workers_sequentially(num_workers=self.num_workers, user_items=user_items)
print(f"Setup finished in {round(time() - t0, 3)} seconds. Found {len(user_items)} items to process.")
if self.fast_dev_run:
@ -767,7 +810,7 @@ class DataProcessor:
signal.signal(signal.SIGINT, self._signal_handler)
self._create_process_workers(data_recipe, begins, workers_user_items)
self._create_process_workers(data_recipe, workers_user_items)
print("Workers are ready ! Starting data processing...")
@ -835,9 +878,7 @@ class DataProcessor:
w.join(0)
raise RuntimeError(f"We found the following error {error}.")
def _create_process_workers(
self, data_recipe: DataRecipe, begins: List[int], workers_user_items: List[List[Any]]
) -> None:
def _create_process_workers(self, data_recipe: DataRecipe, workers_user_items: List[List[Any]]) -> None:
self.progress_queue = Queue()
workers: List[DataWorkerProcess] = []
stop_queues: List[Queue] = []
@ -846,7 +887,6 @@ class DataProcessor:
worker = DataWorkerProcess(
worker_idx,
self.num_workers,
begins[worker_idx],
_get_node_rank(),
data_recipe,
self.input_dir,

View File

@ -0,0 +1,23 @@
from collections import defaultdict
from typing import Any, Dict, List, Tuple
def _pack_greedily(items: List[Any], weights: List[int], num_bins: int) -> Tuple[Dict[int, List[Any]], Dict[int, int]]:
"""Greedily pack items with given weights into bins such that the total weight of each bin is roughly equally
distributed among all bins."""
if len(items) != len(weights):
raise ValueError(f"Items and weights must have the same length, got {len(items)} and {len(weights)}.")
if any(w <= 0 for w in weights):
raise ValueError("All weights must be positive.")
sorted_items_and_weights = sorted(zip(items, weights), key=lambda x: x[1], reverse=True)
bin_contents = defaultdict(list)
bin_weights = {i: 0 for i in range(num_bins)}
for item, weight in sorted_items_and_weights:
min_bin_id = min(bin_weights, key=(lambda x: bin_weights[x]), default=0)
bin_contents[min_bin_id].append(item)
bin_weights[min_bin_id] += weight
return bin_contents, bin_weights

View File

@ -1,4 +1,5 @@
import os
import random
import sys
from typing import Any, List
from unittest import mock
@ -14,8 +15,10 @@ from lightning.data.streaming.data_processor import (
DataChunkRecipe,
DataProcessor,
DataTransformRecipe,
_associated_items_to_workers,
_download_data_target,
_get_item_filesizes,
_map_items_to_workers_sequentially,
_map_items_to_workers_weighted,
_remove_target,
_upload_fn,
_wait_for_file_to_exist,
@ -249,75 +252,98 @@ def test_cache_dir_cleanup(tmpdir, monkeypatch):
assert os.listdir(cache_dir) == []
def test_associated_items_to_workers(monkeypatch):
_, workers_user_items = _associated_items_to_workers(1, range(105))
assert workers_user_items == [range(0, 105)]
_, workers_user_items = _associated_items_to_workers(2, range(105))
assert workers_user_items == [range(0, 52), range(52, 105)]
_, workers_user_items = _associated_items_to_workers(3, range(105))
assert workers_user_items == [range(0, 35), range(35, 70), range(70, 105)]
_, workers_user_items = _associated_items_to_workers(4, range(105))
assert workers_user_items == [range(0, 26), range(26, 52), range(52, 78), range(78, 105)]
def test_map_items_to_workers_weighted(monkeypatch):
workers_user_items = _map_items_to_workers_weighted(1, list(range(5)))
assert workers_user_items == [list(range(5))]
workers_user_items = _map_items_to_workers_weighted(2, list(range(5)))
assert workers_user_items == [[0, 2, 4], [1, 3]]
workers_user_items = _map_items_to_workers_weighted(3, list(range(5)))
assert workers_user_items == [[0, 3], [1, 4], [2]]
workers_user_items = _map_items_to_workers_weighted(4, list(range(5)))
assert workers_user_items == [[0, 4], [1], [2], [3]]
monkeypatch.setenv("DATA_OPTIMIZER_NUM_NODES", "2")
monkeypatch.setenv("DATA_OPTIMIZER_NODE_RANK", "0")
workers_user_items = _map_items_to_workers_weighted(1, list(range(5)))
assert workers_user_items == [[0, 2, 4]]
workers_user_items = _map_items_to_workers_weighted(2, list(range(5)))
assert workers_user_items == [[0, 4], [1]]
_, workers_user_items = _associated_items_to_workers(1, range(105))
assert workers_user_items == [range(0, 52)]
_, workers_user_items = _associated_items_to_workers(2, range(105))
assert workers_user_items == [range(0, 26), range(26, 52)]
_, workers_user_items = _associated_items_to_workers(3, range(105))
assert workers_user_items == [range(0, 17), range(17, 34), range(34, 52)]
_, workers_user_items = _associated_items_to_workers(4, range(105))
assert workers_user_items == [range(0, 13), range(13, 26), range(26, 39), range(39, 52)]
monkeypatch.setenv("DATA_OPTIMIZER_NUM_NODES", "2")
monkeypatch.setenv("DATA_OPTIMIZER_NODE_RANK", "1")
_, workers_user_items = _associated_items_to_workers(1, range(105))
assert workers_user_items == [range(52, 105)]
_, workers_user_items = _associated_items_to_workers(2, range(105))
assert workers_user_items == [range(52, 78), range(78, 105)]
_, workers_user_items = _associated_items_to_workers(3, range(105))
assert workers_user_items == [range(52, 69), range(69, 86), range(86, 105)]
_, workers_user_items = _associated_items_to_workers(4, range(105))
assert workers_user_items == [range(52, 65), range(65, 78), range(78, 91), range(91, 105)]
workers_user_items = _map_items_to_workers_weighted(1, list(range(5)))
assert workers_user_items == [[1, 3]]
workers_user_items = _map_items_to_workers_weighted(2, list(range(5)))
assert workers_user_items == [[2], [3]]
monkeypatch.setenv("DATA_OPTIMIZER_NUM_NODES", "4")
monkeypatch.setenv("DATA_OPTIMIZER_NODE_RANK", "0")
workers_user_items = _map_items_to_workers_weighted(1, list(range(32)))
assert workers_user_items == [[0, 4, 8, 12, 16, 20, 24, 28]]
workers_user_items = _map_items_to_workers_weighted(2, list(range(32)))
assert workers_user_items == [[0, 8, 16, 24], [1, 9, 17, 25]]
workers_user_items = _map_items_to_workers_weighted(3, list(range(32)))
assert workers_user_items == [[0, 12, 24], [1, 13, 25], [2, 14, 26]]
workers_user_items = _map_items_to_workers_weighted(4, list(range(32)))
assert workers_user_items == [[0, 16], [1, 17], [2, 18], [3, 19]]
_, workers_user_items = _associated_items_to_workers(1, range(105))
assert workers_user_items == [range(0, 26)]
_, workers_user_items = _associated_items_to_workers(2, range(105))
assert workers_user_items == [range(0, 13), range(13, 26)]
_, workers_user_items = _associated_items_to_workers(3, range(105))
assert workers_user_items == [range(0, 8), range(8, 16), range(16, 26)]
_, workers_user_items = _associated_items_to_workers(4, range(105))
assert workers_user_items == [range(0, 6), range(6, 12), range(12, 18), range(18, 26)]
monkeypatch.setenv("DATA_OPTIMIZER_NUM_NODES", "4")
monkeypatch.setenv("DATA_OPTIMIZER_NODE_RANK", "3")
workers_user_items = _map_items_to_workers_weighted(1, list(range(32)))
assert workers_user_items == [[3, 7, 11, 15, 19, 23, 27, 31]]
workers_user_items = _map_items_to_workers_weighted(2, list(range(32)))
assert workers_user_items == [[6, 14, 22, 30], [7, 15, 23, 31]]
workers_user_items = _map_items_to_workers_weighted(3, list(range(32)))
assert workers_user_items == [[9, 21], [10, 22], [11, 23]]
workers_user_items = _map_items_to_workers_weighted(4, list(range(32)))
assert workers_user_items == [[12, 28], [13, 29], [14, 30], [15, 31]]
_, workers_user_items = _associated_items_to_workers(1, range(105))
assert workers_user_items == [range(78, 105)]
_, workers_user_items = _associated_items_to_workers(2, range(105))
assert workers_user_items == [range(78, 91), range(91, 105)]
def test_map_items_to_workers_sequentially(monkeypatch):
workers_user_items = _map_items_to_workers_sequentially(1, list(range(5)))
assert workers_user_items == [list(range(5))]
workers_user_items = _map_items_to_workers_sequentially(2, list(range(5)))
assert workers_user_items == [[0, 1], [2, 3, 4]]
workers_user_items = _map_items_to_workers_sequentially(3, list(range(5)))
assert workers_user_items == [[0], [1], [2, 3, 4]]
workers_user_items = _map_items_to_workers_sequentially(4, list(range(5)))
assert workers_user_items == [[0], [1], [2], [3, 4]]
_, workers_user_items = _associated_items_to_workers(3, range(105))
assert workers_user_items == [range(78, 87), range(87, 96), range(96, 105)]
monkeypatch.setenv("DATA_OPTIMIZER_NUM_NODES", "2")
monkeypatch.setenv("DATA_OPTIMIZER_NODE_RANK", "0")
workers_user_items = _map_items_to_workers_sequentially(1, list(range(5)))
assert workers_user_items == [[0, 1]]
workers_user_items = _map_items_to_workers_sequentially(2, list(range(5)))
assert workers_user_items == [[0], [1]]
_, workers_user_items = _associated_items_to_workers(4, range(105))
assert workers_user_items == [range(78, 84), range(84, 90), range(90, 96), range(96, 105)]
monkeypatch.setenv("DATA_OPTIMIZER_NUM_NODES", "2")
monkeypatch.setenv("DATA_OPTIMIZER_NODE_RANK", "1")
workers_user_items = _map_items_to_workers_sequentially(1, list(range(5)))
assert workers_user_items == [[2, 3, 4]]
workers_user_items = _map_items_to_workers_sequentially(2, list(range(5)))
assert workers_user_items == [[2], [3, 4]]
monkeypatch.setenv("DATA_OPTIMIZER_NUM_NODES", "4")
monkeypatch.setenv("DATA_OPTIMIZER_NODE_RANK", "0")
workers_user_items = _map_items_to_workers_sequentially(1, list(range(32)))
assert workers_user_items == [[0, 1, 2, 3, 4, 5, 6, 7]]
workers_user_items = _map_items_to_workers_sequentially(2, list(range(32)))
assert workers_user_items == [[0, 1, 2, 3], [4, 5, 6, 7]]
workers_user_items = _map_items_to_workers_sequentially(3, list(range(32)))
assert workers_user_items == [[0, 1], [2, 3], [4, 5, 6, 7]]
workers_user_items = _map_items_to_workers_sequentially(4, list(range(32)))
assert workers_user_items == [[0, 1], [2, 3], [4, 5], [6, 7]]
monkeypatch.setenv("DATA_OPTIMIZER_NUM_NODES", "4")
monkeypatch.setenv("DATA_OPTIMIZER_NODE_RANK", "3")
workers_user_items = _map_items_to_workers_sequentially(1, list(range(32)))
assert workers_user_items == [[24, 25, 26, 27, 28, 29, 30, 31]]
workers_user_items = _map_items_to_workers_sequentially(2, list(range(32)))
assert workers_user_items == [[24, 25, 26, 27], [28, 29, 30, 31]]
workers_user_items = _map_items_to_workers_sequentially(3, list(range(32)))
assert workers_user_items == [[24, 25], [26, 27], [28, 29, 30, 31]]
workers_user_items = _map_items_to_workers_sequentially(4, list(range(32)))
assert workers_user_items == [[24, 25], [26, 27], [28, 29], [30, 31]]
class CustomDataChunkRecipe(DataChunkRecipe):
@ -519,7 +545,7 @@ def test_data_processsor_distributed(fast_dev_run, delete_cached_files, tmpdir,
"data_format": "jpeg",
"compression": None,
"num_chunks": 16,
"num_bytes_per_chunk": [2, 2, 2, 1, 2, 2, 2, 2, 2, 2, 2, 1, 2, 2, 2, 2],
"num_bytes_per_chunk": [2, 2, 2, 2, 2, 2, 2, 1, 2, 2, 2, 1, 2, 2, 2, 2],
}
@ -807,3 +833,40 @@ def test_lambda_transform_recipe_class(monkeypatch):
data_recipe.prepare_item("", 1)
assert called
def _generate_file_with_size(file_path, num_bytes):
assert num_bytes % 8 == 0
content = bytearray(random.getrandbits(8) for _ in range(num_bytes))
with open(file_path, "wb") as file:
file.write(content)
def test_get_item_filesizes(tmp_path):
_generate_file_with_size(tmp_path / "file1", 32)
_generate_file_with_size(tmp_path / "file2", 64)
_generate_file_with_size(tmp_path / "file3", 128)
_generate_file_with_size(tmp_path / "file4", 256)
items = [
# not a path
"not a path",
# single file path
str(tmp_path / "file1"),
# tuple: one file path
(1, 2, str(tmp_path / "file2")),
# list: two file paths
[str(tmp_path / "file2"), None, str(tmp_path / "file3")],
# list: one file path exists, one does not
[str(tmp_path / "other" / "other"), None, str(tmp_path / "file4")],
# dict: with file path
{"file": str(tmp_path / "file4"), "data": "not file"},
]
num_bytes = _get_item_filesizes(items, base_path=str(tmp_path))
assert num_bytes == [0, 32, 64, 64 + 128, 256, 256]
with open(tmp_path / "empty_file", "w"):
pass
assert os.path.getsize(tmp_path / "empty_file") == 0
with pytest.raises(RuntimeError, match="has 0 bytes!"):
_get_item_filesizes([str(tmp_path / "empty_file")])

View File

View File

@ -0,0 +1,44 @@
import pytest
from lightning.data.utilities.packing import _pack_greedily
def test_pack_greedily():
with pytest.raises(ValueError, match="must have the same length"):
_pack_greedily(items=["A"], weights=[], num_bins=1)
with pytest.raises(ValueError, match="must have the same length"):
_pack_greedily(items=[], weights=[1], num_bins=1)
with pytest.raises(ValueError, match="must be positive"):
_pack_greedily(items=["A"], weights=[0], num_bins=1)
with pytest.raises(ValueError, match="must be positive"):
_pack_greedily(items=["A"], weights=[-1], num_bins=1)
assert _pack_greedily(items=[], weights=[], num_bins=0) == ({}, {})
assert _pack_greedily(items=[], weights=[], num_bins=1) == ({}, {0: 0})
# one item, one bin
bin_contents, bin_weights = _pack_greedily(items=["A"], weights=[1], num_bins=1)
assert bin_contents == {0: ["A"]}
assert bin_weights == {0: 1}
# more bins than items
bin_contents, bin_weights = _pack_greedily(items=["A"], weights=[1], num_bins=3)
assert bin_contents == {0: ["A"]}
assert bin_weights == {0: 1, 1: 0, 2: 0}
# items with equal weight
bin_contents, bin_weights = _pack_greedily(items=["A", "B", "C", "D"], weights=[3, 3, 3, 3], num_bins=4)
assert bin_contents == {0: ["A"], 1: ["B"], 2: ["C"], 3: ["D"]}
assert bin_weights == {0: 3, 1: 3, 2: 3, 3: 3}
# pigeonhole principle: more items than bins
bin_contents, bin_weights = _pack_greedily(items=["A", "B", "C", "D"], weights=[1, 1, 1, 1], num_bins=3)
assert bin_contents == {0: ["A", "D"], 1: ["B"], 2: ["C"]}
assert bin_weights == {0: 2, 1: 1, 2: 1}
bin_contents, bin_weights = _pack_greedily(
items=["A", "B", "C", "D", "E", "F", "G", "H", "I"],
weights=[4, 1, 2, 5, 8, 7, 3, 6, 9],
num_bins=3,
)
assert bin_contents == {0: ["I", "A", "G"], 1: ["E", "D", "C"], 2: ["F", "H", "B"]}
assert bin_weights == {0: 16, 1: 15, 2: 14}