Fix oversized items not fitting into a chunk (#18938)
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
parent
4af77c6bf1
commit
0e7a3b0b5f
|
@ -13,6 +13,7 @@
|
|||
|
||||
import json
|
||||
import os
|
||||
import warnings
|
||||
from dataclasses import dataclass
|
||||
from time import sleep
|
||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||
|
@ -24,6 +25,7 @@ from lightning.data.datasets.env import _DistributedEnv, _WorkerEnv
|
|||
from lightning.data.streaming.compression import _COMPRESSORS, Compressor
|
||||
from lightning.data.streaming.constants import _INDEX_FILENAME, _TORCH_GREATER_EQUAL_2_1_0
|
||||
from lightning.data.streaming.serializers import _SERIALIZERS, Serializer
|
||||
from lightning.data.utilities.format import _human_readable_bytes
|
||||
|
||||
if _TORCH_GREATER_EQUAL_2_1_0:
|
||||
from torch.utils._pytree import PyTree, tree_flatten, treespec_dumps
|
||||
|
@ -42,7 +44,7 @@ _FORMAT_TO_RATIO = {
|
|||
|
||||
|
||||
def _convert_bytes_to_int(bytes_str: str) -> int:
|
||||
"""Convert human readable byte format to an integer."""
|
||||
"""Convert human-readable byte format to an integer."""
|
||||
for suffix in _FORMAT_TO_RATIO:
|
||||
bytes_str = bytes_str.lower().strip()
|
||||
if bytes_str.lower().endswith(suffix):
|
||||
|
@ -50,12 +52,8 @@ def _convert_bytes_to_int(bytes_str: str) -> int:
|
|||
return int(float(bytes_str[0 : -len(suffix)]) * _FORMAT_TO_RATIO[suffix])
|
||||
except ValueError:
|
||||
raise ValueError(
|
||||
"".join(
|
||||
[
|
||||
f"Unsupported value/suffix {bytes_str}. Supported suffix are ",
|
||||
f'{["b"] + list(_FORMAT_TO_RATIO.keys())}.',
|
||||
]
|
||||
)
|
||||
f"Unsupported value/suffix {bytes_str}. Supported suffix are "
|
||||
f'{["b"] + list(_FORMAT_TO_RATIO.keys())}.'
|
||||
)
|
||||
raise ValueError(f"The supported units are {_FORMAT_TO_RATIO.keys()}")
|
||||
|
||||
|
@ -212,21 +210,20 @@ class BinaryWriter:
|
|||
|
||||
def _create_chunk(self, filename: str, on_done: bool = False) -> bytes:
|
||||
"""Create a binary chunk from all the binarized items."""
|
||||
items = []
|
||||
|
||||
if on_done:
|
||||
indices = sorted(self._serialized_items.keys())
|
||||
for i in range(len(indices) - 1):
|
||||
assert indices[i] == indices[i + 1] - 1, indices
|
||||
min_index = indices[0]
|
||||
max_index = indices[-1] + 1
|
||||
num_items = np.uint32(max_index - min_index)
|
||||
items = [self._serialized_items.pop(index) for index in indices]
|
||||
else:
|
||||
assert self._max_index is not None, (self._max_index, self._min_index)
|
||||
assert self._min_index is not None, (self._max_index, self._min_index)
|
||||
num_items = np.uint32(self._max_index - self._min_index)
|
||||
items = [self._serialized_items.pop(index) for index in range(self._min_index, self._max_index)]
|
||||
min_index = self._min_index
|
||||
max_index = self._max_index
|
||||
if self._max_index == self._min_index:
|
||||
# A single item is larger than the target chunk size; allow the chunk to be bigger than the target size
|
||||
items.append(self._serialized_items.pop(self._max_index))
|
||||
items.extend(self._serialized_items.pop(index) for index in range(self._min_index, self._max_index))
|
||||
|
||||
if len(items) == 0:
|
||||
raise RuntimeError(
|
||||
|
@ -234,17 +231,21 @@ class BinaryWriter:
|
|||
f" Found {self._pretty_serialized_items()} with boundaries: {self._min_index}, {self._max_index}."
|
||||
)
|
||||
|
||||
num_items = np.uint32(len(items))
|
||||
sizes = list(map(len, items))
|
||||
offsets = np.array([0] + sizes).cumsum().astype(np.uint32)
|
||||
offsets += len(num_items.tobytes()) + len(offsets.tobytes())
|
||||
sample_data = b"".join([item.data for item in items])
|
||||
data = num_items.tobytes() + offsets.tobytes() + sample_data
|
||||
offsets = offsets.tolist()
|
||||
|
||||
current_chunk_bytes = sum([item.bytes for item in items])
|
||||
|
||||
if self._chunk_bytes:
|
||||
assert current_chunk_bytes <= self._chunk_bytes
|
||||
if self._chunk_bytes and current_chunk_bytes > self._chunk_bytes:
|
||||
warnings.warn(
|
||||
f"An item was larger than the target chunk size ({_human_readable_bytes(self._chunk_bytes)})."
|
||||
f" The current chunk will be {_human_readable_bytes(current_chunk_bytes)} in size.",
|
||||
UserWarning,
|
||||
)
|
||||
|
||||
if self._chunk_size:
|
||||
assert num_items.item() <= self._chunk_size
|
||||
|
@ -308,6 +309,7 @@ class BinaryWriter:
|
|||
return filepath
|
||||
|
||||
def _should_write(self) -> bool:
|
||||
# TODO: Misleading method name, it modifies `self._min_index` and `self._max_index`!
|
||||
if not self._serialized_items:
|
||||
return False
|
||||
indexes = list(self._serialized_items.keys())
|
||||
|
|
|
@ -0,0 +1,6 @@
|
|||
def _human_readable_bytes(num_bytes: float) -> str:
|
||||
for unit in ("B", "KB", "MB", "GB", "TB"):
|
||||
if abs(num_bytes) < 1000.0:
|
||||
return f"{num_bytes:3.1f} {unit}"
|
||||
num_bytes /= 1000.0
|
||||
return f"{num_bytes:.1f} PB"
|
|
@ -10,7 +10,7 @@
|
|||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
from functools import partial
|
||||
|
@ -27,6 +27,7 @@ from lightning.data.streaming.item_loader import TokensLoader
|
|||
from lightning.fabric import Fabric
|
||||
from lightning.pytorch.demos.boring_classes import RandomDataset
|
||||
from lightning_utilities.core.imports import RequirementCache
|
||||
from lightning_utilities.test.warning import no_warning_call
|
||||
from torch.utils.data import DataLoader, Dataset
|
||||
|
||||
_PIL_AVAILABLE = RequirementCache("PIL")
|
||||
|
@ -242,3 +243,35 @@ def test_streaming_dataset(tmpdir, monkeypatch):
|
|||
|
||||
dataloader = DataLoader(dataset, num_workers=2, batch_size=2)
|
||||
assert len(dataloader) == 408
|
||||
|
||||
|
||||
def test_create_oversized_chunk_single_item(tmp_path):
|
||||
cache = Cache(str(tmp_path), chunk_bytes=700)
|
||||
with pytest.warns(UserWarning, match="An item was larger than the target chunk size"):
|
||||
cache[0] = np.random.randint(0, 10, size=(10000,), dtype=np.uint8)
|
||||
|
||||
|
||||
def test_create_undersized_and_oversized_chunk(tmp_path):
|
||||
cache = Cache(str(tmp_path), chunk_bytes=9000) # target: 9KB chunks
|
||||
with no_warning_call(UserWarning):
|
||||
cache[0] = np.random.randint(0, 10, size=(500,), dtype=np.uint8) # will result in undersized chunk
|
||||
cache[1] = np.random.randint(0, 10, size=(10000,), dtype=np.uint8) # will result in oversized chunk
|
||||
with pytest.warns(UserWarning, match="An item was larger than the target chunk size"):
|
||||
cache[2] = np.random.randint(0, 10, size=(150,), dtype=np.uint8)
|
||||
with no_warning_call(UserWarning):
|
||||
cache[3] = np.random.randint(0, 10, size=(200,), dtype=np.uint8)
|
||||
|
||||
cache.done()
|
||||
cache.merge()
|
||||
|
||||
assert len(os.listdir(tmp_path)) == 4 # 3 chunks + 1 index file
|
||||
with open(tmp_path / "index.json") as file:
|
||||
index = json.load(file)
|
||||
|
||||
chunks = index["chunks"]
|
||||
assert chunks[0]["chunk_size"] == 1
|
||||
assert chunks[0]["filename"] == "chunk-0-0.bin"
|
||||
assert chunks[1]["chunk_size"] == 1
|
||||
assert chunks[1]["filename"] == "chunk-0-1.bin"
|
||||
assert chunks[2]["chunk_size"] == 2
|
||||
assert chunks[2]["filename"] == "chunk-0-2.bin"
|
||||
|
|
|
@ -0,0 +1,18 @@
|
|||
from lightning.data.utilities.format import _human_readable_bytes
|
||||
|
||||
|
||||
def test_human_readable_bytes():
|
||||
assert _human_readable_bytes(0) == "0.0 B"
|
||||
assert _human_readable_bytes(1) == "1.0 B"
|
||||
assert _human_readable_bytes(999) == "999.0 B"
|
||||
assert _human_readable_bytes(int(1e3)) == "1.0 KB"
|
||||
assert _human_readable_bytes(int(1e3 + 1e2)) == "1.1 KB"
|
||||
assert _human_readable_bytes(int(1e6)) == "1.0 MB"
|
||||
assert _human_readable_bytes(int(1e6 + 2e5)) == "1.2 MB"
|
||||
assert _human_readable_bytes(int(1e9)) == "1.0 GB"
|
||||
assert _human_readable_bytes(int(1e9 + 3e8)) == "1.3 GB"
|
||||
assert _human_readable_bytes(int(1e12)) == "1.0 TB"
|
||||
assert _human_readable_bytes(int(1e12 + 4e11)) == "1.4 TB"
|
||||
assert _human_readable_bytes(int(1e15)) == "1.0 PB"
|
||||
assert _human_readable_bytes(int(1e15 + 5e14)) == "1.5 PB"
|
||||
assert _human_readable_bytes(int(1e18)) == "1000.0 PB"
|
Loading…
Reference in New Issue