Resolve Item Loader bugs (#19017)

Co-authored-by: thomas <thomas@thomass-MacBook-Pro.local>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
thomas chaton 2023-11-16 18:06:58 -05:00 committed by GitHub
parent 3d448ac48d
commit 6e517bd55b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 319 additions and 33 deletions

View File

@ -11,7 +11,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import copy
import hashlib
import os
from typing import Any, Dict, List, Optional, Union
@ -84,12 +83,16 @@ class StreamingDataset(IterableDataset):
def _create_cache(self, worker_env: _WorkerEnv) -> Cache:
env = Environment(dist_env=self.distributed_env, worker_env=worker_env)
cache_path = _try_create_cache_dir(input_dir=self.input_dir.path, shard_rank=env.shard_rank)
cache_dir = copy.deepcopy(self.input_dir)
if cache_path:
cache_dir.path = cache_path
cache = Cache(input_dir=cache_dir, item_loader=self.item_loader, chunk_bytes=1, serializers=self.serializers)
# TODO: Move this to lightning-cloud
if "this_" not in self.input_dir.path:
cache_path = _try_create_cache_dir(input_dir=self.input_dir.path, shard_rank=env.shard_rank)
if cache_path is not None:
self.input_dir.path = cache_path
cache = Cache(
input_dir=self.input_dir, item_loader=self.item_loader, chunk_bytes=1, serializers=self.serializers
)
cache._reader._try_load_config()
if not cache.filled:
@ -136,6 +139,7 @@ class StreamingDataset(IterableDataset):
self.current_indexes = []
self.chunk_index = 0
self.index = 0
self.has_triggered_download = False
return self
@ -167,6 +171,8 @@ class StreamingDataset(IterableDataset):
self.current_indexes = self.shuffler(current_indexes)
self.chunk_index += 1
last_index = self.chunk_index == len(self.worker_intervals) and len(self.current_indexes) == 1
# Get the first index
index = self.current_indexes.pop(0)
@ -175,7 +181,9 @@ class StreamingDataset(IterableDataset):
ChunkedIndex(
index=index,
chunk_index=self.worker_chunks[self.chunk_index - 1],
# We provide the chunks indexes only one the first
chunk_indexes=None if self.has_triggered_download else self.worker_chunks,
last_index=last_index,
)
)

View File

@ -42,7 +42,7 @@ def _get_input_dir(inputs: Sequence[Any]) -> str:
if len(indexed_paths) == 0:
raise ValueError(f"The provided item {inputs[0]} didn't contain any filepaths.")
absolute_path = str(Path(indexed_paths[0]).resolve())
absolute_path = str(Path(list(indexed_paths.values())[0]).resolve())
if indexed_paths[0] != absolute_path:
raise ValueError("The provided path should be absolute.")
@ -128,6 +128,7 @@ def map(
num_nodes: Optional[int] = None,
machine: Optional[str] = None,
num_downloaders: Optional[int] = None,
reorder_files: bool = True,
) -> None:
"""This function map a callbable over a collection of files possibly in a distributed way.
@ -141,6 +142,8 @@ def map(
num_nodes: When doing remote execution, the number of nodes to use.
machine: When doing remote execution, the machine to use.
num_downloaders: The number of downloaders per worker.
reorder_files: By default, reorders the files by file size to distribute work equally among all workers.
Set this to ``False`` if the order in which samples are processed should be preserved.
"""
if not isinstance(inputs, Sequence):
@ -168,6 +171,7 @@ def map(
num_workers=num_workers or os.cpu_count(),
fast_dev_run=fast_dev_run,
num_downloaders=num_downloaders,
reorder_files=reorder_files,
)
return data_processor.run(LambdaDataTransformRecipe(fn, inputs))
return _execute(
@ -189,6 +193,7 @@ def optimize(
num_nodes: Optional[int] = None,
machine: Optional[str] = None,
num_downloaders: Optional[int] = None,
reorder_files: bool = True,
) -> None:
"""This function converts a dataset into chunks possibly in a distributed way.
@ -205,6 +210,8 @@ def optimize(
num_nodes: When doing remote execution, the number of nodes to use.
machine: When doing remote execution, the machine to use.
num_downloaders: The number of downloaders per worker.
reorder_files: By default, reorders the files by file size to distribute work equally among all workers.
Set this to ``False`` if the order in which samples are processed should be preserved.
"""
if not isinstance(inputs, Sequence):
@ -235,6 +242,7 @@ def optimize(
num_workers=num_workers or os.cpu_count(),
fast_dev_run=fast_dev_run,
num_downloaders=num_downloaders,
reorder_files=reorder_files,
)
return data_processor.run(
LambdaDataChunkRecipe(

View File

@ -110,7 +110,6 @@ class TokensLoader(BaseItemLoader):
super().__init__()
self._block_size = block_size
self._intervals: List[Tuple[int, int]] = []
self._mmaps: Dict[int, np.memmap] = {}
self._buffers: Dict[int, bytes] = {}
self._dtype: Optional[torch.dtype] = None
@ -123,16 +122,16 @@ class TokensLoader(BaseItemLoader):
raise ValueError("The provided chunks isn't properly setup.")
def generate_intervals(self) -> List[Tuple[int, int]]:
intervals = []
begin = 0
end = 0
for chunk in self._chunks:
dim = chunk["dim"]
num_blocks = dim // self._block_size
end += num_blocks
self._intervals.append((begin, end))
intervals.append((begin, end))
begin += num_blocks
return self._intervals
return intervals
def load_item_from_chunk(self, index: int, chunk_index: int, chunk_filepath: str, begin: int) -> torch.Tensor:
if chunk_filepath in self._chunk_filepaths and not os.path.isfile(chunk_filepath):
@ -149,6 +148,10 @@ class TokensLoader(BaseItemLoader):
if chunk_index not in self._mmaps:
# TODO: Add deletion and memmap close
chunk = self._chunks[chunk_index]
# Skip the header
# The number of items + the number of offsets (number of items in the chunk + 1)
# multiplied by the header encoding dtype (np.uint32)
offset = (1 + chunk["chunk_size"] + 1) * 4
mmap = np.memmap(chunk_filepath, mode="r", order="C", offset=offset)
self._mmaps[chunk_index] = mmap
@ -157,5 +160,5 @@ class TokensLoader(BaseItemLoader):
assert self._dtype
buffer: bytes = self._buffers[chunk_index]
offset = self._dtype.itemsize * ((index - begin) if index >= begin else index + 1)
offset = self._dtype.itemsize * (index - begin) * self._block_size
return torch.frombuffer(buffer, dtype=self._dtype, count=self._block_size, offset=offset)

View File

@ -46,6 +46,7 @@ class PrepareChunksThread(Thread):
self._processed_chunks_counter = 0
self._delete_chunks = 0
self._pre_download = pre_download
self._should_stop = False
def download(self, chunk_indices: List[int]) -> None:
"""Receive the list of the chunk indices to download for the current epoch."""
@ -69,12 +70,28 @@ class PrepareChunksThread(Thread):
if os.path.exists(chunk_filepath):
os.remove(chunk_filepath)
def stop(self) -> None:
"""Receive the list of the chunk indices to download for the current epoch."""
with self._lock:
self._should_stop = True
def run(self) -> None:
while True:
with self._lock:
if self._should_stop:
if (
self._max_cache_size
and self._max_cache_size <= shutil.disk_usage(self._config._cache_dir).total
):
for chunk_index in self._chunks_index_to_be_deleted:
if chunk_index not in self._chunks_index_to_be_downloaded:
self._delete(chunk_index)
self._delete_chunks += 1
self._processed_chunks_counter = 0
return
# Wait for something to do
if len(self._chunks_index_to_be_downloaded) == 0 and len(self._chunks_index_to_be_deleted) == 0:
sleep(0.01)
continue
# Delete the chunks if we are missing disk space.
@ -93,7 +110,7 @@ class PrepareChunksThread(Thread):
# If we have already downloaded too many chunks, let's wait for processed chunks to catch up
if self._max_cache_size and (self._downloaded_chunks - self._processed_chunks) > self._pre_download:
sleep(0.01)
sleep(0.1)
continue
chunk_index = self._chunks_index_to_be_downloaded.pop(0)
@ -101,6 +118,9 @@ class PrepareChunksThread(Thread):
self._config.download_chunk_from_index(chunk_index)
self._downloaded_chunks += 1
# Sleep to release the lock
sleep(0.1)
class BinaryReader:
def __init__(
@ -141,7 +161,6 @@ class BinaryReader:
self._rank: Optional[int] = None
self._config: Optional[ChunksConfig] = None
self._prepare_thread: Optional[PrepareChunksThread] = None
self._chunks_index_to_be_downloaded: List[int] = []
self._item_loader = item_loader or PyTreeLoader()
self._last_chunk_index: Optional[int] = None
self._max_cache_size = int(os.getenv("MAX_CACHE_SIZE", max_cache_size))
@ -193,7 +212,6 @@ class BinaryReader:
self._prepare_thread = PrepareChunksThread(self._config, self._max_cache_size)
self._prepare_thread.start()
if index.chunk_indexes:
self._chunks_index_to_be_downloaded.extend(index.chunk_indexes)
self._prepare_thread.download(index.chunk_indexes)
# If the chunk_index isn't already in the download and delete queues, add it.
@ -208,7 +226,13 @@ class BinaryReader:
# Fetch the element
chunk_filepath, begin, _ = self.config[index]
return self._item_loader.load_item_from_chunk(index.index, index.chunk_index, chunk_filepath, begin)
item = self._item_loader.load_item_from_chunk(index.index, index.chunk_index, chunk_filepath, begin)
if index.last_index and self._prepare_thread:
self._prepare_thread.stop()
self._prepare_thread = None
return item
def get_length(self) -> int:
"""Get the number of samples across all chunks."""

View File

@ -25,6 +25,7 @@ class ChunkedIndex:
index: int
chunk_index: int
chunk_indexes: Optional[List[int]] = None
last_index: bool = False
class CacheBatchSampler:

View File

@ -13,7 +13,7 @@
from abc import ABC, abstractmethod
from functools import lru_cache
from typing import Any, List
from typing import Any, List, Tuple
import numpy as np
@ -58,15 +58,11 @@ class NoShuffle(Shuffle):
@lru_cache(maxsize=10)
def get_chunks_and_intervals_per_ranks(self, distributed_env: _DistributedEnv, current_epoch: int) -> Any:
self.random_state = np.random.RandomState(seed=self.seed + current_epoch) # type: ignore
chunk_intervals = self.cache.get_chunk_intervals()
indexes = list(range(len(chunk_intervals)))
shuffled_chunk_intervals = np.asarray(chunk_intervals)[indexes]
chunks_per_ranks: List[List[int]] = [[] for _ in range(distributed_env.world_size)]
intervals_per_ranks: List[List[List[int]]] = [[] for _ in range(distributed_env.world_size)]
for index, (chunk_index, chunk_interval) in enumerate(zip(indexes, shuffled_chunk_intervals)):
replica_index = index % distributed_env.world_size
intervals_per_ranks: List[List[Tuple]] = [[] for _ in range(distributed_env.world_size)]
for chunk_index, chunk_interval in enumerate(chunk_intervals):
replica_index = chunk_index % distributed_env.world_size
chunks_per_ranks[replica_index].append(chunk_index)
intervals_per_ranks[replica_index].append(chunk_interval)

View File

@ -22,6 +22,7 @@ from lightning import seed_everything
from lightning.data.streaming import Cache
from lightning.data.streaming.dataloader import StreamingDataLoader
from lightning.data.streaming.dataset import StreamingDataset
from lightning.data.streaming.item_loader import TokensLoader
from lightning.data.streaming.serializers import Serializer
from lightning.data.utilities.env import _DistributedEnv
from lightning.fabric import Fabric
@ -276,3 +277,37 @@ def test_custom_serializer(tmpdir):
cache.done()
cache.merge()
assert isinstance(cache[0][0], bytes)
def test_cache_for_text_tokens(tmpdir):
seed_everything(42)
block_size = 1024 + 1
cache = Cache(input_dir=str(tmpdir), chunk_size=block_size * 11, item_loader=TokensLoader(block_size))
text_idxs_list = []
counter = 0
while True:
text_ids = torch.randint(0, 1000, (np.random.randint(0, 1000),)).to(torch.int)
text_idxs_list.append(text_ids)
chunk_filepath = cache._add_item(counter, text_ids)
if chunk_filepath:
break
counter += 1
cache.done()
cache.merge()
assert len(cache) == 10
cache_0 = cache[0]
cache_1 = cache[1]
assert len(cache_0) == block_size
assert len(cache_1) == block_size
assert not torch.equal(cache_0, cache[1])
indices = torch.cat(text_idxs_list, dim=0)
assert torch.equal(cache_0, indices[: len(cache_0)])
assert torch.equal(cache_1, indices[len(cache_0) : len(cache_0) + len(cache_1)])
with pytest.raises(ValueError, match="TokensLoader"):
len(Cache(str(tmpdir), chunk_size=block_size * 11))

View File

@ -15,10 +15,13 @@ import os
import sys
from unittest import mock
import numpy as np
import pytest
import torch
from lightning import seed_everything
from lightning.data.streaming import Cache
from lightning.data.streaming import Cache, functions
from lightning.data.streaming.dataset import StreamingDataset, _try_create_cache_dir
from lightning.data.streaming.item_loader import TokensLoader
from lightning.data.streaming.shuffle import FullShuffle, NoShuffle
from lightning.data.utilities.env import _DistributedEnv
from torch.utils.data import DataLoader
@ -326,3 +329,208 @@ def test_try_create_cache_dir():
assert _try_create_cache_dir("dir", shard_rank=2) == os.path.join(
"/cache", "chunks", "736007832d2167baaae763fd3a3f3cf1", "2"
)
def test_dataset_for_text_tokens(tmpdir):
seed_everything(42)
block_size = 1024 + 1
cache = Cache(input_dir=str(tmpdir), chunk_size=block_size * 11, item_loader=TokensLoader(block_size))
text_idxs_list = []
counter = 0
while True:
text_ids = torch.randint(0, 1000, (np.random.randint(0, 1000),)).to(torch.int)
text_idxs_list.append(text_ids)
chunk_filepath = cache._add_item(counter, text_ids)
if chunk_filepath:
break
counter += 1
cache.done()
cache.merge()
dataset = StreamingDataset(input_dir=str(tmpdir), item_loader=TokensLoader(block_size))
assert len(dataset) == 10
cache_0 = dataset[0]
cache_1 = dataset[1]
cache_2 = dataset[2]
cache_3 = dataset[3]
assert len(cache_0) == block_size
assert len(cache_1) == block_size
assert not torch.equal(cache_0, cache[1])
indices = torch.cat(text_idxs_list, dim=0)
assert torch.equal(cache_0, indices[: len(cache_0)])
assert torch.equal(cache_1, indices[len(cache_0) : len(cache_0) + len(cache_1)])
dataloader = DataLoader(StreamingDataset(input_dir=str(tmpdir), item_loader=TokensLoader(block_size)), batch_size=2)
for batch_idx, batch in enumerate(dataloader):
if batch_idx == 0:
assert torch.equal(torch.stack([cache_0, cache_1]), batch)
elif batch_idx == 1:
assert torch.equal(torch.stack([cache_2, cache_3]), batch)
else:
break
def test_dataset_for_text_tokens_multiple_workers(tmpdir):
seed_everything(42)
block_size = 10
cache = Cache(input_dir=str(tmpdir), chunk_size=40, item_loader=TokensLoader(block_size))
counter = 0
for i in range(10):
text_ids = torch.arange(counter, counter + 20).to(torch.int)
cache[i] = text_ids
counter += 20
cache.done()
cache.merge()
for i in range(20):
sequence = cache[i]
assert sequence[0].item() == i * block_size
assert sequence[-1].item() == (i + 1) * block_size - 1
assert len(os.listdir(tmpdir)) == 6
dataset = StreamingDataset(input_dir=str(tmpdir), item_loader=TokensLoader(block_size), shuffle=False)
assert len(dataset) == 20
dataloader = DataLoader(dataset, batch_size=2, num_workers=2, shuffle=False)
assert len(dataloader) == 10
expected = [
[0, 10],
[40, 50],
[20, 30],
[60, 70],
[80, 90],
[120, 130],
[100, 110],
[140, 150],
[160, 170],
[180, 190],
]
for result, batch in zip(expected, dataloader):
assert [batch[0][0].item(), batch[1][0].item()] == result
def test_dataset_for_text_tokens_distributed_num_workers(tmpdir):
seed_everything(42)
block_size = 10
cache = Cache(input_dir=str(tmpdir), chunk_size=40, item_loader=TokensLoader(block_size))
counter = 0
for i in range(10):
text_ids = torch.arange(counter, counter + 20).to(torch.int)
cache[i] = text_ids
counter += 20
cache.done()
cache.merge()
for i in range(20):
sequence = cache[i]
assert sequence[0].item() == i * block_size
assert sequence[-1].item() == (i + 1) * block_size - 1
assert len([f for f in os.listdir(tmpdir) if f.endswith(".bin")]) == 5
dataset = StreamingDataset(input_dir=str(tmpdir), item_loader=TokensLoader(block_size), shuffle=False)
assert len(dataset) == 20
dataset.distributed_env = _DistributedEnv(2, 0)
dataloader = DataLoader(dataset, batch_size=2, shuffle=False, num_workers=2)
assert len(dataloader) == 6
expected = [[0, 10], [80, 90], [20, 30], [100, 110], [160, 170], [180, 190]]
for batch_idx, batch in enumerate(dataloader):
assert [batch[0][0].item(), batch[1][0].item()] == expected[batch_idx]
dataset.distributed_env = _DistributedEnv(2, 1)
dataloader = DataLoader(dataset, batch_size=2, shuffle=False)
assert len(dataloader) == 4
expected = [[40, 50], [60, 70], [120, 130], [140, 150]]
for batch_idx, batch in enumerate(dataloader):
assert [batch[0][0].item(), batch[1][0].item()] == expected[batch_idx]
for batch_idx, batch in enumerate(dataloader):
assert [batch[0][0].item(), batch[1][0].item()] == expected[batch_idx]
def optimize_fn(item):
return torch.arange(item[0], item[0] + 20).to(torch.int)
def test_dataset_for_text_tokens_distributed_num_workers_end_to_end(tmpdir, monkeypatch):
monkeypatch.setattr(functions, "_get_input_dir", lambda x: str(tmpdir))
seed_everything(42)
with open(tmpdir / "a.txt", "w") as f:
f.write("hello")
inputs = [(v, str(tmpdir / "a.txt")) for v in range(0, 200, 20)]
cache_dir = os.path.join(tmpdir, "cache")
output_dir = os.path.join(tmpdir, "target_dir")
os.makedirs(output_dir, exist_ok=True)
monkeypatch.setenv("DATA_OPTIMIZER_CACHE_FOLDER", cache_dir)
monkeypatch.setenv("DATA_OPTIMIZER_DATA_CACHE_FOLDER", cache_dir)
def fn(item):
return torch.arange(item[0], item[0] + 20).to(torch.int)
functions.optimize(
optimize_fn, inputs, output_dir=str(tmpdir), num_workers=2, chunk_size=2, reorder_files=False, num_downloaders=1
)
assert len([f for f in os.listdir(tmpdir) if f.endswith(".bin")]) == 10
block_size = 10
dataset = StreamingDataset(input_dir=str(tmpdir), item_loader=TokensLoader(block_size), shuffle=False)
L = len(dataset)
assert len(dataset) == L
for i in range(L):
sequence = dataset[i]
assert sequence[0].item() == i * block_size
assert sequence[-1].item() == (i + 1) * block_size - 1
dataset = StreamingDataset(input_dir=str(tmpdir), item_loader=TokensLoader(block_size), shuffle=False)
dataset.distributed_env = _DistributedEnv(2, 0)
dataloader = DataLoader(dataset, batch_size=2, shuffle=False, num_workers=2)
assert len(dataloader) == 5
expected = [[0, 10], [40, 50], [80, 90], [120, 130], [160, 170]]
for batch_idx, batch in enumerate(dataloader):
assert [batch[0][0].item(), batch[1][0].item()] == expected[batch_idx]
dataset.distributed_env = _DistributedEnv(2, 1)
dataloader = DataLoader(dataset, batch_size=2, shuffle=False)
assert len(dataloader) == 5
expected = [[20, 30], [60, 70], [100, 110], [140, 150], [180, 190]]
for batch_idx, batch in enumerate(dataloader):
assert [batch[0][0].item(), batch[1][0].item()] == expected[batch_idx]

View File

@ -4,6 +4,7 @@ from unittest import mock
from lightning.data.streaming import reader
from lightning.data.streaming.cache import Cache
from lightning.data.streaming.config import ChunkedIndex
from lightning_cloud.resolver import Dir
@ -30,7 +31,8 @@ def test_reader_chunk_removal(tmpdir, monkeypatch):
os.makedirs(cache_dir, exist_ok=True)
for i in range(25):
assert cache[i] == i
index = ChunkedIndex(i, cache._get_chunk_index_from_index(i), last_index=i == 24)
assert cache[index] == i
assert len(os.listdir(cache_dir)) == 14
@ -46,7 +48,8 @@ def test_reader_chunk_removal(tmpdir, monkeypatch):
expected = []
for i in range(25):
expected.append([i, len(os.listdir(cache_dir))])
assert cache[i] == i
index = ChunkedIndex(i, cache._get_chunk_index_from_index(i), last_index=i == 24)
assert cache[index] == i
assert expected == [
[0, 0],
@ -70,10 +73,10 @@ def test_reader_chunk_removal(tmpdir, monkeypatch):
[18, 9],
[19, 10],
[20, 10],
[21, 11],
[22, 11], # Cleanup is triggered
[23, 2],
[24, 2],
[21, 2],
[22, 2],
[23, 3],
[24, 3],
]
assert len(os.listdir(cache_dir)) == 3
assert len(os.listdir(cache_dir)) in [3, 4]