Prevent downloading more chunks than needed (#18964)

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: thomas <thomas@thomass-MacBook-Pro.local>
This commit is contained in:
thomas chaton 2023-11-07 19:40:21 +00:00 committed by GitHub
parent 20f58f63ef
commit 1b3a3fbaad
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 86 additions and 39 deletions

View File

@ -34,20 +34,25 @@ if _TORCH_GREATER_EQUAL_2_1_0:
class PrepareChunksThread(Thread):
"""This thread is responsible to download the chunks associated to a given worker."""
def __init__(self, config: ChunksConfig, max_cache_size: Optional[int] = None) -> None:
def __init__(self, config: ChunksConfig, max_cache_size: Optional[int] = None, pre_download: int = 10) -> None:
super().__init__(daemon=True)
self._config = config
self._chunks_index_to_be_processed: List[int] = []
self._chunks_index_to_be_downloaded: List[int] = []
self._chunks_index_to_be_deleted: List[int] = []
self._lock = Lock()
self._max_cache_size = max_cache_size
self._downloaded_chunks = 0
self._processed_chunks = 0
self._processed_chunks_counter = 0
self._delete_chunks = 0
self._pre_download = pre_download
def add(self, chunk_indices: List[int]) -> None:
def download(self, chunk_indices: List[int]) -> None:
"""Receive the list of the chunk indices to download for the current epoch."""
with self._lock:
for chunk_indice in chunk_indices:
if chunk_indice not in self._chunks_index_to_be_processed:
self._chunks_index_to_be_processed.append(chunk_indice)
if chunk_indice not in self._chunks_index_to_be_downloaded:
self._chunks_index_to_be_downloaded.append(chunk_indice)
def delete(self, chunk_indices: List[int]) -> None:
"""Receive the list of the chunk indices to download for the current epoch."""
@ -55,6 +60,8 @@ class PrepareChunksThread(Thread):
for chunk_indice in chunk_indices:
if chunk_indice not in self._chunks_index_to_be_deleted:
self._chunks_index_to_be_deleted.append(chunk_indice)
self._processed_chunks += 1
self._processed_chunks_counter += 1
def _delete(self, chunk_index: int) -> None:
chunk_filepath, begin, _ = self._config[ChunkedIndex(index=-1, chunk_index=chunk_index)]
@ -65,25 +72,34 @@ class PrepareChunksThread(Thread):
def run(self) -> None:
while True:
with self._lock:
if len(self._chunks_index_to_be_processed) == 0 and len(self._chunks_index_to_be_deleted) == 0:
sleep(0.005)
# Wait for something to do
if len(self._chunks_index_to_be_downloaded) == 0 and len(self._chunks_index_to_be_deleted) == 0:
sleep(0.01)
continue
# Delete the chunks if we are missing disk space.
# Check every 10 items to avoid losing too much time
if self._max_cache_size and len(self._chunks_index_to_be_deleted) > 10:
if self._max_cache_size and self._processed_chunks_counter >= self._pre_download:
if shutil.disk_usage(self._config._cache_dir).total >= self._max_cache_size:
for chunk_index in self._chunks_index_to_be_deleted:
if chunk_index not in self._chunks_index_to_be_processed:
if chunk_index not in self._chunks_index_to_be_downloaded:
self._delete(chunk_index)
self._delete_chunks += 1
self._processed_chunks_counter = 0
self._chunks_index_to_be_deleted = []
if len(self._chunks_index_to_be_processed) == 0:
# If there is no chunks to download, go back to waiting
if len(self._chunks_index_to_be_downloaded) == 0:
continue
chunk_index = self._chunks_index_to_be_processed.pop(0)
# If we have already downloaded too many chunks, let's wait for processed chunks to catch up
if self._max_cache_size and (self._downloaded_chunks - self._processed_chunks) > self._pre_download:
sleep(0.01)
continue
chunk_index = self._chunks_index_to_be_downloaded.pop(0)
self._config.download_chunk_from_index(chunk_index)
self._downloaded_chunks += 1
class BinaryReader:
@ -123,7 +139,7 @@ class BinaryReader:
self._rank: Optional[int] = None
self._config: Optional[ChunksConfig] = None
self._prepare_thread: Optional[PrepareChunksThread] = None
self._chunks_index_to_be_processed: List[int] = []
self._chunks_index_to_be_downloaded: List[int] = []
self._item_loader = item_loader or PyTreeLoader()
self._last_chunk_index: Optional[int] = None
self._max_cache_size = int(os.getenv("MAX_CACHE_SIZE", max_cache_size))
@ -175,8 +191,8 @@ class BinaryReader:
self._prepare_thread = PrepareChunksThread(self._config, self._max_cache_size)
self._prepare_thread.start()
if index.chunk_indexes:
self._chunks_index_to_be_processed.extend(index.chunk_indexes)
self._prepare_thread.add(index.chunk_indexes)
self._chunks_index_to_be_downloaded.extend(index.chunk_indexes)
self._prepare_thread.download(index.chunk_indexes)
# If the chunk_index isn't already in the download and delete queues, add it.
if index.chunk_index != self._last_chunk_index:
@ -186,7 +202,7 @@ class BinaryReader:
self._prepare_thread.delete([self._last_chunk_index])
self._last_chunk_index = index.chunk_index
self._prepare_thread.add([index.chunk_index])
self._prepare_thread.download([index.chunk_index])
# Fetch the element
chunk_filepath, begin, _ = self.config[index]

View File

@ -223,7 +223,7 @@ def test_wait_for_file_to_exist():
def test_broadcast_object(tmpdir, monkeypatch):
data_processor = DataProcessor(input_dir=tmpdir)
data_processor = DataProcessor(input_dir=str(tmpdir))
assert data_processor._broadcast_object("dummy") == "dummy"
monkeypatch.setenv("DATA_OPTIMIZER_NUM_NODES", "2")
monkeypatch.setattr(data_processor_module, "_distributed_is_initialized", lambda: True)
@ -244,7 +244,7 @@ def test_cache_dir_cleanup(tmpdir, monkeypatch):
assert os.listdir(cache_dir) == ["a.txt"]
data_processor = DataProcessor(input_dir=tmpdir)
data_processor = DataProcessor(input_dir=str(tmpdir))
monkeypatch.setenv("DATA_OPTIMIZER_CACHE_FOLDER", str(cache_dir))
monkeypatch.setenv("DATA_OPTIMIZER_DATA_CACHE_FOLDER", str(cache_data_dir))
data_processor._cleanup_cache()
@ -568,7 +568,7 @@ def test_data_processsor_nlp(tmpdir, monkeypatch):
with open(os.path.join(tmpdir, "dummy.txt"), "w") as f:
f.write("Hello World !")
data_processor = DataProcessor(input_dir=tmpdir, num_workers=1, num_downloaders=1)
data_processor = DataProcessor(input_dir=str(tmpdir), num_workers=1, num_downloaders=1)
data_processor.run(TextTokenizeRecipe(chunk_size=1024 * 11))

View File

@ -27,20 +27,20 @@ from torch.utils.data import DataLoader
def test_streaming_dataset(tmpdir, monkeypatch):
seed_everything(42)
dataset = StreamingDataset(input_dir=tmpdir)
dataset = StreamingDataset(input_dir=str(tmpdir))
with pytest.raises(ValueError, match="The provided dataset"):
iter(dataset)
dataset = StreamingDataset(input_dir=tmpdir)
dataset = StreamingDataset(input_dir=str(tmpdir))
with pytest.raises(ValueError, match="The provided dataset"):
_ = dataset[0]
cache = Cache(tmpdir, chunk_size=10)
cache = Cache(str(tmpdir), chunk_size=10)
for i in range(12):
cache[i] = i
cache.done()
cache.merge()
dataset = StreamingDataset(input_dir=tmpdir)
dataset = StreamingDataset(input_dir=str(tmpdir))
assert len(dataset) == 12
dataset_iter = iter(dataset)
@ -54,7 +54,7 @@ def test_streaming_dataset(tmpdir, monkeypatch):
@mock.patch.dict(os.environ, {"LIGHTNING_CLUSTER_ID": "123", "LIGHTNING_CLOUD_PROJECT_ID": "456"})
@mock.patch("lightning.data.streaming.dataset.os.makedirs")
def test_create_cache_dir_in_lightning_cloud(makedirs_mock, tmpdir):
def test_create_cache_dir_in_lightning_cloud(makedirs_mock):
# Locally, we can't actually write to the root filesystem with user privileges, so we need to mock the call
dataset = StreamingDataset("dummy")
expected = os.path.join("/cache", "chunks", "275876e34cf609db118f3d84b799a790", "0")
@ -67,14 +67,14 @@ def test_create_cache_dir_in_lightning_cloud(makedirs_mock, tmpdir):
def test_streaming_dataset_distributed_no_shuffle(drop_last, tmpdir):
seed_everything(42)
cache = Cache(tmpdir, chunk_size=10)
cache = Cache(str(tmpdir), chunk_size=10)
for i in range(101):
cache[i] = i
cache.done()
cache.merge()
dataset = StreamingDataset(input_dir=tmpdir, shuffle=False, drop_last=drop_last)
dataset = StreamingDataset(input_dir=str(tmpdir), shuffle=False, drop_last=drop_last)
assert not dataset.shuffle
_ = dataset[0] # init shuffler
assert isinstance(dataset.shuffler, NoShuffle)
@ -98,7 +98,7 @@ def test_streaming_dataset_distributed_no_shuffle(drop_last, tmpdir):
assert process_1_2[:10] == [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
assert len(process_1_2) == 50 + int(not drop_last)
dataset = StreamingDataset(input_dir=tmpdir, shuffle=False, drop_last=drop_last)
dataset = StreamingDataset(input_dir=str(tmpdir), shuffle=False, drop_last=drop_last)
dataset.distributed_env = _DistributedEnv(2, 1)
assert len(dataset) == 50
dataset_iter = iter(dataset)
@ -148,14 +148,14 @@ def test_streaming_dataset_distributed_no_shuffle(drop_last, tmpdir):
def test_streaming_dataset_distributed_full_shuffle_odd(drop_last, tmpdir):
seed_everything(42)
cache = Cache(input_dir=tmpdir, chunk_size=10)
cache = Cache(input_dir=str(tmpdir), chunk_size=10)
for i in range(1097):
cache[i] = i
cache.done()
cache.merge()
dataset = StreamingDataset(input_dir=tmpdir, shuffle=True, drop_last=drop_last)
dataset = StreamingDataset(input_dir=str(tmpdir), shuffle=True, drop_last=drop_last)
assert dataset.shuffle
_ = dataset[0]
assert isinstance(dataset.shuffler, FullShuffle)
@ -171,7 +171,7 @@ def test_streaming_dataset_distributed_full_shuffle_odd(drop_last, tmpdir):
assert process_1_1[:10] == [785, 788, 782, 783, 789, 787, 786, 781, 784, 780]
assert len(process_1_1) == 548
dataset_2 = StreamingDataset(input_dir=tmpdir, shuffle=True, drop_last=drop_last)
dataset_2 = StreamingDataset(input_dir=str(tmpdir), shuffle=True, drop_last=drop_last)
iter(dataset_2)
assert isinstance(dataset_2.shuffler, FullShuffle)
dataset_2.distributed_env = _DistributedEnv(2, 1)
@ -188,14 +188,14 @@ def test_streaming_dataset_distributed_full_shuffle_odd(drop_last, tmpdir):
def test_streaming_dataset_distributed_full_shuffle_even(drop_last, tmpdir):
seed_everything(42)
cache = Cache(tmpdir, chunk_size=10)
cache = Cache(str(tmpdir), chunk_size=10)
for i in range(1222):
cache[i] = i
cache.done()
cache.merge()
dataset = StreamingDataset(input_dir=tmpdir, shuffle=True, drop_last=drop_last)
dataset = StreamingDataset(input_dir=str(tmpdir), shuffle=True, drop_last=drop_last)
assert dataset.shuffle
_ = dataset[0]
assert isinstance(dataset.shuffler, FullShuffle)
@ -211,7 +211,7 @@ def test_streaming_dataset_distributed_full_shuffle_even(drop_last, tmpdir):
assert process_1_1[:10] == [185, 184, 182, 189, 187, 181, 183, 180, 186, 188]
assert len(process_1_1) == 611
dataset_2 = StreamingDataset(input_dir=tmpdir, shuffle=True, drop_last=drop_last)
dataset_2 = StreamingDataset(input_dir=str(tmpdir), shuffle=True, drop_last=drop_last)
iter(dataset_2)
assert isinstance(dataset_2.shuffler, FullShuffle)
dataset_2.distributed_env = _DistributedEnv(2, 1)
@ -256,14 +256,14 @@ def test_streaming_dataset_deepcopy(tmpdir, monkeypatch):
def test_dataset_cache_recreation(tmpdir):
"""Test that we recreate the cache and other objects only when appropriate."""
cache = Cache(tmpdir, chunk_size=10)
cache = Cache(str(tmpdir), chunk_size=10)
for i in range(10):
cache[i] = i
cache.done()
cache.merge()
# repated `len()` calls
dataset = StreamingDataset(input_dir=tmpdir)
dataset = StreamingDataset(input_dir=str(tmpdir))
assert not dataset.cache
assert not dataset.shuffler
len(dataset)
@ -274,7 +274,7 @@ def test_dataset_cache_recreation(tmpdir):
assert dataset.shuffler is shuffler
# repeated `iter()` calls
dataset = StreamingDataset(input_dir=tmpdir)
dataset = StreamingDataset(input_dir=str(tmpdir))
assert not dataset.cache
assert not dataset.shuffler
iter(dataset)
@ -289,7 +289,7 @@ def test_dataset_cache_recreation(tmpdir):
assert dataset.shuffler is not shuffler # shuffler gets recreated
# repeated `getitem()` calls
dataset = StreamingDataset(input_dir=tmpdir)
dataset = StreamingDataset(input_dir=str(tmpdir))
assert not dataset.cache
assert not dataset.shuffler
_ = dataset[0]

View File

@ -43,7 +43,37 @@ def test_reader_chunk_removal(tmpdir, monkeypatch):
shutil_mock.disk_usage.return_value = disk_usage
monkeypatch.setattr(reader, "shutil", shutil_mock)
expected = []
for i in range(25):
expected.append([i, len(os.listdir(cache_dir))])
assert cache[i] == i
assert expected == [
[0, 0],
[1, 1],
[2, 1],
[3, 2],
[4, 2],
[5, 3],
[6, 3],
[7, 4],
[8, 4],
[9, 5],
[10, 5],
[11, 6],
[12, 6],
[13, 7],
[14, 7],
[15, 8],
[16, 8],
[17, 9],
[18, 9],
[19, 10],
[20, 10],
[21, 11],
[22, 11], # Cleanup is triggered
[23, 2],
[24, 2],
]
assert len(os.listdir(cache_dir)) == 3

View File

@ -13,6 +13,7 @@
import json
import os
import sys
import numpy as np
import pytest
@ -97,7 +98,7 @@ def test_binary_writer_with_ints_and_chunk_size(tmpdir):
assert data == {"i": i, "i+1": i + 1, "i+2": i + 2}
@pytest.mark.skipif(condition=not _PIL_AVAILABLE, reason="Requires: ['pil']")
@pytest.mark.skipif(condition=not _PIL_AVAILABLE or sys.platform == "darwin", reason="Requires: ['pil']")
def test_binary_writer_with_jpeg_and_int(tmpdir):
"""Validate the writer and reader can serialize / deserialize a pair of image and label."""
from PIL import Image
@ -136,7 +137,7 @@ def test_binary_writer_with_jpeg_and_int(tmpdir):
assert data["y"] == i
@pytest.mark.skipif(condition=not _PIL_AVAILABLE, reason="Requires: ['pil']")
@pytest.mark.skipif(condition=not _PIL_AVAILABLE or sys.platform == "darwin", reason="Requires: ['pil']")
def test_binary_writer_with_jpeg_filepath_and_int(tmpdir):
"""Validate the writer and reader can serialize / deserialize a pair of image and label."""
from PIL import Image