From 1b3a3fbaadba417ac4f33143cc6fdc5e3fc9d205 Mon Sep 17 00:00:00 2001 From: thomas chaton Date: Tue, 7 Nov 2023 19:40:21 +0000 Subject: [PATCH] Prevent downloading more chunks than needed (#18964) Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: thomas --- src/lightning/data/streaming/reader.py | 48 ++++++++++++------- .../streaming/test_data_processor.py | 6 +-- tests/tests_data/streaming/test_dataset.py | 36 +++++++------- tests/tests_data/streaming/test_reader.py | 30 ++++++++++++ tests/tests_data/streaming/test_writer.py | 5 +- 5 files changed, 86 insertions(+), 39 deletions(-) diff --git a/src/lightning/data/streaming/reader.py b/src/lightning/data/streaming/reader.py index 9ab6ea57da..7cb5971b31 100644 --- a/src/lightning/data/streaming/reader.py +++ b/src/lightning/data/streaming/reader.py @@ -34,20 +34,25 @@ if _TORCH_GREATER_EQUAL_2_1_0: class PrepareChunksThread(Thread): """This thread is responsible to download the chunks associated to a given worker.""" - def __init__(self, config: ChunksConfig, max_cache_size: Optional[int] = None) -> None: + def __init__(self, config: ChunksConfig, max_cache_size: Optional[int] = None, pre_download: int = 10) -> None: super().__init__(daemon=True) self._config = config - self._chunks_index_to_be_processed: List[int] = [] + self._chunks_index_to_be_downloaded: List[int] = [] self._chunks_index_to_be_deleted: List[int] = [] self._lock = Lock() self._max_cache_size = max_cache_size + self._downloaded_chunks = 0 + self._processed_chunks = 0 + self._processed_chunks_counter = 0 + self._delete_chunks = 0 + self._pre_download = pre_download - def add(self, chunk_indices: List[int]) -> None: + def download(self, chunk_indices: List[int]) -> None: """Receive the list of the chunk indices to download for the current epoch.""" with self._lock: for chunk_indice in chunk_indices: - if chunk_indice not in self._chunks_index_to_be_processed: - self._chunks_index_to_be_processed.append(chunk_indice) + if chunk_indice not in self._chunks_index_to_be_downloaded: + self._chunks_index_to_be_downloaded.append(chunk_indice) def delete(self, chunk_indices: List[int]) -> None: """Receive the list of the chunk indices to download for the current epoch.""" @@ -55,6 +60,8 @@ class PrepareChunksThread(Thread): for chunk_indice in chunk_indices: if chunk_indice not in self._chunks_index_to_be_deleted: self._chunks_index_to_be_deleted.append(chunk_indice) + self._processed_chunks += 1 + self._processed_chunks_counter += 1 def _delete(self, chunk_index: int) -> None: chunk_filepath, begin, _ = self._config[ChunkedIndex(index=-1, chunk_index=chunk_index)] @@ -65,25 +72,34 @@ class PrepareChunksThread(Thread): def run(self) -> None: while True: with self._lock: - if len(self._chunks_index_to_be_processed) == 0 and len(self._chunks_index_to_be_deleted) == 0: - sleep(0.005) + # Wait for something to do + if len(self._chunks_index_to_be_downloaded) == 0 and len(self._chunks_index_to_be_deleted) == 0: + sleep(0.01) continue # Delete the chunks if we are missing disk space. - # Check every 10 items to avoid losing too much time - if self._max_cache_size and len(self._chunks_index_to_be_deleted) > 10: + if self._max_cache_size and self._processed_chunks_counter >= self._pre_download: if shutil.disk_usage(self._config._cache_dir).total >= self._max_cache_size: for chunk_index in self._chunks_index_to_be_deleted: - if chunk_index not in self._chunks_index_to_be_processed: + if chunk_index not in self._chunks_index_to_be_downloaded: self._delete(chunk_index) + self._delete_chunks += 1 + self._processed_chunks_counter = 0 self._chunks_index_to_be_deleted = [] - if len(self._chunks_index_to_be_processed) == 0: + # If there is no chunks to download, go back to waiting + if len(self._chunks_index_to_be_downloaded) == 0: continue - chunk_index = self._chunks_index_to_be_processed.pop(0) + # If we have already downloaded too many chunks, let's wait for processed chunks to catch up + if self._max_cache_size and (self._downloaded_chunks - self._processed_chunks) > self._pre_download: + sleep(0.01) + continue + + chunk_index = self._chunks_index_to_be_downloaded.pop(0) self._config.download_chunk_from_index(chunk_index) + self._downloaded_chunks += 1 class BinaryReader: @@ -123,7 +139,7 @@ class BinaryReader: self._rank: Optional[int] = None self._config: Optional[ChunksConfig] = None self._prepare_thread: Optional[PrepareChunksThread] = None - self._chunks_index_to_be_processed: List[int] = [] + self._chunks_index_to_be_downloaded: List[int] = [] self._item_loader = item_loader or PyTreeLoader() self._last_chunk_index: Optional[int] = None self._max_cache_size = int(os.getenv("MAX_CACHE_SIZE", max_cache_size)) @@ -175,8 +191,8 @@ class BinaryReader: self._prepare_thread = PrepareChunksThread(self._config, self._max_cache_size) self._prepare_thread.start() if index.chunk_indexes: - self._chunks_index_to_be_processed.extend(index.chunk_indexes) - self._prepare_thread.add(index.chunk_indexes) + self._chunks_index_to_be_downloaded.extend(index.chunk_indexes) + self._prepare_thread.download(index.chunk_indexes) # If the chunk_index isn't already in the download and delete queues, add it. if index.chunk_index != self._last_chunk_index: @@ -186,7 +202,7 @@ class BinaryReader: self._prepare_thread.delete([self._last_chunk_index]) self._last_chunk_index = index.chunk_index - self._prepare_thread.add([index.chunk_index]) + self._prepare_thread.download([index.chunk_index]) # Fetch the element chunk_filepath, begin, _ = self.config[index] diff --git a/tests/tests_data/streaming/test_data_processor.py b/tests/tests_data/streaming/test_data_processor.py index e5ca10e520..df0bfe0116 100644 --- a/tests/tests_data/streaming/test_data_processor.py +++ b/tests/tests_data/streaming/test_data_processor.py @@ -223,7 +223,7 @@ def test_wait_for_file_to_exist(): def test_broadcast_object(tmpdir, monkeypatch): - data_processor = DataProcessor(input_dir=tmpdir) + data_processor = DataProcessor(input_dir=str(tmpdir)) assert data_processor._broadcast_object("dummy") == "dummy" monkeypatch.setenv("DATA_OPTIMIZER_NUM_NODES", "2") monkeypatch.setattr(data_processor_module, "_distributed_is_initialized", lambda: True) @@ -244,7 +244,7 @@ def test_cache_dir_cleanup(tmpdir, monkeypatch): assert os.listdir(cache_dir) == ["a.txt"] - data_processor = DataProcessor(input_dir=tmpdir) + data_processor = DataProcessor(input_dir=str(tmpdir)) monkeypatch.setenv("DATA_OPTIMIZER_CACHE_FOLDER", str(cache_dir)) monkeypatch.setenv("DATA_OPTIMIZER_DATA_CACHE_FOLDER", str(cache_data_dir)) data_processor._cleanup_cache() @@ -568,7 +568,7 @@ def test_data_processsor_nlp(tmpdir, monkeypatch): with open(os.path.join(tmpdir, "dummy.txt"), "w") as f: f.write("Hello World !") - data_processor = DataProcessor(input_dir=tmpdir, num_workers=1, num_downloaders=1) + data_processor = DataProcessor(input_dir=str(tmpdir), num_workers=1, num_downloaders=1) data_processor.run(TextTokenizeRecipe(chunk_size=1024 * 11)) diff --git a/tests/tests_data/streaming/test_dataset.py b/tests/tests_data/streaming/test_dataset.py index f903f8b377..d33f05591d 100644 --- a/tests/tests_data/streaming/test_dataset.py +++ b/tests/tests_data/streaming/test_dataset.py @@ -27,20 +27,20 @@ from torch.utils.data import DataLoader def test_streaming_dataset(tmpdir, monkeypatch): seed_everything(42) - dataset = StreamingDataset(input_dir=tmpdir) + dataset = StreamingDataset(input_dir=str(tmpdir)) with pytest.raises(ValueError, match="The provided dataset"): iter(dataset) - dataset = StreamingDataset(input_dir=tmpdir) + dataset = StreamingDataset(input_dir=str(tmpdir)) with pytest.raises(ValueError, match="The provided dataset"): _ = dataset[0] - cache = Cache(tmpdir, chunk_size=10) + cache = Cache(str(tmpdir), chunk_size=10) for i in range(12): cache[i] = i cache.done() cache.merge() - dataset = StreamingDataset(input_dir=tmpdir) + dataset = StreamingDataset(input_dir=str(tmpdir)) assert len(dataset) == 12 dataset_iter = iter(dataset) @@ -54,7 +54,7 @@ def test_streaming_dataset(tmpdir, monkeypatch): @mock.patch.dict(os.environ, {"LIGHTNING_CLUSTER_ID": "123", "LIGHTNING_CLOUD_PROJECT_ID": "456"}) @mock.patch("lightning.data.streaming.dataset.os.makedirs") -def test_create_cache_dir_in_lightning_cloud(makedirs_mock, tmpdir): +def test_create_cache_dir_in_lightning_cloud(makedirs_mock): # Locally, we can't actually write to the root filesystem with user privileges, so we need to mock the call dataset = StreamingDataset("dummy") expected = os.path.join("/cache", "chunks", "275876e34cf609db118f3d84b799a790", "0") @@ -67,14 +67,14 @@ def test_create_cache_dir_in_lightning_cloud(makedirs_mock, tmpdir): def test_streaming_dataset_distributed_no_shuffle(drop_last, tmpdir): seed_everything(42) - cache = Cache(tmpdir, chunk_size=10) + cache = Cache(str(tmpdir), chunk_size=10) for i in range(101): cache[i] = i cache.done() cache.merge() - dataset = StreamingDataset(input_dir=tmpdir, shuffle=False, drop_last=drop_last) + dataset = StreamingDataset(input_dir=str(tmpdir), shuffle=False, drop_last=drop_last) assert not dataset.shuffle _ = dataset[0] # init shuffler assert isinstance(dataset.shuffler, NoShuffle) @@ -98,7 +98,7 @@ def test_streaming_dataset_distributed_no_shuffle(drop_last, tmpdir): assert process_1_2[:10] == [0, 1, 2, 3, 4, 5, 6, 7, 8, 9] assert len(process_1_2) == 50 + int(not drop_last) - dataset = StreamingDataset(input_dir=tmpdir, shuffle=False, drop_last=drop_last) + dataset = StreamingDataset(input_dir=str(tmpdir), shuffle=False, drop_last=drop_last) dataset.distributed_env = _DistributedEnv(2, 1) assert len(dataset) == 50 dataset_iter = iter(dataset) @@ -148,14 +148,14 @@ def test_streaming_dataset_distributed_no_shuffle(drop_last, tmpdir): def test_streaming_dataset_distributed_full_shuffle_odd(drop_last, tmpdir): seed_everything(42) - cache = Cache(input_dir=tmpdir, chunk_size=10) + cache = Cache(input_dir=str(tmpdir), chunk_size=10) for i in range(1097): cache[i] = i cache.done() cache.merge() - dataset = StreamingDataset(input_dir=tmpdir, shuffle=True, drop_last=drop_last) + dataset = StreamingDataset(input_dir=str(tmpdir), shuffle=True, drop_last=drop_last) assert dataset.shuffle _ = dataset[0] assert isinstance(dataset.shuffler, FullShuffle) @@ -171,7 +171,7 @@ def test_streaming_dataset_distributed_full_shuffle_odd(drop_last, tmpdir): assert process_1_1[:10] == [785, 788, 782, 783, 789, 787, 786, 781, 784, 780] assert len(process_1_1) == 548 - dataset_2 = StreamingDataset(input_dir=tmpdir, shuffle=True, drop_last=drop_last) + dataset_2 = StreamingDataset(input_dir=str(tmpdir), shuffle=True, drop_last=drop_last) iter(dataset_2) assert isinstance(dataset_2.shuffler, FullShuffle) dataset_2.distributed_env = _DistributedEnv(2, 1) @@ -188,14 +188,14 @@ def test_streaming_dataset_distributed_full_shuffle_odd(drop_last, tmpdir): def test_streaming_dataset_distributed_full_shuffle_even(drop_last, tmpdir): seed_everything(42) - cache = Cache(tmpdir, chunk_size=10) + cache = Cache(str(tmpdir), chunk_size=10) for i in range(1222): cache[i] = i cache.done() cache.merge() - dataset = StreamingDataset(input_dir=tmpdir, shuffle=True, drop_last=drop_last) + dataset = StreamingDataset(input_dir=str(tmpdir), shuffle=True, drop_last=drop_last) assert dataset.shuffle _ = dataset[0] assert isinstance(dataset.shuffler, FullShuffle) @@ -211,7 +211,7 @@ def test_streaming_dataset_distributed_full_shuffle_even(drop_last, tmpdir): assert process_1_1[:10] == [185, 184, 182, 189, 187, 181, 183, 180, 186, 188] assert len(process_1_1) == 611 - dataset_2 = StreamingDataset(input_dir=tmpdir, shuffle=True, drop_last=drop_last) + dataset_2 = StreamingDataset(input_dir=str(tmpdir), shuffle=True, drop_last=drop_last) iter(dataset_2) assert isinstance(dataset_2.shuffler, FullShuffle) dataset_2.distributed_env = _DistributedEnv(2, 1) @@ -256,14 +256,14 @@ def test_streaming_dataset_deepcopy(tmpdir, monkeypatch): def test_dataset_cache_recreation(tmpdir): """Test that we recreate the cache and other objects only when appropriate.""" - cache = Cache(tmpdir, chunk_size=10) + cache = Cache(str(tmpdir), chunk_size=10) for i in range(10): cache[i] = i cache.done() cache.merge() # repated `len()` calls - dataset = StreamingDataset(input_dir=tmpdir) + dataset = StreamingDataset(input_dir=str(tmpdir)) assert not dataset.cache assert not dataset.shuffler len(dataset) @@ -274,7 +274,7 @@ def test_dataset_cache_recreation(tmpdir): assert dataset.shuffler is shuffler # repeated `iter()` calls - dataset = StreamingDataset(input_dir=tmpdir) + dataset = StreamingDataset(input_dir=str(tmpdir)) assert not dataset.cache assert not dataset.shuffler iter(dataset) @@ -289,7 +289,7 @@ def test_dataset_cache_recreation(tmpdir): assert dataset.shuffler is not shuffler # shuffler gets recreated # repeated `getitem()` calls - dataset = StreamingDataset(input_dir=tmpdir) + dataset = StreamingDataset(input_dir=str(tmpdir)) assert not dataset.cache assert not dataset.shuffler _ = dataset[0] diff --git a/tests/tests_data/streaming/test_reader.py b/tests/tests_data/streaming/test_reader.py index 19e9105f5f..93ddebf873 100644 --- a/tests/tests_data/streaming/test_reader.py +++ b/tests/tests_data/streaming/test_reader.py @@ -43,7 +43,37 @@ def test_reader_chunk_removal(tmpdir, monkeypatch): shutil_mock.disk_usage.return_value = disk_usage monkeypatch.setattr(reader, "shutil", shutil_mock) + expected = [] for i in range(25): + expected.append([i, len(os.listdir(cache_dir))]) assert cache[i] == i + assert expected == [ + [0, 0], + [1, 1], + [2, 1], + [3, 2], + [4, 2], + [5, 3], + [6, 3], + [7, 4], + [8, 4], + [9, 5], + [10, 5], + [11, 6], + [12, 6], + [13, 7], + [14, 7], + [15, 8], + [16, 8], + [17, 9], + [18, 9], + [19, 10], + [20, 10], + [21, 11], + [22, 11], # Cleanup is triggered + [23, 2], + [24, 2], + ] + assert len(os.listdir(cache_dir)) == 3 diff --git a/tests/tests_data/streaming/test_writer.py b/tests/tests_data/streaming/test_writer.py index 80a3e9fd38..2128a73b35 100644 --- a/tests/tests_data/streaming/test_writer.py +++ b/tests/tests_data/streaming/test_writer.py @@ -13,6 +13,7 @@ import json import os +import sys import numpy as np import pytest @@ -97,7 +98,7 @@ def test_binary_writer_with_ints_and_chunk_size(tmpdir): assert data == {"i": i, "i+1": i + 1, "i+2": i + 2} -@pytest.mark.skipif(condition=not _PIL_AVAILABLE, reason="Requires: ['pil']") +@pytest.mark.skipif(condition=not _PIL_AVAILABLE or sys.platform == "darwin", reason="Requires: ['pil']") def test_binary_writer_with_jpeg_and_int(tmpdir): """Validate the writer and reader can serialize / deserialize a pair of image and label.""" from PIL import Image @@ -136,7 +137,7 @@ def test_binary_writer_with_jpeg_and_int(tmpdir): assert data["y"] == i -@pytest.mark.skipif(condition=not _PIL_AVAILABLE, reason="Requires: ['pil']") +@pytest.mark.skipif(condition=not _PIL_AVAILABLE or sys.platform == "darwin", reason="Requires: ['pil']") def test_binary_writer_with_jpeg_filepath_and_int(tmpdir): """Validate the writer and reader can serialize / deserialize a pair of image and label.""" from PIL import Image