Prevent downloading more chunks than needed (#18964)
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: thomas <thomas@thomass-MacBook-Pro.local>
This commit is contained in:
parent
20f58f63ef
commit
1b3a3fbaad
|
@ -34,20 +34,25 @@ if _TORCH_GREATER_EQUAL_2_1_0:
|
|||
class PrepareChunksThread(Thread):
|
||||
"""This thread is responsible to download the chunks associated to a given worker."""
|
||||
|
||||
def __init__(self, config: ChunksConfig, max_cache_size: Optional[int] = None) -> None:
|
||||
def __init__(self, config: ChunksConfig, max_cache_size: Optional[int] = None, pre_download: int = 10) -> None:
|
||||
super().__init__(daemon=True)
|
||||
self._config = config
|
||||
self._chunks_index_to_be_processed: List[int] = []
|
||||
self._chunks_index_to_be_downloaded: List[int] = []
|
||||
self._chunks_index_to_be_deleted: List[int] = []
|
||||
self._lock = Lock()
|
||||
self._max_cache_size = max_cache_size
|
||||
self._downloaded_chunks = 0
|
||||
self._processed_chunks = 0
|
||||
self._processed_chunks_counter = 0
|
||||
self._delete_chunks = 0
|
||||
self._pre_download = pre_download
|
||||
|
||||
def add(self, chunk_indices: List[int]) -> None:
|
||||
def download(self, chunk_indices: List[int]) -> None:
|
||||
"""Receive the list of the chunk indices to download for the current epoch."""
|
||||
with self._lock:
|
||||
for chunk_indice in chunk_indices:
|
||||
if chunk_indice not in self._chunks_index_to_be_processed:
|
||||
self._chunks_index_to_be_processed.append(chunk_indice)
|
||||
if chunk_indice not in self._chunks_index_to_be_downloaded:
|
||||
self._chunks_index_to_be_downloaded.append(chunk_indice)
|
||||
|
||||
def delete(self, chunk_indices: List[int]) -> None:
|
||||
"""Receive the list of the chunk indices to download for the current epoch."""
|
||||
|
@ -55,6 +60,8 @@ class PrepareChunksThread(Thread):
|
|||
for chunk_indice in chunk_indices:
|
||||
if chunk_indice not in self._chunks_index_to_be_deleted:
|
||||
self._chunks_index_to_be_deleted.append(chunk_indice)
|
||||
self._processed_chunks += 1
|
||||
self._processed_chunks_counter += 1
|
||||
|
||||
def _delete(self, chunk_index: int) -> None:
|
||||
chunk_filepath, begin, _ = self._config[ChunkedIndex(index=-1, chunk_index=chunk_index)]
|
||||
|
@ -65,25 +72,34 @@ class PrepareChunksThread(Thread):
|
|||
def run(self) -> None:
|
||||
while True:
|
||||
with self._lock:
|
||||
if len(self._chunks_index_to_be_processed) == 0 and len(self._chunks_index_to_be_deleted) == 0:
|
||||
sleep(0.005)
|
||||
# Wait for something to do
|
||||
if len(self._chunks_index_to_be_downloaded) == 0 and len(self._chunks_index_to_be_deleted) == 0:
|
||||
sleep(0.01)
|
||||
continue
|
||||
|
||||
# Delete the chunks if we are missing disk space.
|
||||
# Check every 10 items to avoid losing too much time
|
||||
if self._max_cache_size and len(self._chunks_index_to_be_deleted) > 10:
|
||||
if self._max_cache_size and self._processed_chunks_counter >= self._pre_download:
|
||||
if shutil.disk_usage(self._config._cache_dir).total >= self._max_cache_size:
|
||||
for chunk_index in self._chunks_index_to_be_deleted:
|
||||
if chunk_index not in self._chunks_index_to_be_processed:
|
||||
if chunk_index not in self._chunks_index_to_be_downloaded:
|
||||
self._delete(chunk_index)
|
||||
self._delete_chunks += 1
|
||||
self._processed_chunks_counter = 0
|
||||
self._chunks_index_to_be_deleted = []
|
||||
|
||||
if len(self._chunks_index_to_be_processed) == 0:
|
||||
# If there is no chunks to download, go back to waiting
|
||||
if len(self._chunks_index_to_be_downloaded) == 0:
|
||||
continue
|
||||
|
||||
chunk_index = self._chunks_index_to_be_processed.pop(0)
|
||||
# If we have already downloaded too many chunks, let's wait for processed chunks to catch up
|
||||
if self._max_cache_size and (self._downloaded_chunks - self._processed_chunks) > self._pre_download:
|
||||
sleep(0.01)
|
||||
continue
|
||||
|
||||
chunk_index = self._chunks_index_to_be_downloaded.pop(0)
|
||||
|
||||
self._config.download_chunk_from_index(chunk_index)
|
||||
self._downloaded_chunks += 1
|
||||
|
||||
|
||||
class BinaryReader:
|
||||
|
@ -123,7 +139,7 @@ class BinaryReader:
|
|||
self._rank: Optional[int] = None
|
||||
self._config: Optional[ChunksConfig] = None
|
||||
self._prepare_thread: Optional[PrepareChunksThread] = None
|
||||
self._chunks_index_to_be_processed: List[int] = []
|
||||
self._chunks_index_to_be_downloaded: List[int] = []
|
||||
self._item_loader = item_loader or PyTreeLoader()
|
||||
self._last_chunk_index: Optional[int] = None
|
||||
self._max_cache_size = int(os.getenv("MAX_CACHE_SIZE", max_cache_size))
|
||||
|
@ -175,8 +191,8 @@ class BinaryReader:
|
|||
self._prepare_thread = PrepareChunksThread(self._config, self._max_cache_size)
|
||||
self._prepare_thread.start()
|
||||
if index.chunk_indexes:
|
||||
self._chunks_index_to_be_processed.extend(index.chunk_indexes)
|
||||
self._prepare_thread.add(index.chunk_indexes)
|
||||
self._chunks_index_to_be_downloaded.extend(index.chunk_indexes)
|
||||
self._prepare_thread.download(index.chunk_indexes)
|
||||
|
||||
# If the chunk_index isn't already in the download and delete queues, add it.
|
||||
if index.chunk_index != self._last_chunk_index:
|
||||
|
@ -186,7 +202,7 @@ class BinaryReader:
|
|||
self._prepare_thread.delete([self._last_chunk_index])
|
||||
|
||||
self._last_chunk_index = index.chunk_index
|
||||
self._prepare_thread.add([index.chunk_index])
|
||||
self._prepare_thread.download([index.chunk_index])
|
||||
|
||||
# Fetch the element
|
||||
chunk_filepath, begin, _ = self.config[index]
|
||||
|
|
|
@ -223,7 +223,7 @@ def test_wait_for_file_to_exist():
|
|||
|
||||
|
||||
def test_broadcast_object(tmpdir, monkeypatch):
|
||||
data_processor = DataProcessor(input_dir=tmpdir)
|
||||
data_processor = DataProcessor(input_dir=str(tmpdir))
|
||||
assert data_processor._broadcast_object("dummy") == "dummy"
|
||||
monkeypatch.setenv("DATA_OPTIMIZER_NUM_NODES", "2")
|
||||
monkeypatch.setattr(data_processor_module, "_distributed_is_initialized", lambda: True)
|
||||
|
@ -244,7 +244,7 @@ def test_cache_dir_cleanup(tmpdir, monkeypatch):
|
|||
|
||||
assert os.listdir(cache_dir) == ["a.txt"]
|
||||
|
||||
data_processor = DataProcessor(input_dir=tmpdir)
|
||||
data_processor = DataProcessor(input_dir=str(tmpdir))
|
||||
monkeypatch.setenv("DATA_OPTIMIZER_CACHE_FOLDER", str(cache_dir))
|
||||
monkeypatch.setenv("DATA_OPTIMIZER_DATA_CACHE_FOLDER", str(cache_data_dir))
|
||||
data_processor._cleanup_cache()
|
||||
|
@ -568,7 +568,7 @@ def test_data_processsor_nlp(tmpdir, monkeypatch):
|
|||
with open(os.path.join(tmpdir, "dummy.txt"), "w") as f:
|
||||
f.write("Hello World !")
|
||||
|
||||
data_processor = DataProcessor(input_dir=tmpdir, num_workers=1, num_downloaders=1)
|
||||
data_processor = DataProcessor(input_dir=str(tmpdir), num_workers=1, num_downloaders=1)
|
||||
data_processor.run(TextTokenizeRecipe(chunk_size=1024 * 11))
|
||||
|
||||
|
||||
|
|
|
@ -27,20 +27,20 @@ from torch.utils.data import DataLoader
|
|||
def test_streaming_dataset(tmpdir, monkeypatch):
|
||||
seed_everything(42)
|
||||
|
||||
dataset = StreamingDataset(input_dir=tmpdir)
|
||||
dataset = StreamingDataset(input_dir=str(tmpdir))
|
||||
with pytest.raises(ValueError, match="The provided dataset"):
|
||||
iter(dataset)
|
||||
dataset = StreamingDataset(input_dir=tmpdir)
|
||||
dataset = StreamingDataset(input_dir=str(tmpdir))
|
||||
with pytest.raises(ValueError, match="The provided dataset"):
|
||||
_ = dataset[0]
|
||||
|
||||
cache = Cache(tmpdir, chunk_size=10)
|
||||
cache = Cache(str(tmpdir), chunk_size=10)
|
||||
for i in range(12):
|
||||
cache[i] = i
|
||||
cache.done()
|
||||
cache.merge()
|
||||
|
||||
dataset = StreamingDataset(input_dir=tmpdir)
|
||||
dataset = StreamingDataset(input_dir=str(tmpdir))
|
||||
|
||||
assert len(dataset) == 12
|
||||
dataset_iter = iter(dataset)
|
||||
|
@ -54,7 +54,7 @@ def test_streaming_dataset(tmpdir, monkeypatch):
|
|||
|
||||
@mock.patch.dict(os.environ, {"LIGHTNING_CLUSTER_ID": "123", "LIGHTNING_CLOUD_PROJECT_ID": "456"})
|
||||
@mock.patch("lightning.data.streaming.dataset.os.makedirs")
|
||||
def test_create_cache_dir_in_lightning_cloud(makedirs_mock, tmpdir):
|
||||
def test_create_cache_dir_in_lightning_cloud(makedirs_mock):
|
||||
# Locally, we can't actually write to the root filesystem with user privileges, so we need to mock the call
|
||||
dataset = StreamingDataset("dummy")
|
||||
expected = os.path.join("/cache", "chunks", "275876e34cf609db118f3d84b799a790", "0")
|
||||
|
@ -67,14 +67,14 @@ def test_create_cache_dir_in_lightning_cloud(makedirs_mock, tmpdir):
|
|||
def test_streaming_dataset_distributed_no_shuffle(drop_last, tmpdir):
|
||||
seed_everything(42)
|
||||
|
||||
cache = Cache(tmpdir, chunk_size=10)
|
||||
cache = Cache(str(tmpdir), chunk_size=10)
|
||||
for i in range(101):
|
||||
cache[i] = i
|
||||
|
||||
cache.done()
|
||||
cache.merge()
|
||||
|
||||
dataset = StreamingDataset(input_dir=tmpdir, shuffle=False, drop_last=drop_last)
|
||||
dataset = StreamingDataset(input_dir=str(tmpdir), shuffle=False, drop_last=drop_last)
|
||||
assert not dataset.shuffle
|
||||
_ = dataset[0] # init shuffler
|
||||
assert isinstance(dataset.shuffler, NoShuffle)
|
||||
|
@ -98,7 +98,7 @@ def test_streaming_dataset_distributed_no_shuffle(drop_last, tmpdir):
|
|||
assert process_1_2[:10] == [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
|
||||
assert len(process_1_2) == 50 + int(not drop_last)
|
||||
|
||||
dataset = StreamingDataset(input_dir=tmpdir, shuffle=False, drop_last=drop_last)
|
||||
dataset = StreamingDataset(input_dir=str(tmpdir), shuffle=False, drop_last=drop_last)
|
||||
dataset.distributed_env = _DistributedEnv(2, 1)
|
||||
assert len(dataset) == 50
|
||||
dataset_iter = iter(dataset)
|
||||
|
@ -148,14 +148,14 @@ def test_streaming_dataset_distributed_no_shuffle(drop_last, tmpdir):
|
|||
def test_streaming_dataset_distributed_full_shuffle_odd(drop_last, tmpdir):
|
||||
seed_everything(42)
|
||||
|
||||
cache = Cache(input_dir=tmpdir, chunk_size=10)
|
||||
cache = Cache(input_dir=str(tmpdir), chunk_size=10)
|
||||
for i in range(1097):
|
||||
cache[i] = i
|
||||
|
||||
cache.done()
|
||||
cache.merge()
|
||||
|
||||
dataset = StreamingDataset(input_dir=tmpdir, shuffle=True, drop_last=drop_last)
|
||||
dataset = StreamingDataset(input_dir=str(tmpdir), shuffle=True, drop_last=drop_last)
|
||||
assert dataset.shuffle
|
||||
_ = dataset[0]
|
||||
assert isinstance(dataset.shuffler, FullShuffle)
|
||||
|
@ -171,7 +171,7 @@ def test_streaming_dataset_distributed_full_shuffle_odd(drop_last, tmpdir):
|
|||
assert process_1_1[:10] == [785, 788, 782, 783, 789, 787, 786, 781, 784, 780]
|
||||
assert len(process_1_1) == 548
|
||||
|
||||
dataset_2 = StreamingDataset(input_dir=tmpdir, shuffle=True, drop_last=drop_last)
|
||||
dataset_2 = StreamingDataset(input_dir=str(tmpdir), shuffle=True, drop_last=drop_last)
|
||||
iter(dataset_2)
|
||||
assert isinstance(dataset_2.shuffler, FullShuffle)
|
||||
dataset_2.distributed_env = _DistributedEnv(2, 1)
|
||||
|
@ -188,14 +188,14 @@ def test_streaming_dataset_distributed_full_shuffle_odd(drop_last, tmpdir):
|
|||
def test_streaming_dataset_distributed_full_shuffle_even(drop_last, tmpdir):
|
||||
seed_everything(42)
|
||||
|
||||
cache = Cache(tmpdir, chunk_size=10)
|
||||
cache = Cache(str(tmpdir), chunk_size=10)
|
||||
for i in range(1222):
|
||||
cache[i] = i
|
||||
|
||||
cache.done()
|
||||
cache.merge()
|
||||
|
||||
dataset = StreamingDataset(input_dir=tmpdir, shuffle=True, drop_last=drop_last)
|
||||
dataset = StreamingDataset(input_dir=str(tmpdir), shuffle=True, drop_last=drop_last)
|
||||
assert dataset.shuffle
|
||||
_ = dataset[0]
|
||||
assert isinstance(dataset.shuffler, FullShuffle)
|
||||
|
@ -211,7 +211,7 @@ def test_streaming_dataset_distributed_full_shuffle_even(drop_last, tmpdir):
|
|||
assert process_1_1[:10] == [185, 184, 182, 189, 187, 181, 183, 180, 186, 188]
|
||||
assert len(process_1_1) == 611
|
||||
|
||||
dataset_2 = StreamingDataset(input_dir=tmpdir, shuffle=True, drop_last=drop_last)
|
||||
dataset_2 = StreamingDataset(input_dir=str(tmpdir), shuffle=True, drop_last=drop_last)
|
||||
iter(dataset_2)
|
||||
assert isinstance(dataset_2.shuffler, FullShuffle)
|
||||
dataset_2.distributed_env = _DistributedEnv(2, 1)
|
||||
|
@ -256,14 +256,14 @@ def test_streaming_dataset_deepcopy(tmpdir, monkeypatch):
|
|||
|
||||
def test_dataset_cache_recreation(tmpdir):
|
||||
"""Test that we recreate the cache and other objects only when appropriate."""
|
||||
cache = Cache(tmpdir, chunk_size=10)
|
||||
cache = Cache(str(tmpdir), chunk_size=10)
|
||||
for i in range(10):
|
||||
cache[i] = i
|
||||
cache.done()
|
||||
cache.merge()
|
||||
|
||||
# repated `len()` calls
|
||||
dataset = StreamingDataset(input_dir=tmpdir)
|
||||
dataset = StreamingDataset(input_dir=str(tmpdir))
|
||||
assert not dataset.cache
|
||||
assert not dataset.shuffler
|
||||
len(dataset)
|
||||
|
@ -274,7 +274,7 @@ def test_dataset_cache_recreation(tmpdir):
|
|||
assert dataset.shuffler is shuffler
|
||||
|
||||
# repeated `iter()` calls
|
||||
dataset = StreamingDataset(input_dir=tmpdir)
|
||||
dataset = StreamingDataset(input_dir=str(tmpdir))
|
||||
assert not dataset.cache
|
||||
assert not dataset.shuffler
|
||||
iter(dataset)
|
||||
|
@ -289,7 +289,7 @@ def test_dataset_cache_recreation(tmpdir):
|
|||
assert dataset.shuffler is not shuffler # shuffler gets recreated
|
||||
|
||||
# repeated `getitem()` calls
|
||||
dataset = StreamingDataset(input_dir=tmpdir)
|
||||
dataset = StreamingDataset(input_dir=str(tmpdir))
|
||||
assert not dataset.cache
|
||||
assert not dataset.shuffler
|
||||
_ = dataset[0]
|
||||
|
|
|
@ -43,7 +43,37 @@ def test_reader_chunk_removal(tmpdir, monkeypatch):
|
|||
shutil_mock.disk_usage.return_value = disk_usage
|
||||
monkeypatch.setattr(reader, "shutil", shutil_mock)
|
||||
|
||||
expected = []
|
||||
for i in range(25):
|
||||
expected.append([i, len(os.listdir(cache_dir))])
|
||||
assert cache[i] == i
|
||||
|
||||
assert expected == [
|
||||
[0, 0],
|
||||
[1, 1],
|
||||
[2, 1],
|
||||
[3, 2],
|
||||
[4, 2],
|
||||
[5, 3],
|
||||
[6, 3],
|
||||
[7, 4],
|
||||
[8, 4],
|
||||
[9, 5],
|
||||
[10, 5],
|
||||
[11, 6],
|
||||
[12, 6],
|
||||
[13, 7],
|
||||
[14, 7],
|
||||
[15, 8],
|
||||
[16, 8],
|
||||
[17, 9],
|
||||
[18, 9],
|
||||
[19, 10],
|
||||
[20, 10],
|
||||
[21, 11],
|
||||
[22, 11], # Cleanup is triggered
|
||||
[23, 2],
|
||||
[24, 2],
|
||||
]
|
||||
|
||||
assert len(os.listdir(cache_dir)) == 3
|
||||
|
|
|
@ -13,6 +13,7 @@
|
|||
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
@ -97,7 +98,7 @@ def test_binary_writer_with_ints_and_chunk_size(tmpdir):
|
|||
assert data == {"i": i, "i+1": i + 1, "i+2": i + 2}
|
||||
|
||||
|
||||
@pytest.mark.skipif(condition=not _PIL_AVAILABLE, reason="Requires: ['pil']")
|
||||
@pytest.mark.skipif(condition=not _PIL_AVAILABLE or sys.platform == "darwin", reason="Requires: ['pil']")
|
||||
def test_binary_writer_with_jpeg_and_int(tmpdir):
|
||||
"""Validate the writer and reader can serialize / deserialize a pair of image and label."""
|
||||
from PIL import Image
|
||||
|
@ -136,7 +137,7 @@ def test_binary_writer_with_jpeg_and_int(tmpdir):
|
|||
assert data["y"] == i
|
||||
|
||||
|
||||
@pytest.mark.skipif(condition=not _PIL_AVAILABLE, reason="Requires: ['pil']")
|
||||
@pytest.mark.skipif(condition=not _PIL_AVAILABLE or sys.platform == "darwin", reason="Requires: ['pil']")
|
||||
def test_binary_writer_with_jpeg_filepath_and_int(tmpdir):
|
||||
"""Validate the writer and reader can serialize / deserialize a pair of image and label."""
|
||||
from PIL import Image
|
||||
|
|
Loading…
Reference in New Issue