Fix oversized items not fitting into a chunk (#18938)

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
Adrian Wälchli 2023-11-05 11:28:18 +01:00 committed by GitHub
parent 4af77c6bf1
commit 0e7a3b0b5f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 77 additions and 18 deletions

View File

@ -13,6 +13,7 @@
import json
import os
import warnings
from dataclasses import dataclass
from time import sleep
from typing import Any, Dict, List, Optional, Tuple, Union
@ -24,6 +25,7 @@ from lightning.data.datasets.env import _DistributedEnv, _WorkerEnv
from lightning.data.streaming.compression import _COMPRESSORS, Compressor
from lightning.data.streaming.constants import _INDEX_FILENAME, _TORCH_GREATER_EQUAL_2_1_0
from lightning.data.streaming.serializers import _SERIALIZERS, Serializer
from lightning.data.utilities.format import _human_readable_bytes
if _TORCH_GREATER_EQUAL_2_1_0:
from torch.utils._pytree import PyTree, tree_flatten, treespec_dumps
@ -42,7 +44,7 @@ _FORMAT_TO_RATIO = {
def _convert_bytes_to_int(bytes_str: str) -> int:
"""Convert human readable byte format to an integer."""
"""Convert human-readable byte format to an integer."""
for suffix in _FORMAT_TO_RATIO:
bytes_str = bytes_str.lower().strip()
if bytes_str.lower().endswith(suffix):
@ -50,12 +52,8 @@ def _convert_bytes_to_int(bytes_str: str) -> int:
return int(float(bytes_str[0 : -len(suffix)]) * _FORMAT_TO_RATIO[suffix])
except ValueError:
raise ValueError(
"".join(
[
f"Unsupported value/suffix {bytes_str}. Supported suffix are ",
f'{["b"] + list(_FORMAT_TO_RATIO.keys())}.',
]
)
f"Unsupported value/suffix {bytes_str}. Supported suffix are "
f'{["b"] + list(_FORMAT_TO_RATIO.keys())}.'
)
raise ValueError(f"The supported units are {_FORMAT_TO_RATIO.keys()}")
@ -212,21 +210,20 @@ class BinaryWriter:
def _create_chunk(self, filename: str, on_done: bool = False) -> bytes:
"""Create a binary chunk from all the binarized items."""
items = []
if on_done:
indices = sorted(self._serialized_items.keys())
for i in range(len(indices) - 1):
assert indices[i] == indices[i + 1] - 1, indices
min_index = indices[0]
max_index = indices[-1] + 1
num_items = np.uint32(max_index - min_index)
items = [self._serialized_items.pop(index) for index in indices]
else:
assert self._max_index is not None, (self._max_index, self._min_index)
assert self._min_index is not None, (self._max_index, self._min_index)
num_items = np.uint32(self._max_index - self._min_index)
items = [self._serialized_items.pop(index) for index in range(self._min_index, self._max_index)]
min_index = self._min_index
max_index = self._max_index
if self._max_index == self._min_index:
# A single item is larger than the target chunk size; allow the chunk to be bigger than the target size
items.append(self._serialized_items.pop(self._max_index))
items.extend(self._serialized_items.pop(index) for index in range(self._min_index, self._max_index))
if len(items) == 0:
raise RuntimeError(
@ -234,17 +231,21 @@ class BinaryWriter:
f" Found {self._pretty_serialized_items()} with boundaries: {self._min_index}, {self._max_index}."
)
num_items = np.uint32(len(items))
sizes = list(map(len, items))
offsets = np.array([0] + sizes).cumsum().astype(np.uint32)
offsets += len(num_items.tobytes()) + len(offsets.tobytes())
sample_data = b"".join([item.data for item in items])
data = num_items.tobytes() + offsets.tobytes() + sample_data
offsets = offsets.tolist()
current_chunk_bytes = sum([item.bytes for item in items])
if self._chunk_bytes:
assert current_chunk_bytes <= self._chunk_bytes
if self._chunk_bytes and current_chunk_bytes > self._chunk_bytes:
warnings.warn(
f"An item was larger than the target chunk size ({_human_readable_bytes(self._chunk_bytes)})."
f" The current chunk will be {_human_readable_bytes(current_chunk_bytes)} in size.",
UserWarning,
)
if self._chunk_size:
assert num_items.item() <= self._chunk_size
@ -308,6 +309,7 @@ class BinaryWriter:
return filepath
def _should_write(self) -> bool:
# TODO: Misleading method name, it modifies `self._min_index` and `self._max_index`!
if not self._serialized_items:
return False
indexes = list(self._serialized_items.keys())

View File

View File

@ -0,0 +1,6 @@
def _human_readable_bytes(num_bytes: float) -> str:
for unit in ("B", "KB", "MB", "GB", "TB"):
if abs(num_bytes) < 1000.0:
return f"{num_bytes:3.1f} {unit}"
num_bytes /= 1000.0
return f"{num_bytes:.1f} PB"

View File

@ -10,7 +10,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json
import os
import sys
from functools import partial
@ -27,6 +27,7 @@ from lightning.data.streaming.item_loader import TokensLoader
from lightning.fabric import Fabric
from lightning.pytorch.demos.boring_classes import RandomDataset
from lightning_utilities.core.imports import RequirementCache
from lightning_utilities.test.warning import no_warning_call
from torch.utils.data import DataLoader, Dataset
_PIL_AVAILABLE = RequirementCache("PIL")
@ -242,3 +243,35 @@ def test_streaming_dataset(tmpdir, monkeypatch):
dataloader = DataLoader(dataset, num_workers=2, batch_size=2)
assert len(dataloader) == 408
def test_create_oversized_chunk_single_item(tmp_path):
cache = Cache(str(tmp_path), chunk_bytes=700)
with pytest.warns(UserWarning, match="An item was larger than the target chunk size"):
cache[0] = np.random.randint(0, 10, size=(10000,), dtype=np.uint8)
def test_create_undersized_and_oversized_chunk(tmp_path):
cache = Cache(str(tmp_path), chunk_bytes=9000) # target: 9KB chunks
with no_warning_call(UserWarning):
cache[0] = np.random.randint(0, 10, size=(500,), dtype=np.uint8) # will result in undersized chunk
cache[1] = np.random.randint(0, 10, size=(10000,), dtype=np.uint8) # will result in oversized chunk
with pytest.warns(UserWarning, match="An item was larger than the target chunk size"):
cache[2] = np.random.randint(0, 10, size=(150,), dtype=np.uint8)
with no_warning_call(UserWarning):
cache[3] = np.random.randint(0, 10, size=(200,), dtype=np.uint8)
cache.done()
cache.merge()
assert len(os.listdir(tmp_path)) == 4 # 3 chunks + 1 index file
with open(tmp_path / "index.json") as file:
index = json.load(file)
chunks = index["chunks"]
assert chunks[0]["chunk_size"] == 1
assert chunks[0]["filename"] == "chunk-0-0.bin"
assert chunks[1]["chunk_size"] == 1
assert chunks[1]["filename"] == "chunk-0-1.bin"
assert chunks[2]["chunk_size"] == 2
assert chunks[2]["filename"] == "chunk-0-2.bin"

View File

@ -0,0 +1,18 @@
from lightning.data.utilities.format import _human_readable_bytes
def test_human_readable_bytes():
assert _human_readable_bytes(0) == "0.0 B"
assert _human_readable_bytes(1) == "1.0 B"
assert _human_readable_bytes(999) == "999.0 B"
assert _human_readable_bytes(int(1e3)) == "1.0 KB"
assert _human_readable_bytes(int(1e3 + 1e2)) == "1.1 KB"
assert _human_readable_bytes(int(1e6)) == "1.0 MB"
assert _human_readable_bytes(int(1e6 + 2e5)) == "1.2 MB"
assert _human_readable_bytes(int(1e9)) == "1.0 GB"
assert _human_readable_bytes(int(1e9 + 3e8)) == "1.3 GB"
assert _human_readable_bytes(int(1e12)) == "1.0 TB"
assert _human_readable_bytes(int(1e12 + 4e11)) == "1.4 TB"
assert _human_readable_bytes(int(1e15)) == "1.0 PB"
assert _human_readable_bytes(int(1e15 + 5e14)) == "1.5 PB"
assert _human_readable_bytes(int(1e18)) == "1000.0 PB"