diff --git a/src/lightning/data/streaming/writer.py b/src/lightning/data/streaming/writer.py index 59234392b0..126bd907e8 100644 --- a/src/lightning/data/streaming/writer.py +++ b/src/lightning/data/streaming/writer.py @@ -13,6 +13,7 @@ import json import os +import warnings from dataclasses import dataclass from time import sleep from typing import Any, Dict, List, Optional, Tuple, Union @@ -24,6 +25,7 @@ from lightning.data.datasets.env import _DistributedEnv, _WorkerEnv from lightning.data.streaming.compression import _COMPRESSORS, Compressor from lightning.data.streaming.constants import _INDEX_FILENAME, _TORCH_GREATER_EQUAL_2_1_0 from lightning.data.streaming.serializers import _SERIALIZERS, Serializer +from lightning.data.utilities.format import _human_readable_bytes if _TORCH_GREATER_EQUAL_2_1_0: from torch.utils._pytree import PyTree, tree_flatten, treespec_dumps @@ -42,7 +44,7 @@ _FORMAT_TO_RATIO = { def _convert_bytes_to_int(bytes_str: str) -> int: - """Convert human readable byte format to an integer.""" + """Convert human-readable byte format to an integer.""" for suffix in _FORMAT_TO_RATIO: bytes_str = bytes_str.lower().strip() if bytes_str.lower().endswith(suffix): @@ -50,12 +52,8 @@ def _convert_bytes_to_int(bytes_str: str) -> int: return int(float(bytes_str[0 : -len(suffix)]) * _FORMAT_TO_RATIO[suffix]) except ValueError: raise ValueError( - "".join( - [ - f"Unsupported value/suffix {bytes_str}. Supported suffix are ", - f'{["b"] + list(_FORMAT_TO_RATIO.keys())}.', - ] - ) + f"Unsupported value/suffix {bytes_str}. Supported suffix are " + f'{["b"] + list(_FORMAT_TO_RATIO.keys())}.' ) raise ValueError(f"The supported units are {_FORMAT_TO_RATIO.keys()}") @@ -212,21 +210,20 @@ class BinaryWriter: def _create_chunk(self, filename: str, on_done: bool = False) -> bytes: """Create a binary chunk from all the binarized items.""" + items = [] + if on_done: indices = sorted(self._serialized_items.keys()) for i in range(len(indices) - 1): assert indices[i] == indices[i + 1] - 1, indices - min_index = indices[0] - max_index = indices[-1] + 1 - num_items = np.uint32(max_index - min_index) items = [self._serialized_items.pop(index) for index in indices] else: assert self._max_index is not None, (self._max_index, self._min_index) assert self._min_index is not None, (self._max_index, self._min_index) - num_items = np.uint32(self._max_index - self._min_index) - items = [self._serialized_items.pop(index) for index in range(self._min_index, self._max_index)] - min_index = self._min_index - max_index = self._max_index + if self._max_index == self._min_index: + # A single item is larger than the target chunk size; allow the chunk to be bigger than the target size + items.append(self._serialized_items.pop(self._max_index)) + items.extend(self._serialized_items.pop(index) for index in range(self._min_index, self._max_index)) if len(items) == 0: raise RuntimeError( @@ -234,17 +231,21 @@ class BinaryWriter: f" Found {self._pretty_serialized_items()} with boundaries: {self._min_index}, {self._max_index}." ) + num_items = np.uint32(len(items)) sizes = list(map(len, items)) offsets = np.array([0] + sizes).cumsum().astype(np.uint32) offsets += len(num_items.tobytes()) + len(offsets.tobytes()) sample_data = b"".join([item.data for item in items]) data = num_items.tobytes() + offsets.tobytes() + sample_data - offsets = offsets.tolist() current_chunk_bytes = sum([item.bytes for item in items]) - if self._chunk_bytes: - assert current_chunk_bytes <= self._chunk_bytes + if self._chunk_bytes and current_chunk_bytes > self._chunk_bytes: + warnings.warn( + f"An item was larger than the target chunk size ({_human_readable_bytes(self._chunk_bytes)})." + f" The current chunk will be {_human_readable_bytes(current_chunk_bytes)} in size.", + UserWarning, + ) if self._chunk_size: assert num_items.item() <= self._chunk_size @@ -308,6 +309,7 @@ class BinaryWriter: return filepath def _should_write(self) -> bool: + # TODO: Misleading method name, it modifies `self._min_index` and `self._max_index`! if not self._serialized_items: return False indexes = list(self._serialized_items.keys()) diff --git a/src/lightning/data/utilities/__init__.py b/src/lightning/data/utilities/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/src/lightning/data/utilities/format.py b/src/lightning/data/utilities/format.py new file mode 100644 index 0000000000..4492661b2e --- /dev/null +++ b/src/lightning/data/utilities/format.py @@ -0,0 +1,6 @@ +def _human_readable_bytes(num_bytes: float) -> str: + for unit in ("B", "KB", "MB", "GB", "TB"): + if abs(num_bytes) < 1000.0: + return f"{num_bytes:3.1f} {unit}" + num_bytes /= 1000.0 + return f"{num_bytes:.1f} PB" diff --git a/tests/tests_data/streaming/test_cache.py b/tests/tests_data/streaming/test_cache.py index 057872d553..41317735be 100644 --- a/tests/tests_data/streaming/test_cache.py +++ b/tests/tests_data/streaming/test_cache.py @@ -10,7 +10,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - +import json import os import sys from functools import partial @@ -27,6 +27,7 @@ from lightning.data.streaming.item_loader import TokensLoader from lightning.fabric import Fabric from lightning.pytorch.demos.boring_classes import RandomDataset from lightning_utilities.core.imports import RequirementCache +from lightning_utilities.test.warning import no_warning_call from torch.utils.data import DataLoader, Dataset _PIL_AVAILABLE = RequirementCache("PIL") @@ -242,3 +243,35 @@ def test_streaming_dataset(tmpdir, monkeypatch): dataloader = DataLoader(dataset, num_workers=2, batch_size=2) assert len(dataloader) == 408 + + +def test_create_oversized_chunk_single_item(tmp_path): + cache = Cache(str(tmp_path), chunk_bytes=700) + with pytest.warns(UserWarning, match="An item was larger than the target chunk size"): + cache[0] = np.random.randint(0, 10, size=(10000,), dtype=np.uint8) + + +def test_create_undersized_and_oversized_chunk(tmp_path): + cache = Cache(str(tmp_path), chunk_bytes=9000) # target: 9KB chunks + with no_warning_call(UserWarning): + cache[0] = np.random.randint(0, 10, size=(500,), dtype=np.uint8) # will result in undersized chunk + cache[1] = np.random.randint(0, 10, size=(10000,), dtype=np.uint8) # will result in oversized chunk + with pytest.warns(UserWarning, match="An item was larger than the target chunk size"): + cache[2] = np.random.randint(0, 10, size=(150,), dtype=np.uint8) + with no_warning_call(UserWarning): + cache[3] = np.random.randint(0, 10, size=(200,), dtype=np.uint8) + + cache.done() + cache.merge() + + assert len(os.listdir(tmp_path)) == 4 # 3 chunks + 1 index file + with open(tmp_path / "index.json") as file: + index = json.load(file) + + chunks = index["chunks"] + assert chunks[0]["chunk_size"] == 1 + assert chunks[0]["filename"] == "chunk-0-0.bin" + assert chunks[1]["chunk_size"] == 1 + assert chunks[1]["filename"] == "chunk-0-1.bin" + assert chunks[2]["chunk_size"] == 2 + assert chunks[2]["filename"] == "chunk-0-2.bin" diff --git a/tests/tests_data/utilities/test_format.py b/tests/tests_data/utilities/test_format.py new file mode 100644 index 0000000000..38bb7d447e --- /dev/null +++ b/tests/tests_data/utilities/test_format.py @@ -0,0 +1,18 @@ +from lightning.data.utilities.format import _human_readable_bytes + + +def test_human_readable_bytes(): + assert _human_readable_bytes(0) == "0.0 B" + assert _human_readable_bytes(1) == "1.0 B" + assert _human_readable_bytes(999) == "999.0 B" + assert _human_readable_bytes(int(1e3)) == "1.0 KB" + assert _human_readable_bytes(int(1e3 + 1e2)) == "1.1 KB" + assert _human_readable_bytes(int(1e6)) == "1.0 MB" + assert _human_readable_bytes(int(1e6 + 2e5)) == "1.2 MB" + assert _human_readable_bytes(int(1e9)) == "1.0 GB" + assert _human_readable_bytes(int(1e9 + 3e8)) == "1.3 GB" + assert _human_readable_bytes(int(1e12)) == "1.0 TB" + assert _human_readable_bytes(int(1e12 + 4e11)) == "1.4 TB" + assert _human_readable_bytes(int(1e15)) == "1.0 PB" + assert _human_readable_bytes(int(1e15 + 5e14)) == "1.5 PB" + assert _human_readable_bytes(int(1e18)) == "1000.0 PB"