Resolve Item Loader bugs (#19017)
Co-authored-by: thomas <thomas@thomass-MacBook-Pro.local> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
parent
3d448ac48d
commit
6e517bd55b
|
@ -11,7 +11,6 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import copy
|
||||
import hashlib
|
||||
import os
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
|
@ -84,12 +83,16 @@ class StreamingDataset(IterableDataset):
|
|||
|
||||
def _create_cache(self, worker_env: _WorkerEnv) -> Cache:
|
||||
env = Environment(dist_env=self.distributed_env, worker_env=worker_env)
|
||||
cache_path = _try_create_cache_dir(input_dir=self.input_dir.path, shard_rank=env.shard_rank)
|
||||
cache_dir = copy.deepcopy(self.input_dir)
|
||||
if cache_path:
|
||||
cache_dir.path = cache_path
|
||||
|
||||
cache = Cache(input_dir=cache_dir, item_loader=self.item_loader, chunk_bytes=1, serializers=self.serializers)
|
||||
# TODO: Move this to lightning-cloud
|
||||
if "this_" not in self.input_dir.path:
|
||||
cache_path = _try_create_cache_dir(input_dir=self.input_dir.path, shard_rank=env.shard_rank)
|
||||
if cache_path is not None:
|
||||
self.input_dir.path = cache_path
|
||||
|
||||
cache = Cache(
|
||||
input_dir=self.input_dir, item_loader=self.item_loader, chunk_bytes=1, serializers=self.serializers
|
||||
)
|
||||
cache._reader._try_load_config()
|
||||
|
||||
if not cache.filled:
|
||||
|
@ -136,6 +139,7 @@ class StreamingDataset(IterableDataset):
|
|||
self.current_indexes = []
|
||||
self.chunk_index = 0
|
||||
self.index = 0
|
||||
self.has_triggered_download = False
|
||||
|
||||
return self
|
||||
|
||||
|
@ -167,6 +171,8 @@ class StreamingDataset(IterableDataset):
|
|||
self.current_indexes = self.shuffler(current_indexes)
|
||||
self.chunk_index += 1
|
||||
|
||||
last_index = self.chunk_index == len(self.worker_intervals) and len(self.current_indexes) == 1
|
||||
|
||||
# Get the first index
|
||||
index = self.current_indexes.pop(0)
|
||||
|
||||
|
@ -175,7 +181,9 @@ class StreamingDataset(IterableDataset):
|
|||
ChunkedIndex(
|
||||
index=index,
|
||||
chunk_index=self.worker_chunks[self.chunk_index - 1],
|
||||
# We provide the chunks indexes only one the first
|
||||
chunk_indexes=None if self.has_triggered_download else self.worker_chunks,
|
||||
last_index=last_index,
|
||||
)
|
||||
)
|
||||
|
||||
|
|
|
@ -42,7 +42,7 @@ def _get_input_dir(inputs: Sequence[Any]) -> str:
|
|||
if len(indexed_paths) == 0:
|
||||
raise ValueError(f"The provided item {inputs[0]} didn't contain any filepaths.")
|
||||
|
||||
absolute_path = str(Path(indexed_paths[0]).resolve())
|
||||
absolute_path = str(Path(list(indexed_paths.values())[0]).resolve())
|
||||
|
||||
if indexed_paths[0] != absolute_path:
|
||||
raise ValueError("The provided path should be absolute.")
|
||||
|
@ -128,6 +128,7 @@ def map(
|
|||
num_nodes: Optional[int] = None,
|
||||
machine: Optional[str] = None,
|
||||
num_downloaders: Optional[int] = None,
|
||||
reorder_files: bool = True,
|
||||
) -> None:
|
||||
"""This function map a callbable over a collection of files possibly in a distributed way.
|
||||
|
||||
|
@ -141,6 +142,8 @@ def map(
|
|||
num_nodes: When doing remote execution, the number of nodes to use.
|
||||
machine: When doing remote execution, the machine to use.
|
||||
num_downloaders: The number of downloaders per worker.
|
||||
reorder_files: By default, reorders the files by file size to distribute work equally among all workers.
|
||||
Set this to ``False`` if the order in which samples are processed should be preserved.
|
||||
|
||||
"""
|
||||
if not isinstance(inputs, Sequence):
|
||||
|
@ -168,6 +171,7 @@ def map(
|
|||
num_workers=num_workers or os.cpu_count(),
|
||||
fast_dev_run=fast_dev_run,
|
||||
num_downloaders=num_downloaders,
|
||||
reorder_files=reorder_files,
|
||||
)
|
||||
return data_processor.run(LambdaDataTransformRecipe(fn, inputs))
|
||||
return _execute(
|
||||
|
@ -189,6 +193,7 @@ def optimize(
|
|||
num_nodes: Optional[int] = None,
|
||||
machine: Optional[str] = None,
|
||||
num_downloaders: Optional[int] = None,
|
||||
reorder_files: bool = True,
|
||||
) -> None:
|
||||
"""This function converts a dataset into chunks possibly in a distributed way.
|
||||
|
||||
|
@ -205,6 +210,8 @@ def optimize(
|
|||
num_nodes: When doing remote execution, the number of nodes to use.
|
||||
machine: When doing remote execution, the machine to use.
|
||||
num_downloaders: The number of downloaders per worker.
|
||||
reorder_files: By default, reorders the files by file size to distribute work equally among all workers.
|
||||
Set this to ``False`` if the order in which samples are processed should be preserved.
|
||||
|
||||
"""
|
||||
if not isinstance(inputs, Sequence):
|
||||
|
@ -235,6 +242,7 @@ def optimize(
|
|||
num_workers=num_workers or os.cpu_count(),
|
||||
fast_dev_run=fast_dev_run,
|
||||
num_downloaders=num_downloaders,
|
||||
reorder_files=reorder_files,
|
||||
)
|
||||
return data_processor.run(
|
||||
LambdaDataChunkRecipe(
|
||||
|
|
|
@ -110,7 +110,6 @@ class TokensLoader(BaseItemLoader):
|
|||
|
||||
super().__init__()
|
||||
self._block_size = block_size
|
||||
self._intervals: List[Tuple[int, int]] = []
|
||||
self._mmaps: Dict[int, np.memmap] = {}
|
||||
self._buffers: Dict[int, bytes] = {}
|
||||
self._dtype: Optional[torch.dtype] = None
|
||||
|
@ -123,16 +122,16 @@ class TokensLoader(BaseItemLoader):
|
|||
raise ValueError("The provided chunks isn't properly setup.")
|
||||
|
||||
def generate_intervals(self) -> List[Tuple[int, int]]:
|
||||
intervals = []
|
||||
begin = 0
|
||||
end = 0
|
||||
for chunk in self._chunks:
|
||||
dim = chunk["dim"]
|
||||
num_blocks = dim // self._block_size
|
||||
end += num_blocks
|
||||
self._intervals.append((begin, end))
|
||||
intervals.append((begin, end))
|
||||
begin += num_blocks
|
||||
|
||||
return self._intervals
|
||||
return intervals
|
||||
|
||||
def load_item_from_chunk(self, index: int, chunk_index: int, chunk_filepath: str, begin: int) -> torch.Tensor:
|
||||
if chunk_filepath in self._chunk_filepaths and not os.path.isfile(chunk_filepath):
|
||||
|
@ -149,6 +148,10 @@ class TokensLoader(BaseItemLoader):
|
|||
if chunk_index not in self._mmaps:
|
||||
# TODO: Add deletion and memmap close
|
||||
chunk = self._chunks[chunk_index]
|
||||
|
||||
# Skip the header
|
||||
# The number of items + the number of offsets (number of items in the chunk + 1)
|
||||
# multiplied by the header encoding dtype (np.uint32)
|
||||
offset = (1 + chunk["chunk_size"] + 1) * 4
|
||||
mmap = np.memmap(chunk_filepath, mode="r", order="C", offset=offset)
|
||||
self._mmaps[chunk_index] = mmap
|
||||
|
@ -157,5 +160,5 @@ class TokensLoader(BaseItemLoader):
|
|||
assert self._dtype
|
||||
|
||||
buffer: bytes = self._buffers[chunk_index]
|
||||
offset = self._dtype.itemsize * ((index - begin) if index >= begin else index + 1)
|
||||
offset = self._dtype.itemsize * (index - begin) * self._block_size
|
||||
return torch.frombuffer(buffer, dtype=self._dtype, count=self._block_size, offset=offset)
|
||||
|
|
|
@ -46,6 +46,7 @@ class PrepareChunksThread(Thread):
|
|||
self._processed_chunks_counter = 0
|
||||
self._delete_chunks = 0
|
||||
self._pre_download = pre_download
|
||||
self._should_stop = False
|
||||
|
||||
def download(self, chunk_indices: List[int]) -> None:
|
||||
"""Receive the list of the chunk indices to download for the current epoch."""
|
||||
|
@ -69,12 +70,28 @@ class PrepareChunksThread(Thread):
|
|||
if os.path.exists(chunk_filepath):
|
||||
os.remove(chunk_filepath)
|
||||
|
||||
def stop(self) -> None:
|
||||
"""Receive the list of the chunk indices to download for the current epoch."""
|
||||
with self._lock:
|
||||
self._should_stop = True
|
||||
|
||||
def run(self) -> None:
|
||||
while True:
|
||||
with self._lock:
|
||||
if self._should_stop:
|
||||
if (
|
||||
self._max_cache_size
|
||||
and self._max_cache_size <= shutil.disk_usage(self._config._cache_dir).total
|
||||
):
|
||||
for chunk_index in self._chunks_index_to_be_deleted:
|
||||
if chunk_index not in self._chunks_index_to_be_downloaded:
|
||||
self._delete(chunk_index)
|
||||
self._delete_chunks += 1
|
||||
self._processed_chunks_counter = 0
|
||||
return
|
||||
|
||||
# Wait for something to do
|
||||
if len(self._chunks_index_to_be_downloaded) == 0 and len(self._chunks_index_to_be_deleted) == 0:
|
||||
sleep(0.01)
|
||||
continue
|
||||
|
||||
# Delete the chunks if we are missing disk space.
|
||||
|
@ -93,7 +110,7 @@ class PrepareChunksThread(Thread):
|
|||
|
||||
# If we have already downloaded too many chunks, let's wait for processed chunks to catch up
|
||||
if self._max_cache_size and (self._downloaded_chunks - self._processed_chunks) > self._pre_download:
|
||||
sleep(0.01)
|
||||
sleep(0.1)
|
||||
continue
|
||||
|
||||
chunk_index = self._chunks_index_to_be_downloaded.pop(0)
|
||||
|
@ -101,6 +118,9 @@ class PrepareChunksThread(Thread):
|
|||
self._config.download_chunk_from_index(chunk_index)
|
||||
self._downloaded_chunks += 1
|
||||
|
||||
# Sleep to release the lock
|
||||
sleep(0.1)
|
||||
|
||||
|
||||
class BinaryReader:
|
||||
def __init__(
|
||||
|
@ -141,7 +161,6 @@ class BinaryReader:
|
|||
self._rank: Optional[int] = None
|
||||
self._config: Optional[ChunksConfig] = None
|
||||
self._prepare_thread: Optional[PrepareChunksThread] = None
|
||||
self._chunks_index_to_be_downloaded: List[int] = []
|
||||
self._item_loader = item_loader or PyTreeLoader()
|
||||
self._last_chunk_index: Optional[int] = None
|
||||
self._max_cache_size = int(os.getenv("MAX_CACHE_SIZE", max_cache_size))
|
||||
|
@ -193,7 +212,6 @@ class BinaryReader:
|
|||
self._prepare_thread = PrepareChunksThread(self._config, self._max_cache_size)
|
||||
self._prepare_thread.start()
|
||||
if index.chunk_indexes:
|
||||
self._chunks_index_to_be_downloaded.extend(index.chunk_indexes)
|
||||
self._prepare_thread.download(index.chunk_indexes)
|
||||
|
||||
# If the chunk_index isn't already in the download and delete queues, add it.
|
||||
|
@ -208,7 +226,13 @@ class BinaryReader:
|
|||
|
||||
# Fetch the element
|
||||
chunk_filepath, begin, _ = self.config[index]
|
||||
return self._item_loader.load_item_from_chunk(index.index, index.chunk_index, chunk_filepath, begin)
|
||||
item = self._item_loader.load_item_from_chunk(index.index, index.chunk_index, chunk_filepath, begin)
|
||||
|
||||
if index.last_index and self._prepare_thread:
|
||||
self._prepare_thread.stop()
|
||||
self._prepare_thread = None
|
||||
|
||||
return item
|
||||
|
||||
def get_length(self) -> int:
|
||||
"""Get the number of samples across all chunks."""
|
||||
|
|
|
@ -25,6 +25,7 @@ class ChunkedIndex:
|
|||
index: int
|
||||
chunk_index: int
|
||||
chunk_indexes: Optional[List[int]] = None
|
||||
last_index: bool = False
|
||||
|
||||
|
||||
class CacheBatchSampler:
|
||||
|
|
|
@ -13,7 +13,7 @@
|
|||
|
||||
from abc import ABC, abstractmethod
|
||||
from functools import lru_cache
|
||||
from typing import Any, List
|
||||
from typing import Any, List, Tuple
|
||||
|
||||
import numpy as np
|
||||
|
||||
|
@ -58,15 +58,11 @@ class NoShuffle(Shuffle):
|
|||
|
||||
@lru_cache(maxsize=10)
|
||||
def get_chunks_and_intervals_per_ranks(self, distributed_env: _DistributedEnv, current_epoch: int) -> Any:
|
||||
self.random_state = np.random.RandomState(seed=self.seed + current_epoch) # type: ignore
|
||||
chunk_intervals = self.cache.get_chunk_intervals()
|
||||
indexes = list(range(len(chunk_intervals)))
|
||||
shuffled_chunk_intervals = np.asarray(chunk_intervals)[indexes]
|
||||
|
||||
chunks_per_ranks: List[List[int]] = [[] for _ in range(distributed_env.world_size)]
|
||||
intervals_per_ranks: List[List[List[int]]] = [[] for _ in range(distributed_env.world_size)]
|
||||
for index, (chunk_index, chunk_interval) in enumerate(zip(indexes, shuffled_chunk_intervals)):
|
||||
replica_index = index % distributed_env.world_size
|
||||
intervals_per_ranks: List[List[Tuple]] = [[] for _ in range(distributed_env.world_size)]
|
||||
for chunk_index, chunk_interval in enumerate(chunk_intervals):
|
||||
replica_index = chunk_index % distributed_env.world_size
|
||||
chunks_per_ranks[replica_index].append(chunk_index)
|
||||
intervals_per_ranks[replica_index].append(chunk_interval)
|
||||
|
||||
|
|
|
@ -22,6 +22,7 @@ from lightning import seed_everything
|
|||
from lightning.data.streaming import Cache
|
||||
from lightning.data.streaming.dataloader import StreamingDataLoader
|
||||
from lightning.data.streaming.dataset import StreamingDataset
|
||||
from lightning.data.streaming.item_loader import TokensLoader
|
||||
from lightning.data.streaming.serializers import Serializer
|
||||
from lightning.data.utilities.env import _DistributedEnv
|
||||
from lightning.fabric import Fabric
|
||||
|
@ -276,3 +277,37 @@ def test_custom_serializer(tmpdir):
|
|||
cache.done()
|
||||
cache.merge()
|
||||
assert isinstance(cache[0][0], bytes)
|
||||
|
||||
|
||||
def test_cache_for_text_tokens(tmpdir):
|
||||
seed_everything(42)
|
||||
|
||||
block_size = 1024 + 1
|
||||
cache = Cache(input_dir=str(tmpdir), chunk_size=block_size * 11, item_loader=TokensLoader(block_size))
|
||||
text_idxs_list = []
|
||||
|
||||
counter = 0
|
||||
while True:
|
||||
text_ids = torch.randint(0, 1000, (np.random.randint(0, 1000),)).to(torch.int)
|
||||
text_idxs_list.append(text_ids)
|
||||
chunk_filepath = cache._add_item(counter, text_ids)
|
||||
if chunk_filepath:
|
||||
break
|
||||
counter += 1
|
||||
|
||||
cache.done()
|
||||
cache.merge()
|
||||
|
||||
assert len(cache) == 10
|
||||
|
||||
cache_0 = cache[0]
|
||||
cache_1 = cache[1]
|
||||
assert len(cache_0) == block_size
|
||||
assert len(cache_1) == block_size
|
||||
assert not torch.equal(cache_0, cache[1])
|
||||
indices = torch.cat(text_idxs_list, dim=0)
|
||||
assert torch.equal(cache_0, indices[: len(cache_0)])
|
||||
assert torch.equal(cache_1, indices[len(cache_0) : len(cache_0) + len(cache_1)])
|
||||
|
||||
with pytest.raises(ValueError, match="TokensLoader"):
|
||||
len(Cache(str(tmpdir), chunk_size=block_size * 11))
|
||||
|
|
|
@ -15,10 +15,13 @@ import os
|
|||
import sys
|
||||
from unittest import mock
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
import torch
|
||||
from lightning import seed_everything
|
||||
from lightning.data.streaming import Cache
|
||||
from lightning.data.streaming import Cache, functions
|
||||
from lightning.data.streaming.dataset import StreamingDataset, _try_create_cache_dir
|
||||
from lightning.data.streaming.item_loader import TokensLoader
|
||||
from lightning.data.streaming.shuffle import FullShuffle, NoShuffle
|
||||
from lightning.data.utilities.env import _DistributedEnv
|
||||
from torch.utils.data import DataLoader
|
||||
|
@ -326,3 +329,208 @@ def test_try_create_cache_dir():
|
|||
assert _try_create_cache_dir("dir", shard_rank=2) == os.path.join(
|
||||
"/cache", "chunks", "736007832d2167baaae763fd3a3f3cf1", "2"
|
||||
)
|
||||
|
||||
|
||||
def test_dataset_for_text_tokens(tmpdir):
|
||||
seed_everything(42)
|
||||
|
||||
block_size = 1024 + 1
|
||||
cache = Cache(input_dir=str(tmpdir), chunk_size=block_size * 11, item_loader=TokensLoader(block_size))
|
||||
text_idxs_list = []
|
||||
|
||||
counter = 0
|
||||
while True:
|
||||
text_ids = torch.randint(0, 1000, (np.random.randint(0, 1000),)).to(torch.int)
|
||||
text_idxs_list.append(text_ids)
|
||||
chunk_filepath = cache._add_item(counter, text_ids)
|
||||
if chunk_filepath:
|
||||
break
|
||||
counter += 1
|
||||
|
||||
cache.done()
|
||||
cache.merge()
|
||||
|
||||
dataset = StreamingDataset(input_dir=str(tmpdir), item_loader=TokensLoader(block_size))
|
||||
|
||||
assert len(dataset) == 10
|
||||
|
||||
cache_0 = dataset[0]
|
||||
cache_1 = dataset[1]
|
||||
cache_2 = dataset[2]
|
||||
cache_3 = dataset[3]
|
||||
assert len(cache_0) == block_size
|
||||
assert len(cache_1) == block_size
|
||||
assert not torch.equal(cache_0, cache[1])
|
||||
indices = torch.cat(text_idxs_list, dim=0)
|
||||
assert torch.equal(cache_0, indices[: len(cache_0)])
|
||||
assert torch.equal(cache_1, indices[len(cache_0) : len(cache_0) + len(cache_1)])
|
||||
|
||||
dataloader = DataLoader(StreamingDataset(input_dir=str(tmpdir), item_loader=TokensLoader(block_size)), batch_size=2)
|
||||
|
||||
for batch_idx, batch in enumerate(dataloader):
|
||||
if batch_idx == 0:
|
||||
assert torch.equal(torch.stack([cache_0, cache_1]), batch)
|
||||
elif batch_idx == 1:
|
||||
assert torch.equal(torch.stack([cache_2, cache_3]), batch)
|
||||
else:
|
||||
break
|
||||
|
||||
|
||||
def test_dataset_for_text_tokens_multiple_workers(tmpdir):
|
||||
seed_everything(42)
|
||||
|
||||
block_size = 10
|
||||
cache = Cache(input_dir=str(tmpdir), chunk_size=40, item_loader=TokensLoader(block_size))
|
||||
|
||||
counter = 0
|
||||
for i in range(10):
|
||||
text_ids = torch.arange(counter, counter + 20).to(torch.int)
|
||||
cache[i] = text_ids
|
||||
counter += 20
|
||||
|
||||
cache.done()
|
||||
cache.merge()
|
||||
|
||||
for i in range(20):
|
||||
sequence = cache[i]
|
||||
assert sequence[0].item() == i * block_size
|
||||
assert sequence[-1].item() == (i + 1) * block_size - 1
|
||||
|
||||
assert len(os.listdir(tmpdir)) == 6
|
||||
|
||||
dataset = StreamingDataset(input_dir=str(tmpdir), item_loader=TokensLoader(block_size), shuffle=False)
|
||||
|
||||
assert len(dataset) == 20
|
||||
|
||||
dataloader = DataLoader(dataset, batch_size=2, num_workers=2, shuffle=False)
|
||||
|
||||
assert len(dataloader) == 10
|
||||
|
||||
expected = [
|
||||
[0, 10],
|
||||
[40, 50],
|
||||
[20, 30],
|
||||
[60, 70],
|
||||
[80, 90],
|
||||
[120, 130],
|
||||
[100, 110],
|
||||
[140, 150],
|
||||
[160, 170],
|
||||
[180, 190],
|
||||
]
|
||||
|
||||
for result, batch in zip(expected, dataloader):
|
||||
assert [batch[0][0].item(), batch[1][0].item()] == result
|
||||
|
||||
|
||||
def test_dataset_for_text_tokens_distributed_num_workers(tmpdir):
|
||||
seed_everything(42)
|
||||
|
||||
block_size = 10
|
||||
cache = Cache(input_dir=str(tmpdir), chunk_size=40, item_loader=TokensLoader(block_size))
|
||||
|
||||
counter = 0
|
||||
for i in range(10):
|
||||
text_ids = torch.arange(counter, counter + 20).to(torch.int)
|
||||
cache[i] = text_ids
|
||||
counter += 20
|
||||
|
||||
cache.done()
|
||||
cache.merge()
|
||||
|
||||
for i in range(20):
|
||||
sequence = cache[i]
|
||||
assert sequence[0].item() == i * block_size
|
||||
assert sequence[-1].item() == (i + 1) * block_size - 1
|
||||
|
||||
assert len([f for f in os.listdir(tmpdir) if f.endswith(".bin")]) == 5
|
||||
|
||||
dataset = StreamingDataset(input_dir=str(tmpdir), item_loader=TokensLoader(block_size), shuffle=False)
|
||||
|
||||
assert len(dataset) == 20
|
||||
|
||||
dataset.distributed_env = _DistributedEnv(2, 0)
|
||||
dataloader = DataLoader(dataset, batch_size=2, shuffle=False, num_workers=2)
|
||||
|
||||
assert len(dataloader) == 6
|
||||
|
||||
expected = [[0, 10], [80, 90], [20, 30], [100, 110], [160, 170], [180, 190]]
|
||||
|
||||
for batch_idx, batch in enumerate(dataloader):
|
||||
assert [batch[0][0].item(), batch[1][0].item()] == expected[batch_idx]
|
||||
|
||||
dataset.distributed_env = _DistributedEnv(2, 1)
|
||||
dataloader = DataLoader(dataset, batch_size=2, shuffle=False)
|
||||
|
||||
assert len(dataloader) == 4
|
||||
|
||||
expected = [[40, 50], [60, 70], [120, 130], [140, 150]]
|
||||
|
||||
for batch_idx, batch in enumerate(dataloader):
|
||||
assert [batch[0][0].item(), batch[1][0].item()] == expected[batch_idx]
|
||||
|
||||
for batch_idx, batch in enumerate(dataloader):
|
||||
assert [batch[0][0].item(), batch[1][0].item()] == expected[batch_idx]
|
||||
|
||||
|
||||
def optimize_fn(item):
|
||||
return torch.arange(item[0], item[0] + 20).to(torch.int)
|
||||
|
||||
|
||||
def test_dataset_for_text_tokens_distributed_num_workers_end_to_end(tmpdir, monkeypatch):
|
||||
monkeypatch.setattr(functions, "_get_input_dir", lambda x: str(tmpdir))
|
||||
|
||||
seed_everything(42)
|
||||
|
||||
with open(tmpdir / "a.txt", "w") as f:
|
||||
f.write("hello")
|
||||
|
||||
inputs = [(v, str(tmpdir / "a.txt")) for v in range(0, 200, 20)]
|
||||
|
||||
cache_dir = os.path.join(tmpdir, "cache")
|
||||
output_dir = os.path.join(tmpdir, "target_dir")
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
monkeypatch.setenv("DATA_OPTIMIZER_CACHE_FOLDER", cache_dir)
|
||||
monkeypatch.setenv("DATA_OPTIMIZER_DATA_CACHE_FOLDER", cache_dir)
|
||||
|
||||
def fn(item):
|
||||
return torch.arange(item[0], item[0] + 20).to(torch.int)
|
||||
|
||||
functions.optimize(
|
||||
optimize_fn, inputs, output_dir=str(tmpdir), num_workers=2, chunk_size=2, reorder_files=False, num_downloaders=1
|
||||
)
|
||||
|
||||
assert len([f for f in os.listdir(tmpdir) if f.endswith(".bin")]) == 10
|
||||
|
||||
block_size = 10
|
||||
dataset = StreamingDataset(input_dir=str(tmpdir), item_loader=TokensLoader(block_size), shuffle=False)
|
||||
|
||||
L = len(dataset)
|
||||
assert len(dataset) == L
|
||||
|
||||
for i in range(L):
|
||||
sequence = dataset[i]
|
||||
assert sequence[0].item() == i * block_size
|
||||
assert sequence[-1].item() == (i + 1) * block_size - 1
|
||||
|
||||
dataset = StreamingDataset(input_dir=str(tmpdir), item_loader=TokensLoader(block_size), shuffle=False)
|
||||
|
||||
dataset.distributed_env = _DistributedEnv(2, 0)
|
||||
dataloader = DataLoader(dataset, batch_size=2, shuffle=False, num_workers=2)
|
||||
|
||||
assert len(dataloader) == 5
|
||||
|
||||
expected = [[0, 10], [40, 50], [80, 90], [120, 130], [160, 170]]
|
||||
|
||||
for batch_idx, batch in enumerate(dataloader):
|
||||
assert [batch[0][0].item(), batch[1][0].item()] == expected[batch_idx]
|
||||
|
||||
dataset.distributed_env = _DistributedEnv(2, 1)
|
||||
dataloader = DataLoader(dataset, batch_size=2, shuffle=False)
|
||||
|
||||
assert len(dataloader) == 5
|
||||
|
||||
expected = [[20, 30], [60, 70], [100, 110], [140, 150], [180, 190]]
|
||||
|
||||
for batch_idx, batch in enumerate(dataloader):
|
||||
assert [batch[0][0].item(), batch[1][0].item()] == expected[batch_idx]
|
||||
|
|
|
@ -4,6 +4,7 @@ from unittest import mock
|
|||
|
||||
from lightning.data.streaming import reader
|
||||
from lightning.data.streaming.cache import Cache
|
||||
from lightning.data.streaming.config import ChunkedIndex
|
||||
from lightning_cloud.resolver import Dir
|
||||
|
||||
|
||||
|
@ -30,7 +31,8 @@ def test_reader_chunk_removal(tmpdir, monkeypatch):
|
|||
os.makedirs(cache_dir, exist_ok=True)
|
||||
|
||||
for i in range(25):
|
||||
assert cache[i] == i
|
||||
index = ChunkedIndex(i, cache._get_chunk_index_from_index(i), last_index=i == 24)
|
||||
assert cache[index] == i
|
||||
|
||||
assert len(os.listdir(cache_dir)) == 14
|
||||
|
||||
|
@ -46,7 +48,8 @@ def test_reader_chunk_removal(tmpdir, monkeypatch):
|
|||
expected = []
|
||||
for i in range(25):
|
||||
expected.append([i, len(os.listdir(cache_dir))])
|
||||
assert cache[i] == i
|
||||
index = ChunkedIndex(i, cache._get_chunk_index_from_index(i), last_index=i == 24)
|
||||
assert cache[index] == i
|
||||
|
||||
assert expected == [
|
||||
[0, 0],
|
||||
|
@ -70,10 +73,10 @@ def test_reader_chunk_removal(tmpdir, monkeypatch):
|
|||
[18, 9],
|
||||
[19, 10],
|
||||
[20, 10],
|
||||
[21, 11],
|
||||
[22, 11], # Cleanup is triggered
|
||||
[23, 2],
|
||||
[24, 2],
|
||||
[21, 2],
|
||||
[22, 2],
|
||||
[23, 3],
|
||||
[24, 3],
|
||||
]
|
||||
|
||||
assert len(os.listdir(cache_dir)) == 3
|
||||
assert len(os.listdir(cache_dir)) in [3, 4]
|
||||
|
|
Loading…
Reference in New Issue