Greedily select files for data processor workers based on size (#18907)
Co-authored-by: thomas <thomas@thomass-MacBook-Pro.local>
This commit is contained in:
parent
e79ac21415
commit
62771f3932
|
@ -28,6 +28,7 @@ from lightning.data.streaming.constants import (
|
|||
_LIGHTNING_CLOUD_GREATER_EQUAL_0_5_50,
|
||||
_TORCH_GREATER_EQUAL_2_1_0,
|
||||
)
|
||||
from lightning.data.utilities.packing import _pack_greedily
|
||||
from lightning.fabric.accelerators.cuda import is_cuda_available
|
||||
from lightning.fabric.plugins.environments import LightningEnvironment
|
||||
from lightning.fabric.utilities.distributed import (
|
||||
|
@ -205,13 +206,11 @@ def _upload_fn(upload_queue: Queue, remove_queue: Queue, cache_dir: str, output_
|
|||
remove_queue.put([local_filepath])
|
||||
|
||||
|
||||
def _associated_items_to_workers(num_workers: int, user_items: List[Any]) -> Tuple[List[int], List[List[Any]]]:
|
||||
# Associate the items to the workers based on number of nodes and node rank.
|
||||
def _map_items_to_workers_sequentially(num_workers: int, user_items: List[Any]) -> List[List[Any]]:
|
||||
num_nodes = _get_num_nodes()
|
||||
current_node_rank = _get_node_rank()
|
||||
node_size = len(user_items) // num_nodes
|
||||
workers_user_items = []
|
||||
begins = []
|
||||
for node_rank in range(num_nodes):
|
||||
if node_rank != current_node_rank:
|
||||
continue
|
||||
|
@ -225,9 +224,44 @@ def _associated_items_to_workers(num_workers: int, user_items: List[Any]) -> Tup
|
|||
begin = worker_idx * worker_size
|
||||
end = len(node_user_items) if is_last else (worker_idx + 1) * worker_size
|
||||
workers_user_items.append(node_user_items[begin:end])
|
||||
begins.append(begin)
|
||||
return begins, workers_user_items
|
||||
raise RuntimeError(f"The current_node_rank {current_node_rank} doesn't exist in {num_nodes}.")
|
||||
return workers_user_items
|
||||
|
||||
|
||||
def _map_items_to_workers_weighted(
|
||||
num_workers: int, user_items: List[Any], weights: Optional[List[int]] = None
|
||||
) -> List[List[Any]]:
|
||||
# Associate the items to the workers based on number of nodes and node rank.
|
||||
weights = [1] * len(user_items) if weights is None else weights
|
||||
num_nodes = _get_num_nodes()
|
||||
node_rank = _get_node_rank()
|
||||
world_size = num_nodes * num_workers
|
||||
|
||||
worker_items, worker_weights = _pack_greedily(items=user_items, weights=weights, num_bins=world_size)
|
||||
worker_ids_this_node = range(node_rank * num_workers, (node_rank + 1) * num_workers)
|
||||
|
||||
for worker_id, size in worker_weights.items():
|
||||
if worker_id not in worker_ids_this_node:
|
||||
continue
|
||||
print(f"Worker {worker_id} gets {size / 1e6:.1f} MB ({len(worker_items[worker_id])} files)")
|
||||
|
||||
return [worker_items[worker_id] for worker_id in worker_ids_this_node]
|
||||
|
||||
|
||||
def _get_item_filesizes(items: List[Any], base_path: str = "") -> List[int]:
|
||||
"""Computes the total size in bytes of all file paths for every datastructure in the given list."""
|
||||
item_sizes = []
|
||||
for item in items:
|
||||
flattened_item, spec = tree_flatten(item)
|
||||
|
||||
num_bytes = 0
|
||||
for index, element in enumerate(flattened_item):
|
||||
if isinstance(element, str) and element.startswith(base_path) and os.path.exists(element):
|
||||
file_bytes = os.path.getsize(element)
|
||||
if file_bytes == 0:
|
||||
raise RuntimeError(f"The file {element} has 0 bytes!")
|
||||
num_bytes += file_bytes
|
||||
item_sizes.append(num_bytes)
|
||||
return item_sizes
|
||||
|
||||
|
||||
class BaseWorker:
|
||||
|
@ -235,7 +269,6 @@ class BaseWorker:
|
|||
self,
|
||||
worker_index: int,
|
||||
num_workers: int,
|
||||
start_index: int,
|
||||
node_rank: int,
|
||||
data_recipe: "DataRecipe",
|
||||
input_dir: Dir,
|
||||
|
@ -250,7 +283,6 @@ class BaseWorker:
|
|||
"""The BaseWorker is responsible to process the user data."""
|
||||
self.worker_index = worker_index
|
||||
self.num_workers = num_workers
|
||||
self.start_index = start_index
|
||||
self.node_rank = node_rank
|
||||
self.data_recipe = data_recipe
|
||||
self.input_dir = input_dir
|
||||
|
@ -692,6 +724,7 @@ class DataProcessor:
|
|||
delete_cached_files: bool = True,
|
||||
fast_dev_run: Optional[Union[bool, int]] = None,
|
||||
random_seed: Optional[int] = 42,
|
||||
reorder_files: bool = True,
|
||||
):
|
||||
"""The `DatasetOptimiser` provides an efficient way to process data across multiple machine into chunks to make
|
||||
training faster.
|
||||
|
@ -704,6 +737,8 @@ class DataProcessor:
|
|||
delete_cached_files: Whether to delete the cached files.
|
||||
fast_dev_run: Whether to run a quick dev run.
|
||||
random_seed: The random seed to be set before shuffling the data.
|
||||
reorder_files: By default, reorders the files by file size to distribute work equally among all workers.
|
||||
Set this to ``False`` if the order in which samples are processed should be preserved.
|
||||
|
||||
"""
|
||||
self.input_dir = _resolve_dir(input_dir)
|
||||
|
@ -717,6 +752,7 @@ class DataProcessor:
|
|||
self.progress_queue: Optional[Queue] = None
|
||||
self.error_queue: Queue = Queue()
|
||||
self.stop_queues: List[Queue] = []
|
||||
self.reorder_files = reorder_files
|
||||
|
||||
if self.input_dir:
|
||||
# Ensure the input dir is the same across all nodes
|
||||
|
@ -746,8 +782,15 @@ class DataProcessor:
|
|||
if not isinstance(user_items, list):
|
||||
raise ValueError("The `prepare_structure` should return a list of item metadata.")
|
||||
|
||||
# Associate the items to the workers based on num_nodes and node_rank
|
||||
begins, workers_user_items = _associated_items_to_workers(self.num_workers, user_items)
|
||||
if self.reorder_files:
|
||||
# TODO: Only do this on node 0, and broadcast the item sizes to the other nodes.
|
||||
item_sizes = _get_item_filesizes(user_items, base_path=self.input_dir.path)
|
||||
workers_user_items = _map_items_to_workers_weighted(
|
||||
num_workers=self.num_workers, user_items=user_items, weights=item_sizes
|
||||
)
|
||||
else:
|
||||
workers_user_items = _map_items_to_workers_sequentially(num_workers=self.num_workers, user_items=user_items)
|
||||
|
||||
print(f"Setup finished in {round(time() - t0, 3)} seconds. Found {len(user_items)} items to process.")
|
||||
|
||||
if self.fast_dev_run:
|
||||
|
@ -767,7 +810,7 @@ class DataProcessor:
|
|||
|
||||
signal.signal(signal.SIGINT, self._signal_handler)
|
||||
|
||||
self._create_process_workers(data_recipe, begins, workers_user_items)
|
||||
self._create_process_workers(data_recipe, workers_user_items)
|
||||
|
||||
print("Workers are ready ! Starting data processing...")
|
||||
|
||||
|
@ -835,9 +878,7 @@ class DataProcessor:
|
|||
w.join(0)
|
||||
raise RuntimeError(f"We found the following error {error}.")
|
||||
|
||||
def _create_process_workers(
|
||||
self, data_recipe: DataRecipe, begins: List[int], workers_user_items: List[List[Any]]
|
||||
) -> None:
|
||||
def _create_process_workers(self, data_recipe: DataRecipe, workers_user_items: List[List[Any]]) -> None:
|
||||
self.progress_queue = Queue()
|
||||
workers: List[DataWorkerProcess] = []
|
||||
stop_queues: List[Queue] = []
|
||||
|
@ -846,7 +887,6 @@ class DataProcessor:
|
|||
worker = DataWorkerProcess(
|
||||
worker_idx,
|
||||
self.num_workers,
|
||||
begins[worker_idx],
|
||||
_get_node_rank(),
|
||||
data_recipe,
|
||||
self.input_dir,
|
||||
|
|
|
@ -0,0 +1,23 @@
|
|||
from collections import defaultdict
|
||||
from typing import Any, Dict, List, Tuple
|
||||
|
||||
|
||||
def _pack_greedily(items: List[Any], weights: List[int], num_bins: int) -> Tuple[Dict[int, List[Any]], Dict[int, int]]:
|
||||
"""Greedily pack items with given weights into bins such that the total weight of each bin is roughly equally
|
||||
distributed among all bins."""
|
||||
|
||||
if len(items) != len(weights):
|
||||
raise ValueError(f"Items and weights must have the same length, got {len(items)} and {len(weights)}.")
|
||||
if any(w <= 0 for w in weights):
|
||||
raise ValueError("All weights must be positive.")
|
||||
|
||||
sorted_items_and_weights = sorted(zip(items, weights), key=lambda x: x[1], reverse=True)
|
||||
bin_contents = defaultdict(list)
|
||||
bin_weights = {i: 0 for i in range(num_bins)}
|
||||
|
||||
for item, weight in sorted_items_and_weights:
|
||||
min_bin_id = min(bin_weights, key=(lambda x: bin_weights[x]), default=0)
|
||||
bin_contents[min_bin_id].append(item)
|
||||
bin_weights[min_bin_id] += weight
|
||||
|
||||
return bin_contents, bin_weights
|
|
@ -1,4 +1,5 @@
|
|||
import os
|
||||
import random
|
||||
import sys
|
||||
from typing import Any, List
|
||||
from unittest import mock
|
||||
|
@ -14,8 +15,10 @@ from lightning.data.streaming.data_processor import (
|
|||
DataChunkRecipe,
|
||||
DataProcessor,
|
||||
DataTransformRecipe,
|
||||
_associated_items_to_workers,
|
||||
_download_data_target,
|
||||
_get_item_filesizes,
|
||||
_map_items_to_workers_sequentially,
|
||||
_map_items_to_workers_weighted,
|
||||
_remove_target,
|
||||
_upload_fn,
|
||||
_wait_for_file_to_exist,
|
||||
|
@ -249,75 +252,98 @@ def test_cache_dir_cleanup(tmpdir, monkeypatch):
|
|||
assert os.listdir(cache_dir) == []
|
||||
|
||||
|
||||
def test_associated_items_to_workers(monkeypatch):
|
||||
_, workers_user_items = _associated_items_to_workers(1, range(105))
|
||||
assert workers_user_items == [range(0, 105)]
|
||||
|
||||
_, workers_user_items = _associated_items_to_workers(2, range(105))
|
||||
assert workers_user_items == [range(0, 52), range(52, 105)]
|
||||
|
||||
_, workers_user_items = _associated_items_to_workers(3, range(105))
|
||||
assert workers_user_items == [range(0, 35), range(35, 70), range(70, 105)]
|
||||
|
||||
_, workers_user_items = _associated_items_to_workers(4, range(105))
|
||||
assert workers_user_items == [range(0, 26), range(26, 52), range(52, 78), range(78, 105)]
|
||||
def test_map_items_to_workers_weighted(monkeypatch):
|
||||
workers_user_items = _map_items_to_workers_weighted(1, list(range(5)))
|
||||
assert workers_user_items == [list(range(5))]
|
||||
workers_user_items = _map_items_to_workers_weighted(2, list(range(5)))
|
||||
assert workers_user_items == [[0, 2, 4], [1, 3]]
|
||||
workers_user_items = _map_items_to_workers_weighted(3, list(range(5)))
|
||||
assert workers_user_items == [[0, 3], [1, 4], [2]]
|
||||
workers_user_items = _map_items_to_workers_weighted(4, list(range(5)))
|
||||
assert workers_user_items == [[0, 4], [1], [2], [3]]
|
||||
|
||||
monkeypatch.setenv("DATA_OPTIMIZER_NUM_NODES", "2")
|
||||
monkeypatch.setenv("DATA_OPTIMIZER_NODE_RANK", "0")
|
||||
workers_user_items = _map_items_to_workers_weighted(1, list(range(5)))
|
||||
assert workers_user_items == [[0, 2, 4]]
|
||||
workers_user_items = _map_items_to_workers_weighted(2, list(range(5)))
|
||||
assert workers_user_items == [[0, 4], [1]]
|
||||
|
||||
_, workers_user_items = _associated_items_to_workers(1, range(105))
|
||||
assert workers_user_items == [range(0, 52)]
|
||||
|
||||
_, workers_user_items = _associated_items_to_workers(2, range(105))
|
||||
assert workers_user_items == [range(0, 26), range(26, 52)]
|
||||
|
||||
_, workers_user_items = _associated_items_to_workers(3, range(105))
|
||||
assert workers_user_items == [range(0, 17), range(17, 34), range(34, 52)]
|
||||
|
||||
_, workers_user_items = _associated_items_to_workers(4, range(105))
|
||||
assert workers_user_items == [range(0, 13), range(13, 26), range(26, 39), range(39, 52)]
|
||||
|
||||
monkeypatch.setenv("DATA_OPTIMIZER_NUM_NODES", "2")
|
||||
monkeypatch.setenv("DATA_OPTIMIZER_NODE_RANK", "1")
|
||||
|
||||
_, workers_user_items = _associated_items_to_workers(1, range(105))
|
||||
assert workers_user_items == [range(52, 105)]
|
||||
|
||||
_, workers_user_items = _associated_items_to_workers(2, range(105))
|
||||
assert workers_user_items == [range(52, 78), range(78, 105)]
|
||||
|
||||
_, workers_user_items = _associated_items_to_workers(3, range(105))
|
||||
assert workers_user_items == [range(52, 69), range(69, 86), range(86, 105)]
|
||||
|
||||
_, workers_user_items = _associated_items_to_workers(4, range(105))
|
||||
assert workers_user_items == [range(52, 65), range(65, 78), range(78, 91), range(91, 105)]
|
||||
workers_user_items = _map_items_to_workers_weighted(1, list(range(5)))
|
||||
assert workers_user_items == [[1, 3]]
|
||||
workers_user_items = _map_items_to_workers_weighted(2, list(range(5)))
|
||||
assert workers_user_items == [[2], [3]]
|
||||
|
||||
monkeypatch.setenv("DATA_OPTIMIZER_NUM_NODES", "4")
|
||||
monkeypatch.setenv("DATA_OPTIMIZER_NODE_RANK", "0")
|
||||
workers_user_items = _map_items_to_workers_weighted(1, list(range(32)))
|
||||
assert workers_user_items == [[0, 4, 8, 12, 16, 20, 24, 28]]
|
||||
workers_user_items = _map_items_to_workers_weighted(2, list(range(32)))
|
||||
assert workers_user_items == [[0, 8, 16, 24], [1, 9, 17, 25]]
|
||||
workers_user_items = _map_items_to_workers_weighted(3, list(range(32)))
|
||||
assert workers_user_items == [[0, 12, 24], [1, 13, 25], [2, 14, 26]]
|
||||
workers_user_items = _map_items_to_workers_weighted(4, list(range(32)))
|
||||
assert workers_user_items == [[0, 16], [1, 17], [2, 18], [3, 19]]
|
||||
|
||||
_, workers_user_items = _associated_items_to_workers(1, range(105))
|
||||
assert workers_user_items == [range(0, 26)]
|
||||
|
||||
_, workers_user_items = _associated_items_to_workers(2, range(105))
|
||||
assert workers_user_items == [range(0, 13), range(13, 26)]
|
||||
|
||||
_, workers_user_items = _associated_items_to_workers(3, range(105))
|
||||
assert workers_user_items == [range(0, 8), range(8, 16), range(16, 26)]
|
||||
|
||||
_, workers_user_items = _associated_items_to_workers(4, range(105))
|
||||
assert workers_user_items == [range(0, 6), range(6, 12), range(12, 18), range(18, 26)]
|
||||
|
||||
monkeypatch.setenv("DATA_OPTIMIZER_NUM_NODES", "4")
|
||||
monkeypatch.setenv("DATA_OPTIMIZER_NODE_RANK", "3")
|
||||
workers_user_items = _map_items_to_workers_weighted(1, list(range(32)))
|
||||
assert workers_user_items == [[3, 7, 11, 15, 19, 23, 27, 31]]
|
||||
workers_user_items = _map_items_to_workers_weighted(2, list(range(32)))
|
||||
assert workers_user_items == [[6, 14, 22, 30], [7, 15, 23, 31]]
|
||||
workers_user_items = _map_items_to_workers_weighted(3, list(range(32)))
|
||||
assert workers_user_items == [[9, 21], [10, 22], [11, 23]]
|
||||
workers_user_items = _map_items_to_workers_weighted(4, list(range(32)))
|
||||
assert workers_user_items == [[12, 28], [13, 29], [14, 30], [15, 31]]
|
||||
|
||||
_, workers_user_items = _associated_items_to_workers(1, range(105))
|
||||
assert workers_user_items == [range(78, 105)]
|
||||
|
||||
_, workers_user_items = _associated_items_to_workers(2, range(105))
|
||||
assert workers_user_items == [range(78, 91), range(91, 105)]
|
||||
def test_map_items_to_workers_sequentially(monkeypatch):
|
||||
workers_user_items = _map_items_to_workers_sequentially(1, list(range(5)))
|
||||
assert workers_user_items == [list(range(5))]
|
||||
workers_user_items = _map_items_to_workers_sequentially(2, list(range(5)))
|
||||
assert workers_user_items == [[0, 1], [2, 3, 4]]
|
||||
workers_user_items = _map_items_to_workers_sequentially(3, list(range(5)))
|
||||
assert workers_user_items == [[0], [1], [2, 3, 4]]
|
||||
workers_user_items = _map_items_to_workers_sequentially(4, list(range(5)))
|
||||
assert workers_user_items == [[0], [1], [2], [3, 4]]
|
||||
|
||||
_, workers_user_items = _associated_items_to_workers(3, range(105))
|
||||
assert workers_user_items == [range(78, 87), range(87, 96), range(96, 105)]
|
||||
monkeypatch.setenv("DATA_OPTIMIZER_NUM_NODES", "2")
|
||||
monkeypatch.setenv("DATA_OPTIMIZER_NODE_RANK", "0")
|
||||
workers_user_items = _map_items_to_workers_sequentially(1, list(range(5)))
|
||||
assert workers_user_items == [[0, 1]]
|
||||
workers_user_items = _map_items_to_workers_sequentially(2, list(range(5)))
|
||||
assert workers_user_items == [[0], [1]]
|
||||
|
||||
_, workers_user_items = _associated_items_to_workers(4, range(105))
|
||||
assert workers_user_items == [range(78, 84), range(84, 90), range(90, 96), range(96, 105)]
|
||||
monkeypatch.setenv("DATA_OPTIMIZER_NUM_NODES", "2")
|
||||
monkeypatch.setenv("DATA_OPTIMIZER_NODE_RANK", "1")
|
||||
workers_user_items = _map_items_to_workers_sequentially(1, list(range(5)))
|
||||
assert workers_user_items == [[2, 3, 4]]
|
||||
workers_user_items = _map_items_to_workers_sequentially(2, list(range(5)))
|
||||
assert workers_user_items == [[2], [3, 4]]
|
||||
|
||||
monkeypatch.setenv("DATA_OPTIMIZER_NUM_NODES", "4")
|
||||
monkeypatch.setenv("DATA_OPTIMIZER_NODE_RANK", "0")
|
||||
workers_user_items = _map_items_to_workers_sequentially(1, list(range(32)))
|
||||
assert workers_user_items == [[0, 1, 2, 3, 4, 5, 6, 7]]
|
||||
workers_user_items = _map_items_to_workers_sequentially(2, list(range(32)))
|
||||
assert workers_user_items == [[0, 1, 2, 3], [4, 5, 6, 7]]
|
||||
workers_user_items = _map_items_to_workers_sequentially(3, list(range(32)))
|
||||
assert workers_user_items == [[0, 1], [2, 3], [4, 5, 6, 7]]
|
||||
workers_user_items = _map_items_to_workers_sequentially(4, list(range(32)))
|
||||
assert workers_user_items == [[0, 1], [2, 3], [4, 5], [6, 7]]
|
||||
|
||||
monkeypatch.setenv("DATA_OPTIMIZER_NUM_NODES", "4")
|
||||
monkeypatch.setenv("DATA_OPTIMIZER_NODE_RANK", "3")
|
||||
workers_user_items = _map_items_to_workers_sequentially(1, list(range(32)))
|
||||
assert workers_user_items == [[24, 25, 26, 27, 28, 29, 30, 31]]
|
||||
workers_user_items = _map_items_to_workers_sequentially(2, list(range(32)))
|
||||
assert workers_user_items == [[24, 25, 26, 27], [28, 29, 30, 31]]
|
||||
workers_user_items = _map_items_to_workers_sequentially(3, list(range(32)))
|
||||
assert workers_user_items == [[24, 25], [26, 27], [28, 29, 30, 31]]
|
||||
workers_user_items = _map_items_to_workers_sequentially(4, list(range(32)))
|
||||
assert workers_user_items == [[24, 25], [26, 27], [28, 29], [30, 31]]
|
||||
|
||||
|
||||
class CustomDataChunkRecipe(DataChunkRecipe):
|
||||
|
@ -519,7 +545,7 @@ def test_data_processsor_distributed(fast_dev_run, delete_cached_files, tmpdir,
|
|||
"data_format": "jpeg",
|
||||
"compression": None,
|
||||
"num_chunks": 16,
|
||||
"num_bytes_per_chunk": [2, 2, 2, 1, 2, 2, 2, 2, 2, 2, 2, 1, 2, 2, 2, 2],
|
||||
"num_bytes_per_chunk": [2, 2, 2, 2, 2, 2, 2, 1, 2, 2, 2, 1, 2, 2, 2, 2],
|
||||
}
|
||||
|
||||
|
||||
|
@ -807,3 +833,40 @@ def test_lambda_transform_recipe_class(monkeypatch):
|
|||
|
||||
data_recipe.prepare_item("", 1)
|
||||
assert called
|
||||
|
||||
|
||||
def _generate_file_with_size(file_path, num_bytes):
|
||||
assert num_bytes % 8 == 0
|
||||
content = bytearray(random.getrandbits(8) for _ in range(num_bytes))
|
||||
with open(file_path, "wb") as file:
|
||||
file.write(content)
|
||||
|
||||
|
||||
def test_get_item_filesizes(tmp_path):
|
||||
_generate_file_with_size(tmp_path / "file1", 32)
|
||||
_generate_file_with_size(tmp_path / "file2", 64)
|
||||
_generate_file_with_size(tmp_path / "file3", 128)
|
||||
_generate_file_with_size(tmp_path / "file4", 256)
|
||||
|
||||
items = [
|
||||
# not a path
|
||||
"not a path",
|
||||
# single file path
|
||||
str(tmp_path / "file1"),
|
||||
# tuple: one file path
|
||||
(1, 2, str(tmp_path / "file2")),
|
||||
# list: two file paths
|
||||
[str(tmp_path / "file2"), None, str(tmp_path / "file3")],
|
||||
# list: one file path exists, one does not
|
||||
[str(tmp_path / "other" / "other"), None, str(tmp_path / "file4")],
|
||||
# dict: with file path
|
||||
{"file": str(tmp_path / "file4"), "data": "not file"},
|
||||
]
|
||||
num_bytes = _get_item_filesizes(items, base_path=str(tmp_path))
|
||||
assert num_bytes == [0, 32, 64, 64 + 128, 256, 256]
|
||||
|
||||
with open(tmp_path / "empty_file", "w"):
|
||||
pass
|
||||
assert os.path.getsize(tmp_path / "empty_file") == 0
|
||||
with pytest.raises(RuntimeError, match="has 0 bytes!"):
|
||||
_get_item_filesizes([str(tmp_path / "empty_file")])
|
||||
|
|
|
@ -0,0 +1,44 @@
|
|||
import pytest
|
||||
from lightning.data.utilities.packing import _pack_greedily
|
||||
|
||||
|
||||
def test_pack_greedily():
|
||||
with pytest.raises(ValueError, match="must have the same length"):
|
||||
_pack_greedily(items=["A"], weights=[], num_bins=1)
|
||||
with pytest.raises(ValueError, match="must have the same length"):
|
||||
_pack_greedily(items=[], weights=[1], num_bins=1)
|
||||
with pytest.raises(ValueError, match="must be positive"):
|
||||
_pack_greedily(items=["A"], weights=[0], num_bins=1)
|
||||
with pytest.raises(ValueError, match="must be positive"):
|
||||
_pack_greedily(items=["A"], weights=[-1], num_bins=1)
|
||||
|
||||
assert _pack_greedily(items=[], weights=[], num_bins=0) == ({}, {})
|
||||
assert _pack_greedily(items=[], weights=[], num_bins=1) == ({}, {0: 0})
|
||||
|
||||
# one item, one bin
|
||||
bin_contents, bin_weights = _pack_greedily(items=["A"], weights=[1], num_bins=1)
|
||||
assert bin_contents == {0: ["A"]}
|
||||
assert bin_weights == {0: 1}
|
||||
|
||||
# more bins than items
|
||||
bin_contents, bin_weights = _pack_greedily(items=["A"], weights=[1], num_bins=3)
|
||||
assert bin_contents == {0: ["A"]}
|
||||
assert bin_weights == {0: 1, 1: 0, 2: 0}
|
||||
|
||||
# items with equal weight
|
||||
bin_contents, bin_weights = _pack_greedily(items=["A", "B", "C", "D"], weights=[3, 3, 3, 3], num_bins=4)
|
||||
assert bin_contents == {0: ["A"], 1: ["B"], 2: ["C"], 3: ["D"]}
|
||||
assert bin_weights == {0: 3, 1: 3, 2: 3, 3: 3}
|
||||
|
||||
# pigeonhole principle: more items than bins
|
||||
bin_contents, bin_weights = _pack_greedily(items=["A", "B", "C", "D"], weights=[1, 1, 1, 1], num_bins=3)
|
||||
assert bin_contents == {0: ["A", "D"], 1: ["B"], 2: ["C"]}
|
||||
assert bin_weights == {0: 2, 1: 1, 2: 1}
|
||||
|
||||
bin_contents, bin_weights = _pack_greedily(
|
||||
items=["A", "B", "C", "D", "E", "F", "G", "H", "I"],
|
||||
weights=[4, 1, 2, 5, 8, 7, 3, 6, 9],
|
||||
num_bins=3,
|
||||
)
|
||||
assert bin_contents == {0: ["I", "A", "G"], 1: ["E", "D", "C"], 2: ["F", "H", "B"]}
|
||||
assert bin_weights == {0: 16, 1: 15, 2: 14}
|
Loading…
Reference in New Issue