diff --git a/src/lightning/data/streaming/data_processor.py b/src/lightning/data/streaming/data_processor.py index 3f44cdf8a8..19998a48a8 100644 --- a/src/lightning/data/streaming/data_processor.py +++ b/src/lightning/data/streaming/data_processor.py @@ -28,6 +28,7 @@ from lightning.data.streaming.constants import ( _LIGHTNING_CLOUD_GREATER_EQUAL_0_5_50, _TORCH_GREATER_EQUAL_2_1_0, ) +from lightning.data.utilities.packing import _pack_greedily from lightning.fabric.accelerators.cuda import is_cuda_available from lightning.fabric.plugins.environments import LightningEnvironment from lightning.fabric.utilities.distributed import ( @@ -205,13 +206,11 @@ def _upload_fn(upload_queue: Queue, remove_queue: Queue, cache_dir: str, output_ remove_queue.put([local_filepath]) -def _associated_items_to_workers(num_workers: int, user_items: List[Any]) -> Tuple[List[int], List[List[Any]]]: - # Associate the items to the workers based on number of nodes and node rank. +def _map_items_to_workers_sequentially(num_workers: int, user_items: List[Any]) -> List[List[Any]]: num_nodes = _get_num_nodes() current_node_rank = _get_node_rank() node_size = len(user_items) // num_nodes workers_user_items = [] - begins = [] for node_rank in range(num_nodes): if node_rank != current_node_rank: continue @@ -225,9 +224,44 @@ def _associated_items_to_workers(num_workers: int, user_items: List[Any]) -> Tup begin = worker_idx * worker_size end = len(node_user_items) if is_last else (worker_idx + 1) * worker_size workers_user_items.append(node_user_items[begin:end]) - begins.append(begin) - return begins, workers_user_items - raise RuntimeError(f"The current_node_rank {current_node_rank} doesn't exist in {num_nodes}.") + return workers_user_items + + +def _map_items_to_workers_weighted( + num_workers: int, user_items: List[Any], weights: Optional[List[int]] = None +) -> List[List[Any]]: + # Associate the items to the workers based on number of nodes and node rank. + weights = [1] * len(user_items) if weights is None else weights + num_nodes = _get_num_nodes() + node_rank = _get_node_rank() + world_size = num_nodes * num_workers + + worker_items, worker_weights = _pack_greedily(items=user_items, weights=weights, num_bins=world_size) + worker_ids_this_node = range(node_rank * num_workers, (node_rank + 1) * num_workers) + + for worker_id, size in worker_weights.items(): + if worker_id not in worker_ids_this_node: + continue + print(f"Worker {worker_id} gets {size / 1e6:.1f} MB ({len(worker_items[worker_id])} files)") + + return [worker_items[worker_id] for worker_id in worker_ids_this_node] + + +def _get_item_filesizes(items: List[Any], base_path: str = "") -> List[int]: + """Computes the total size in bytes of all file paths for every datastructure in the given list.""" + item_sizes = [] + for item in items: + flattened_item, spec = tree_flatten(item) + + num_bytes = 0 + for index, element in enumerate(flattened_item): + if isinstance(element, str) and element.startswith(base_path) and os.path.exists(element): + file_bytes = os.path.getsize(element) + if file_bytes == 0: + raise RuntimeError(f"The file {element} has 0 bytes!") + num_bytes += file_bytes + item_sizes.append(num_bytes) + return item_sizes class BaseWorker: @@ -235,7 +269,6 @@ class BaseWorker: self, worker_index: int, num_workers: int, - start_index: int, node_rank: int, data_recipe: "DataRecipe", input_dir: Dir, @@ -250,7 +283,6 @@ class BaseWorker: """The BaseWorker is responsible to process the user data.""" self.worker_index = worker_index self.num_workers = num_workers - self.start_index = start_index self.node_rank = node_rank self.data_recipe = data_recipe self.input_dir = input_dir @@ -692,6 +724,7 @@ class DataProcessor: delete_cached_files: bool = True, fast_dev_run: Optional[Union[bool, int]] = None, random_seed: Optional[int] = 42, + reorder_files: bool = True, ): """The `DatasetOptimiser` provides an efficient way to process data across multiple machine into chunks to make training faster. @@ -704,6 +737,8 @@ class DataProcessor: delete_cached_files: Whether to delete the cached files. fast_dev_run: Whether to run a quick dev run. random_seed: The random seed to be set before shuffling the data. + reorder_files: By default, reorders the files by file size to distribute work equally among all workers. + Set this to ``False`` if the order in which samples are processed should be preserved. """ self.input_dir = _resolve_dir(input_dir) @@ -717,6 +752,7 @@ class DataProcessor: self.progress_queue: Optional[Queue] = None self.error_queue: Queue = Queue() self.stop_queues: List[Queue] = [] + self.reorder_files = reorder_files if self.input_dir: # Ensure the input dir is the same across all nodes @@ -746,8 +782,15 @@ class DataProcessor: if not isinstance(user_items, list): raise ValueError("The `prepare_structure` should return a list of item metadata.") - # Associate the items to the workers based on num_nodes and node_rank - begins, workers_user_items = _associated_items_to_workers(self.num_workers, user_items) + if self.reorder_files: + # TODO: Only do this on node 0, and broadcast the item sizes to the other nodes. + item_sizes = _get_item_filesizes(user_items, base_path=self.input_dir.path) + workers_user_items = _map_items_to_workers_weighted( + num_workers=self.num_workers, user_items=user_items, weights=item_sizes + ) + else: + workers_user_items = _map_items_to_workers_sequentially(num_workers=self.num_workers, user_items=user_items) + print(f"Setup finished in {round(time() - t0, 3)} seconds. Found {len(user_items)} items to process.") if self.fast_dev_run: @@ -767,7 +810,7 @@ class DataProcessor: signal.signal(signal.SIGINT, self._signal_handler) - self._create_process_workers(data_recipe, begins, workers_user_items) + self._create_process_workers(data_recipe, workers_user_items) print("Workers are ready ! Starting data processing...") @@ -835,9 +878,7 @@ class DataProcessor: w.join(0) raise RuntimeError(f"We found the following error {error}.") - def _create_process_workers( - self, data_recipe: DataRecipe, begins: List[int], workers_user_items: List[List[Any]] - ) -> None: + def _create_process_workers(self, data_recipe: DataRecipe, workers_user_items: List[List[Any]]) -> None: self.progress_queue = Queue() workers: List[DataWorkerProcess] = [] stop_queues: List[Queue] = [] @@ -846,7 +887,6 @@ class DataProcessor: worker = DataWorkerProcess( worker_idx, self.num_workers, - begins[worker_idx], _get_node_rank(), data_recipe, self.input_dir, diff --git a/src/lightning/data/utilities/packing.py b/src/lightning/data/utilities/packing.py new file mode 100644 index 0000000000..309a32d726 --- /dev/null +++ b/src/lightning/data/utilities/packing.py @@ -0,0 +1,23 @@ +from collections import defaultdict +from typing import Any, Dict, List, Tuple + + +def _pack_greedily(items: List[Any], weights: List[int], num_bins: int) -> Tuple[Dict[int, List[Any]], Dict[int, int]]: + """Greedily pack items with given weights into bins such that the total weight of each bin is roughly equally + distributed among all bins.""" + + if len(items) != len(weights): + raise ValueError(f"Items and weights must have the same length, got {len(items)} and {len(weights)}.") + if any(w <= 0 for w in weights): + raise ValueError("All weights must be positive.") + + sorted_items_and_weights = sorted(zip(items, weights), key=lambda x: x[1], reverse=True) + bin_contents = defaultdict(list) + bin_weights = {i: 0 for i in range(num_bins)} + + for item, weight in sorted_items_and_weights: + min_bin_id = min(bin_weights, key=(lambda x: bin_weights[x]), default=0) + bin_contents[min_bin_id].append(item) + bin_weights[min_bin_id] += weight + + return bin_contents, bin_weights diff --git a/tests/tests_data/streaming/test_data_processor.py b/tests/tests_data/streaming/test_data_processor.py index b33a215c1a..e5ca10e520 100644 --- a/tests/tests_data/streaming/test_data_processor.py +++ b/tests/tests_data/streaming/test_data_processor.py @@ -1,4 +1,5 @@ import os +import random import sys from typing import Any, List from unittest import mock @@ -14,8 +15,10 @@ from lightning.data.streaming.data_processor import ( DataChunkRecipe, DataProcessor, DataTransformRecipe, - _associated_items_to_workers, _download_data_target, + _get_item_filesizes, + _map_items_to_workers_sequentially, + _map_items_to_workers_weighted, _remove_target, _upload_fn, _wait_for_file_to_exist, @@ -249,75 +252,98 @@ def test_cache_dir_cleanup(tmpdir, monkeypatch): assert os.listdir(cache_dir) == [] -def test_associated_items_to_workers(monkeypatch): - _, workers_user_items = _associated_items_to_workers(1, range(105)) - assert workers_user_items == [range(0, 105)] - - _, workers_user_items = _associated_items_to_workers(2, range(105)) - assert workers_user_items == [range(0, 52), range(52, 105)] - - _, workers_user_items = _associated_items_to_workers(3, range(105)) - assert workers_user_items == [range(0, 35), range(35, 70), range(70, 105)] - - _, workers_user_items = _associated_items_to_workers(4, range(105)) - assert workers_user_items == [range(0, 26), range(26, 52), range(52, 78), range(78, 105)] +def test_map_items_to_workers_weighted(monkeypatch): + workers_user_items = _map_items_to_workers_weighted(1, list(range(5))) + assert workers_user_items == [list(range(5))] + workers_user_items = _map_items_to_workers_weighted(2, list(range(5))) + assert workers_user_items == [[0, 2, 4], [1, 3]] + workers_user_items = _map_items_to_workers_weighted(3, list(range(5))) + assert workers_user_items == [[0, 3], [1, 4], [2]] + workers_user_items = _map_items_to_workers_weighted(4, list(range(5))) + assert workers_user_items == [[0, 4], [1], [2], [3]] monkeypatch.setenv("DATA_OPTIMIZER_NUM_NODES", "2") + monkeypatch.setenv("DATA_OPTIMIZER_NODE_RANK", "0") + workers_user_items = _map_items_to_workers_weighted(1, list(range(5))) + assert workers_user_items == [[0, 2, 4]] + workers_user_items = _map_items_to_workers_weighted(2, list(range(5))) + assert workers_user_items == [[0, 4], [1]] - _, workers_user_items = _associated_items_to_workers(1, range(105)) - assert workers_user_items == [range(0, 52)] - - _, workers_user_items = _associated_items_to_workers(2, range(105)) - assert workers_user_items == [range(0, 26), range(26, 52)] - - _, workers_user_items = _associated_items_to_workers(3, range(105)) - assert workers_user_items == [range(0, 17), range(17, 34), range(34, 52)] - - _, workers_user_items = _associated_items_to_workers(4, range(105)) - assert workers_user_items == [range(0, 13), range(13, 26), range(26, 39), range(39, 52)] - + monkeypatch.setenv("DATA_OPTIMIZER_NUM_NODES", "2") monkeypatch.setenv("DATA_OPTIMIZER_NODE_RANK", "1") - - _, workers_user_items = _associated_items_to_workers(1, range(105)) - assert workers_user_items == [range(52, 105)] - - _, workers_user_items = _associated_items_to_workers(2, range(105)) - assert workers_user_items == [range(52, 78), range(78, 105)] - - _, workers_user_items = _associated_items_to_workers(3, range(105)) - assert workers_user_items == [range(52, 69), range(69, 86), range(86, 105)] - - _, workers_user_items = _associated_items_to_workers(4, range(105)) - assert workers_user_items == [range(52, 65), range(65, 78), range(78, 91), range(91, 105)] + workers_user_items = _map_items_to_workers_weighted(1, list(range(5))) + assert workers_user_items == [[1, 3]] + workers_user_items = _map_items_to_workers_weighted(2, list(range(5))) + assert workers_user_items == [[2], [3]] monkeypatch.setenv("DATA_OPTIMIZER_NUM_NODES", "4") monkeypatch.setenv("DATA_OPTIMIZER_NODE_RANK", "0") + workers_user_items = _map_items_to_workers_weighted(1, list(range(32))) + assert workers_user_items == [[0, 4, 8, 12, 16, 20, 24, 28]] + workers_user_items = _map_items_to_workers_weighted(2, list(range(32))) + assert workers_user_items == [[0, 8, 16, 24], [1, 9, 17, 25]] + workers_user_items = _map_items_to_workers_weighted(3, list(range(32))) + assert workers_user_items == [[0, 12, 24], [1, 13, 25], [2, 14, 26]] + workers_user_items = _map_items_to_workers_weighted(4, list(range(32))) + assert workers_user_items == [[0, 16], [1, 17], [2, 18], [3, 19]] - _, workers_user_items = _associated_items_to_workers(1, range(105)) - assert workers_user_items == [range(0, 26)] - - _, workers_user_items = _associated_items_to_workers(2, range(105)) - assert workers_user_items == [range(0, 13), range(13, 26)] - - _, workers_user_items = _associated_items_to_workers(3, range(105)) - assert workers_user_items == [range(0, 8), range(8, 16), range(16, 26)] - - _, workers_user_items = _associated_items_to_workers(4, range(105)) - assert workers_user_items == [range(0, 6), range(6, 12), range(12, 18), range(18, 26)] - + monkeypatch.setenv("DATA_OPTIMIZER_NUM_NODES", "4") monkeypatch.setenv("DATA_OPTIMIZER_NODE_RANK", "3") + workers_user_items = _map_items_to_workers_weighted(1, list(range(32))) + assert workers_user_items == [[3, 7, 11, 15, 19, 23, 27, 31]] + workers_user_items = _map_items_to_workers_weighted(2, list(range(32))) + assert workers_user_items == [[6, 14, 22, 30], [7, 15, 23, 31]] + workers_user_items = _map_items_to_workers_weighted(3, list(range(32))) + assert workers_user_items == [[9, 21], [10, 22], [11, 23]] + workers_user_items = _map_items_to_workers_weighted(4, list(range(32))) + assert workers_user_items == [[12, 28], [13, 29], [14, 30], [15, 31]] - _, workers_user_items = _associated_items_to_workers(1, range(105)) - assert workers_user_items == [range(78, 105)] - _, workers_user_items = _associated_items_to_workers(2, range(105)) - assert workers_user_items == [range(78, 91), range(91, 105)] +def test_map_items_to_workers_sequentially(monkeypatch): + workers_user_items = _map_items_to_workers_sequentially(1, list(range(5))) + assert workers_user_items == [list(range(5))] + workers_user_items = _map_items_to_workers_sequentially(2, list(range(5))) + assert workers_user_items == [[0, 1], [2, 3, 4]] + workers_user_items = _map_items_to_workers_sequentially(3, list(range(5))) + assert workers_user_items == [[0], [1], [2, 3, 4]] + workers_user_items = _map_items_to_workers_sequentially(4, list(range(5))) + assert workers_user_items == [[0], [1], [2], [3, 4]] - _, workers_user_items = _associated_items_to_workers(3, range(105)) - assert workers_user_items == [range(78, 87), range(87, 96), range(96, 105)] + monkeypatch.setenv("DATA_OPTIMIZER_NUM_NODES", "2") + monkeypatch.setenv("DATA_OPTIMIZER_NODE_RANK", "0") + workers_user_items = _map_items_to_workers_sequentially(1, list(range(5))) + assert workers_user_items == [[0, 1]] + workers_user_items = _map_items_to_workers_sequentially(2, list(range(5))) + assert workers_user_items == [[0], [1]] - _, workers_user_items = _associated_items_to_workers(4, range(105)) - assert workers_user_items == [range(78, 84), range(84, 90), range(90, 96), range(96, 105)] + monkeypatch.setenv("DATA_OPTIMIZER_NUM_NODES", "2") + monkeypatch.setenv("DATA_OPTIMIZER_NODE_RANK", "1") + workers_user_items = _map_items_to_workers_sequentially(1, list(range(5))) + assert workers_user_items == [[2, 3, 4]] + workers_user_items = _map_items_to_workers_sequentially(2, list(range(5))) + assert workers_user_items == [[2], [3, 4]] + + monkeypatch.setenv("DATA_OPTIMIZER_NUM_NODES", "4") + monkeypatch.setenv("DATA_OPTIMIZER_NODE_RANK", "0") + workers_user_items = _map_items_to_workers_sequentially(1, list(range(32))) + assert workers_user_items == [[0, 1, 2, 3, 4, 5, 6, 7]] + workers_user_items = _map_items_to_workers_sequentially(2, list(range(32))) + assert workers_user_items == [[0, 1, 2, 3], [4, 5, 6, 7]] + workers_user_items = _map_items_to_workers_sequentially(3, list(range(32))) + assert workers_user_items == [[0, 1], [2, 3], [4, 5, 6, 7]] + workers_user_items = _map_items_to_workers_sequentially(4, list(range(32))) + assert workers_user_items == [[0, 1], [2, 3], [4, 5], [6, 7]] + + monkeypatch.setenv("DATA_OPTIMIZER_NUM_NODES", "4") + monkeypatch.setenv("DATA_OPTIMIZER_NODE_RANK", "3") + workers_user_items = _map_items_to_workers_sequentially(1, list(range(32))) + assert workers_user_items == [[24, 25, 26, 27, 28, 29, 30, 31]] + workers_user_items = _map_items_to_workers_sequentially(2, list(range(32))) + assert workers_user_items == [[24, 25, 26, 27], [28, 29, 30, 31]] + workers_user_items = _map_items_to_workers_sequentially(3, list(range(32))) + assert workers_user_items == [[24, 25], [26, 27], [28, 29, 30, 31]] + workers_user_items = _map_items_to_workers_sequentially(4, list(range(32))) + assert workers_user_items == [[24, 25], [26, 27], [28, 29], [30, 31]] class CustomDataChunkRecipe(DataChunkRecipe): @@ -519,7 +545,7 @@ def test_data_processsor_distributed(fast_dev_run, delete_cached_files, tmpdir, "data_format": "jpeg", "compression": None, "num_chunks": 16, - "num_bytes_per_chunk": [2, 2, 2, 1, 2, 2, 2, 2, 2, 2, 2, 1, 2, 2, 2, 2], + "num_bytes_per_chunk": [2, 2, 2, 2, 2, 2, 2, 1, 2, 2, 2, 1, 2, 2, 2, 2], } @@ -807,3 +833,40 @@ def test_lambda_transform_recipe_class(monkeypatch): data_recipe.prepare_item("", 1) assert called + + +def _generate_file_with_size(file_path, num_bytes): + assert num_bytes % 8 == 0 + content = bytearray(random.getrandbits(8) for _ in range(num_bytes)) + with open(file_path, "wb") as file: + file.write(content) + + +def test_get_item_filesizes(tmp_path): + _generate_file_with_size(tmp_path / "file1", 32) + _generate_file_with_size(tmp_path / "file2", 64) + _generate_file_with_size(tmp_path / "file3", 128) + _generate_file_with_size(tmp_path / "file4", 256) + + items = [ + # not a path + "not a path", + # single file path + str(tmp_path / "file1"), + # tuple: one file path + (1, 2, str(tmp_path / "file2")), + # list: two file paths + [str(tmp_path / "file2"), None, str(tmp_path / "file3")], + # list: one file path exists, one does not + [str(tmp_path / "other" / "other"), None, str(tmp_path / "file4")], + # dict: with file path + {"file": str(tmp_path / "file4"), "data": "not file"}, + ] + num_bytes = _get_item_filesizes(items, base_path=str(tmp_path)) + assert num_bytes == [0, 32, 64, 64 + 128, 256, 256] + + with open(tmp_path / "empty_file", "w"): + pass + assert os.path.getsize(tmp_path / "empty_file") == 0 + with pytest.raises(RuntimeError, match="has 0 bytes!"): + _get_item_filesizes([str(tmp_path / "empty_file")]) diff --git a/tests/tests_data/utilities/__init__.py b/tests/tests_data/utilities/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/tests_data/utilities/test_packing.py b/tests/tests_data/utilities/test_packing.py new file mode 100644 index 0000000000..878083ccf1 --- /dev/null +++ b/tests/tests_data/utilities/test_packing.py @@ -0,0 +1,44 @@ +import pytest +from lightning.data.utilities.packing import _pack_greedily + + +def test_pack_greedily(): + with pytest.raises(ValueError, match="must have the same length"): + _pack_greedily(items=["A"], weights=[], num_bins=1) + with pytest.raises(ValueError, match="must have the same length"): + _pack_greedily(items=[], weights=[1], num_bins=1) + with pytest.raises(ValueError, match="must be positive"): + _pack_greedily(items=["A"], weights=[0], num_bins=1) + with pytest.raises(ValueError, match="must be positive"): + _pack_greedily(items=["A"], weights=[-1], num_bins=1) + + assert _pack_greedily(items=[], weights=[], num_bins=0) == ({}, {}) + assert _pack_greedily(items=[], weights=[], num_bins=1) == ({}, {0: 0}) + + # one item, one bin + bin_contents, bin_weights = _pack_greedily(items=["A"], weights=[1], num_bins=1) + assert bin_contents == {0: ["A"]} + assert bin_weights == {0: 1} + + # more bins than items + bin_contents, bin_weights = _pack_greedily(items=["A"], weights=[1], num_bins=3) + assert bin_contents == {0: ["A"]} + assert bin_weights == {0: 1, 1: 0, 2: 0} + + # items with equal weight + bin_contents, bin_weights = _pack_greedily(items=["A", "B", "C", "D"], weights=[3, 3, 3, 3], num_bins=4) + assert bin_contents == {0: ["A"], 1: ["B"], 2: ["C"], 3: ["D"]} + assert bin_weights == {0: 3, 1: 3, 2: 3, 3: 3} + + # pigeonhole principle: more items than bins + bin_contents, bin_weights = _pack_greedily(items=["A", "B", "C", "D"], weights=[1, 1, 1, 1], num_bins=3) + assert bin_contents == {0: ["A", "D"], 1: ["B"], 2: ["C"]} + assert bin_weights == {0: 2, 1: 1, 2: 1} + + bin_contents, bin_weights = _pack_greedily( + items=["A", "B", "C", "D", "E", "F", "G", "H", "I"], + weights=[4, 1, 2, 5, 8, 7, 3, 6, 9], + num_bins=3, + ) + assert bin_contents == {0: ["I", "A", "G"], 1: ["E", "D", "C"], 2: ["F", "H", "B"]} + assert bin_weights == {0: 16, 1: 15, 2: 14}