From 6e517bd55b50166138ce6ab915abd4547702994b Mon Sep 17 00:00:00 2001 From: thomas chaton Date: Thu, 16 Nov 2023 18:06:58 -0500 Subject: [PATCH] Resolve Item Loader bugs (#19017) Co-authored-by: thomas Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- src/lightning/data/streaming/dataset.py | 20 +- src/lightning/data/streaming/functions.py | 10 +- src/lightning/data/streaming/item_loader.py | 13 +- src/lightning/data/streaming/reader.py | 34 +++- src/lightning/data/streaming/sampler.py | 1 + src/lightning/data/streaming/shuffle.py | 12 +- tests/tests_data/streaming/test_cache.py | 35 ++++ tests/tests_data/streaming/test_dataset.py | 210 +++++++++++++++++++- tests/tests_data/streaming/test_reader.py | 17 +- 9 files changed, 319 insertions(+), 33 deletions(-) diff --git a/src/lightning/data/streaming/dataset.py b/src/lightning/data/streaming/dataset.py index fe37a366d1..495e36b5da 100644 --- a/src/lightning/data/streaming/dataset.py +++ b/src/lightning/data/streaming/dataset.py @@ -11,7 +11,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -import copy import hashlib import os from typing import Any, Dict, List, Optional, Union @@ -84,12 +83,16 @@ class StreamingDataset(IterableDataset): def _create_cache(self, worker_env: _WorkerEnv) -> Cache: env = Environment(dist_env=self.distributed_env, worker_env=worker_env) - cache_path = _try_create_cache_dir(input_dir=self.input_dir.path, shard_rank=env.shard_rank) - cache_dir = copy.deepcopy(self.input_dir) - if cache_path: - cache_dir.path = cache_path - cache = Cache(input_dir=cache_dir, item_loader=self.item_loader, chunk_bytes=1, serializers=self.serializers) + # TODO: Move this to lightning-cloud + if "this_" not in self.input_dir.path: + cache_path = _try_create_cache_dir(input_dir=self.input_dir.path, shard_rank=env.shard_rank) + if cache_path is not None: + self.input_dir.path = cache_path + + cache = Cache( + input_dir=self.input_dir, item_loader=self.item_loader, chunk_bytes=1, serializers=self.serializers + ) cache._reader._try_load_config() if not cache.filled: @@ -136,6 +139,7 @@ class StreamingDataset(IterableDataset): self.current_indexes = [] self.chunk_index = 0 self.index = 0 + self.has_triggered_download = False return self @@ -167,6 +171,8 @@ class StreamingDataset(IterableDataset): self.current_indexes = self.shuffler(current_indexes) self.chunk_index += 1 + last_index = self.chunk_index == len(self.worker_intervals) and len(self.current_indexes) == 1 + # Get the first index index = self.current_indexes.pop(0) @@ -175,7 +181,9 @@ class StreamingDataset(IterableDataset): ChunkedIndex( index=index, chunk_index=self.worker_chunks[self.chunk_index - 1], + # We provide the chunks indexes only one the first chunk_indexes=None if self.has_triggered_download else self.worker_chunks, + last_index=last_index, ) ) diff --git a/src/lightning/data/streaming/functions.py b/src/lightning/data/streaming/functions.py index 3d4e87eaa0..c8683579c9 100644 --- a/src/lightning/data/streaming/functions.py +++ b/src/lightning/data/streaming/functions.py @@ -42,7 +42,7 @@ def _get_input_dir(inputs: Sequence[Any]) -> str: if len(indexed_paths) == 0: raise ValueError(f"The provided item {inputs[0]} didn't contain any filepaths.") - absolute_path = str(Path(indexed_paths[0]).resolve()) + absolute_path = str(Path(list(indexed_paths.values())[0]).resolve()) if indexed_paths[0] != absolute_path: raise ValueError("The provided path should be absolute.") @@ -128,6 +128,7 @@ def map( num_nodes: Optional[int] = None, machine: Optional[str] = None, num_downloaders: Optional[int] = None, + reorder_files: bool = True, ) -> None: """This function map a callbable over a collection of files possibly in a distributed way. @@ -141,6 +142,8 @@ def map( num_nodes: When doing remote execution, the number of nodes to use. machine: When doing remote execution, the machine to use. num_downloaders: The number of downloaders per worker. + reorder_files: By default, reorders the files by file size to distribute work equally among all workers. + Set this to ``False`` if the order in which samples are processed should be preserved. """ if not isinstance(inputs, Sequence): @@ -168,6 +171,7 @@ def map( num_workers=num_workers or os.cpu_count(), fast_dev_run=fast_dev_run, num_downloaders=num_downloaders, + reorder_files=reorder_files, ) return data_processor.run(LambdaDataTransformRecipe(fn, inputs)) return _execute( @@ -189,6 +193,7 @@ def optimize( num_nodes: Optional[int] = None, machine: Optional[str] = None, num_downloaders: Optional[int] = None, + reorder_files: bool = True, ) -> None: """This function converts a dataset into chunks possibly in a distributed way. @@ -205,6 +210,8 @@ def optimize( num_nodes: When doing remote execution, the number of nodes to use. machine: When doing remote execution, the machine to use. num_downloaders: The number of downloaders per worker. + reorder_files: By default, reorders the files by file size to distribute work equally among all workers. + Set this to ``False`` if the order in which samples are processed should be preserved. """ if not isinstance(inputs, Sequence): @@ -235,6 +242,7 @@ def optimize( num_workers=num_workers or os.cpu_count(), fast_dev_run=fast_dev_run, num_downloaders=num_downloaders, + reorder_files=reorder_files, ) return data_processor.run( LambdaDataChunkRecipe( diff --git a/src/lightning/data/streaming/item_loader.py b/src/lightning/data/streaming/item_loader.py index 8f8cd2dbce..1b028369ab 100644 --- a/src/lightning/data/streaming/item_loader.py +++ b/src/lightning/data/streaming/item_loader.py @@ -110,7 +110,6 @@ class TokensLoader(BaseItemLoader): super().__init__() self._block_size = block_size - self._intervals: List[Tuple[int, int]] = [] self._mmaps: Dict[int, np.memmap] = {} self._buffers: Dict[int, bytes] = {} self._dtype: Optional[torch.dtype] = None @@ -123,16 +122,16 @@ class TokensLoader(BaseItemLoader): raise ValueError("The provided chunks isn't properly setup.") def generate_intervals(self) -> List[Tuple[int, int]]: + intervals = [] begin = 0 end = 0 for chunk in self._chunks: dim = chunk["dim"] num_blocks = dim // self._block_size end += num_blocks - self._intervals.append((begin, end)) + intervals.append((begin, end)) begin += num_blocks - - return self._intervals + return intervals def load_item_from_chunk(self, index: int, chunk_index: int, chunk_filepath: str, begin: int) -> torch.Tensor: if chunk_filepath in self._chunk_filepaths and not os.path.isfile(chunk_filepath): @@ -149,6 +148,10 @@ class TokensLoader(BaseItemLoader): if chunk_index not in self._mmaps: # TODO: Add deletion and memmap close chunk = self._chunks[chunk_index] + + # Skip the header + # The number of items + the number of offsets (number of items in the chunk + 1) + # multiplied by the header encoding dtype (np.uint32) offset = (1 + chunk["chunk_size"] + 1) * 4 mmap = np.memmap(chunk_filepath, mode="r", order="C", offset=offset) self._mmaps[chunk_index] = mmap @@ -157,5 +160,5 @@ class TokensLoader(BaseItemLoader): assert self._dtype buffer: bytes = self._buffers[chunk_index] - offset = self._dtype.itemsize * ((index - begin) if index >= begin else index + 1) + offset = self._dtype.itemsize * (index - begin) * self._block_size return torch.frombuffer(buffer, dtype=self._dtype, count=self._block_size, offset=offset) diff --git a/src/lightning/data/streaming/reader.py b/src/lightning/data/streaming/reader.py index 1912305e43..7a3b84ad70 100644 --- a/src/lightning/data/streaming/reader.py +++ b/src/lightning/data/streaming/reader.py @@ -46,6 +46,7 @@ class PrepareChunksThread(Thread): self._processed_chunks_counter = 0 self._delete_chunks = 0 self._pre_download = pre_download + self._should_stop = False def download(self, chunk_indices: List[int]) -> None: """Receive the list of the chunk indices to download for the current epoch.""" @@ -69,12 +70,28 @@ class PrepareChunksThread(Thread): if os.path.exists(chunk_filepath): os.remove(chunk_filepath) + def stop(self) -> None: + """Receive the list of the chunk indices to download for the current epoch.""" + with self._lock: + self._should_stop = True + def run(self) -> None: while True: with self._lock: + if self._should_stop: + if ( + self._max_cache_size + and self._max_cache_size <= shutil.disk_usage(self._config._cache_dir).total + ): + for chunk_index in self._chunks_index_to_be_deleted: + if chunk_index not in self._chunks_index_to_be_downloaded: + self._delete(chunk_index) + self._delete_chunks += 1 + self._processed_chunks_counter = 0 + return + # Wait for something to do if len(self._chunks_index_to_be_downloaded) == 0 and len(self._chunks_index_to_be_deleted) == 0: - sleep(0.01) continue # Delete the chunks if we are missing disk space. @@ -93,7 +110,7 @@ class PrepareChunksThread(Thread): # If we have already downloaded too many chunks, let's wait for processed chunks to catch up if self._max_cache_size and (self._downloaded_chunks - self._processed_chunks) > self._pre_download: - sleep(0.01) + sleep(0.1) continue chunk_index = self._chunks_index_to_be_downloaded.pop(0) @@ -101,6 +118,9 @@ class PrepareChunksThread(Thread): self._config.download_chunk_from_index(chunk_index) self._downloaded_chunks += 1 + # Sleep to release the lock + sleep(0.1) + class BinaryReader: def __init__( @@ -141,7 +161,6 @@ class BinaryReader: self._rank: Optional[int] = None self._config: Optional[ChunksConfig] = None self._prepare_thread: Optional[PrepareChunksThread] = None - self._chunks_index_to_be_downloaded: List[int] = [] self._item_loader = item_loader or PyTreeLoader() self._last_chunk_index: Optional[int] = None self._max_cache_size = int(os.getenv("MAX_CACHE_SIZE", max_cache_size)) @@ -193,7 +212,6 @@ class BinaryReader: self._prepare_thread = PrepareChunksThread(self._config, self._max_cache_size) self._prepare_thread.start() if index.chunk_indexes: - self._chunks_index_to_be_downloaded.extend(index.chunk_indexes) self._prepare_thread.download(index.chunk_indexes) # If the chunk_index isn't already in the download and delete queues, add it. @@ -208,7 +226,13 @@ class BinaryReader: # Fetch the element chunk_filepath, begin, _ = self.config[index] - return self._item_loader.load_item_from_chunk(index.index, index.chunk_index, chunk_filepath, begin) + item = self._item_loader.load_item_from_chunk(index.index, index.chunk_index, chunk_filepath, begin) + + if index.last_index and self._prepare_thread: + self._prepare_thread.stop() + self._prepare_thread = None + + return item def get_length(self) -> int: """Get the number of samples across all chunks.""" diff --git a/src/lightning/data/streaming/sampler.py b/src/lightning/data/streaming/sampler.py index cec70466fc..c4f4d35602 100644 --- a/src/lightning/data/streaming/sampler.py +++ b/src/lightning/data/streaming/sampler.py @@ -25,6 +25,7 @@ class ChunkedIndex: index: int chunk_index: int chunk_indexes: Optional[List[int]] = None + last_index: bool = False class CacheBatchSampler: diff --git a/src/lightning/data/streaming/shuffle.py b/src/lightning/data/streaming/shuffle.py index 4d96b50be4..d389cc3e66 100644 --- a/src/lightning/data/streaming/shuffle.py +++ b/src/lightning/data/streaming/shuffle.py @@ -13,7 +13,7 @@ from abc import ABC, abstractmethod from functools import lru_cache -from typing import Any, List +from typing import Any, List, Tuple import numpy as np @@ -58,15 +58,11 @@ class NoShuffle(Shuffle): @lru_cache(maxsize=10) def get_chunks_and_intervals_per_ranks(self, distributed_env: _DistributedEnv, current_epoch: int) -> Any: - self.random_state = np.random.RandomState(seed=self.seed + current_epoch) # type: ignore chunk_intervals = self.cache.get_chunk_intervals() - indexes = list(range(len(chunk_intervals))) - shuffled_chunk_intervals = np.asarray(chunk_intervals)[indexes] - chunks_per_ranks: List[List[int]] = [[] for _ in range(distributed_env.world_size)] - intervals_per_ranks: List[List[List[int]]] = [[] for _ in range(distributed_env.world_size)] - for index, (chunk_index, chunk_interval) in enumerate(zip(indexes, shuffled_chunk_intervals)): - replica_index = index % distributed_env.world_size + intervals_per_ranks: List[List[Tuple]] = [[] for _ in range(distributed_env.world_size)] + for chunk_index, chunk_interval in enumerate(chunk_intervals): + replica_index = chunk_index % distributed_env.world_size chunks_per_ranks[replica_index].append(chunk_index) intervals_per_ranks[replica_index].append(chunk_interval) diff --git a/tests/tests_data/streaming/test_cache.py b/tests/tests_data/streaming/test_cache.py index d95be38e58..111e739394 100644 --- a/tests/tests_data/streaming/test_cache.py +++ b/tests/tests_data/streaming/test_cache.py @@ -22,6 +22,7 @@ from lightning import seed_everything from lightning.data.streaming import Cache from lightning.data.streaming.dataloader import StreamingDataLoader from lightning.data.streaming.dataset import StreamingDataset +from lightning.data.streaming.item_loader import TokensLoader from lightning.data.streaming.serializers import Serializer from lightning.data.utilities.env import _DistributedEnv from lightning.fabric import Fabric @@ -276,3 +277,37 @@ def test_custom_serializer(tmpdir): cache.done() cache.merge() assert isinstance(cache[0][0], bytes) + + +def test_cache_for_text_tokens(tmpdir): + seed_everything(42) + + block_size = 1024 + 1 + cache = Cache(input_dir=str(tmpdir), chunk_size=block_size * 11, item_loader=TokensLoader(block_size)) + text_idxs_list = [] + + counter = 0 + while True: + text_ids = torch.randint(0, 1000, (np.random.randint(0, 1000),)).to(torch.int) + text_idxs_list.append(text_ids) + chunk_filepath = cache._add_item(counter, text_ids) + if chunk_filepath: + break + counter += 1 + + cache.done() + cache.merge() + + assert len(cache) == 10 + + cache_0 = cache[0] + cache_1 = cache[1] + assert len(cache_0) == block_size + assert len(cache_1) == block_size + assert not torch.equal(cache_0, cache[1]) + indices = torch.cat(text_idxs_list, dim=0) + assert torch.equal(cache_0, indices[: len(cache_0)]) + assert torch.equal(cache_1, indices[len(cache_0) : len(cache_0) + len(cache_1)]) + + with pytest.raises(ValueError, match="TokensLoader"): + len(Cache(str(tmpdir), chunk_size=block_size * 11)) diff --git a/tests/tests_data/streaming/test_dataset.py b/tests/tests_data/streaming/test_dataset.py index ea4c2cba6b..e5c58661ad 100644 --- a/tests/tests_data/streaming/test_dataset.py +++ b/tests/tests_data/streaming/test_dataset.py @@ -15,10 +15,13 @@ import os import sys from unittest import mock +import numpy as np import pytest +import torch from lightning import seed_everything -from lightning.data.streaming import Cache +from lightning.data.streaming import Cache, functions from lightning.data.streaming.dataset import StreamingDataset, _try_create_cache_dir +from lightning.data.streaming.item_loader import TokensLoader from lightning.data.streaming.shuffle import FullShuffle, NoShuffle from lightning.data.utilities.env import _DistributedEnv from torch.utils.data import DataLoader @@ -326,3 +329,208 @@ def test_try_create_cache_dir(): assert _try_create_cache_dir("dir", shard_rank=2) == os.path.join( "/cache", "chunks", "736007832d2167baaae763fd3a3f3cf1", "2" ) + + +def test_dataset_for_text_tokens(tmpdir): + seed_everything(42) + + block_size = 1024 + 1 + cache = Cache(input_dir=str(tmpdir), chunk_size=block_size * 11, item_loader=TokensLoader(block_size)) + text_idxs_list = [] + + counter = 0 + while True: + text_ids = torch.randint(0, 1000, (np.random.randint(0, 1000),)).to(torch.int) + text_idxs_list.append(text_ids) + chunk_filepath = cache._add_item(counter, text_ids) + if chunk_filepath: + break + counter += 1 + + cache.done() + cache.merge() + + dataset = StreamingDataset(input_dir=str(tmpdir), item_loader=TokensLoader(block_size)) + + assert len(dataset) == 10 + + cache_0 = dataset[0] + cache_1 = dataset[1] + cache_2 = dataset[2] + cache_3 = dataset[3] + assert len(cache_0) == block_size + assert len(cache_1) == block_size + assert not torch.equal(cache_0, cache[1]) + indices = torch.cat(text_idxs_list, dim=0) + assert torch.equal(cache_0, indices[: len(cache_0)]) + assert torch.equal(cache_1, indices[len(cache_0) : len(cache_0) + len(cache_1)]) + + dataloader = DataLoader(StreamingDataset(input_dir=str(tmpdir), item_loader=TokensLoader(block_size)), batch_size=2) + + for batch_idx, batch in enumerate(dataloader): + if batch_idx == 0: + assert torch.equal(torch.stack([cache_0, cache_1]), batch) + elif batch_idx == 1: + assert torch.equal(torch.stack([cache_2, cache_3]), batch) + else: + break + + +def test_dataset_for_text_tokens_multiple_workers(tmpdir): + seed_everything(42) + + block_size = 10 + cache = Cache(input_dir=str(tmpdir), chunk_size=40, item_loader=TokensLoader(block_size)) + + counter = 0 + for i in range(10): + text_ids = torch.arange(counter, counter + 20).to(torch.int) + cache[i] = text_ids + counter += 20 + + cache.done() + cache.merge() + + for i in range(20): + sequence = cache[i] + assert sequence[0].item() == i * block_size + assert sequence[-1].item() == (i + 1) * block_size - 1 + + assert len(os.listdir(tmpdir)) == 6 + + dataset = StreamingDataset(input_dir=str(tmpdir), item_loader=TokensLoader(block_size), shuffle=False) + + assert len(dataset) == 20 + + dataloader = DataLoader(dataset, batch_size=2, num_workers=2, shuffle=False) + + assert len(dataloader) == 10 + + expected = [ + [0, 10], + [40, 50], + [20, 30], + [60, 70], + [80, 90], + [120, 130], + [100, 110], + [140, 150], + [160, 170], + [180, 190], + ] + + for result, batch in zip(expected, dataloader): + assert [batch[0][0].item(), batch[1][0].item()] == result + + +def test_dataset_for_text_tokens_distributed_num_workers(tmpdir): + seed_everything(42) + + block_size = 10 + cache = Cache(input_dir=str(tmpdir), chunk_size=40, item_loader=TokensLoader(block_size)) + + counter = 0 + for i in range(10): + text_ids = torch.arange(counter, counter + 20).to(torch.int) + cache[i] = text_ids + counter += 20 + + cache.done() + cache.merge() + + for i in range(20): + sequence = cache[i] + assert sequence[0].item() == i * block_size + assert sequence[-1].item() == (i + 1) * block_size - 1 + + assert len([f for f in os.listdir(tmpdir) if f.endswith(".bin")]) == 5 + + dataset = StreamingDataset(input_dir=str(tmpdir), item_loader=TokensLoader(block_size), shuffle=False) + + assert len(dataset) == 20 + + dataset.distributed_env = _DistributedEnv(2, 0) + dataloader = DataLoader(dataset, batch_size=2, shuffle=False, num_workers=2) + + assert len(dataloader) == 6 + + expected = [[0, 10], [80, 90], [20, 30], [100, 110], [160, 170], [180, 190]] + + for batch_idx, batch in enumerate(dataloader): + assert [batch[0][0].item(), batch[1][0].item()] == expected[batch_idx] + + dataset.distributed_env = _DistributedEnv(2, 1) + dataloader = DataLoader(dataset, batch_size=2, shuffle=False) + + assert len(dataloader) == 4 + + expected = [[40, 50], [60, 70], [120, 130], [140, 150]] + + for batch_idx, batch in enumerate(dataloader): + assert [batch[0][0].item(), batch[1][0].item()] == expected[batch_idx] + + for batch_idx, batch in enumerate(dataloader): + assert [batch[0][0].item(), batch[1][0].item()] == expected[batch_idx] + + +def optimize_fn(item): + return torch.arange(item[0], item[0] + 20).to(torch.int) + + +def test_dataset_for_text_tokens_distributed_num_workers_end_to_end(tmpdir, monkeypatch): + monkeypatch.setattr(functions, "_get_input_dir", lambda x: str(tmpdir)) + + seed_everything(42) + + with open(tmpdir / "a.txt", "w") as f: + f.write("hello") + + inputs = [(v, str(tmpdir / "a.txt")) for v in range(0, 200, 20)] + + cache_dir = os.path.join(tmpdir, "cache") + output_dir = os.path.join(tmpdir, "target_dir") + os.makedirs(output_dir, exist_ok=True) + monkeypatch.setenv("DATA_OPTIMIZER_CACHE_FOLDER", cache_dir) + monkeypatch.setenv("DATA_OPTIMIZER_DATA_CACHE_FOLDER", cache_dir) + + def fn(item): + return torch.arange(item[0], item[0] + 20).to(torch.int) + + functions.optimize( + optimize_fn, inputs, output_dir=str(tmpdir), num_workers=2, chunk_size=2, reorder_files=False, num_downloaders=1 + ) + + assert len([f for f in os.listdir(tmpdir) if f.endswith(".bin")]) == 10 + + block_size = 10 + dataset = StreamingDataset(input_dir=str(tmpdir), item_loader=TokensLoader(block_size), shuffle=False) + + L = len(dataset) + assert len(dataset) == L + + for i in range(L): + sequence = dataset[i] + assert sequence[0].item() == i * block_size + assert sequence[-1].item() == (i + 1) * block_size - 1 + + dataset = StreamingDataset(input_dir=str(tmpdir), item_loader=TokensLoader(block_size), shuffle=False) + + dataset.distributed_env = _DistributedEnv(2, 0) + dataloader = DataLoader(dataset, batch_size=2, shuffle=False, num_workers=2) + + assert len(dataloader) == 5 + + expected = [[0, 10], [40, 50], [80, 90], [120, 130], [160, 170]] + + for batch_idx, batch in enumerate(dataloader): + assert [batch[0][0].item(), batch[1][0].item()] == expected[batch_idx] + + dataset.distributed_env = _DistributedEnv(2, 1) + dataloader = DataLoader(dataset, batch_size=2, shuffle=False) + + assert len(dataloader) == 5 + + expected = [[20, 30], [60, 70], [100, 110], [140, 150], [180, 190]] + + for batch_idx, batch in enumerate(dataloader): + assert [batch[0][0].item(), batch[1][0].item()] == expected[batch_idx] diff --git a/tests/tests_data/streaming/test_reader.py b/tests/tests_data/streaming/test_reader.py index 93ddebf873..226945a576 100644 --- a/tests/tests_data/streaming/test_reader.py +++ b/tests/tests_data/streaming/test_reader.py @@ -4,6 +4,7 @@ from unittest import mock from lightning.data.streaming import reader from lightning.data.streaming.cache import Cache +from lightning.data.streaming.config import ChunkedIndex from lightning_cloud.resolver import Dir @@ -30,7 +31,8 @@ def test_reader_chunk_removal(tmpdir, monkeypatch): os.makedirs(cache_dir, exist_ok=True) for i in range(25): - assert cache[i] == i + index = ChunkedIndex(i, cache._get_chunk_index_from_index(i), last_index=i == 24) + assert cache[index] == i assert len(os.listdir(cache_dir)) == 14 @@ -46,7 +48,8 @@ def test_reader_chunk_removal(tmpdir, monkeypatch): expected = [] for i in range(25): expected.append([i, len(os.listdir(cache_dir))]) - assert cache[i] == i + index = ChunkedIndex(i, cache._get_chunk_index_from_index(i), last_index=i == 24) + assert cache[index] == i assert expected == [ [0, 0], @@ -70,10 +73,10 @@ def test_reader_chunk_removal(tmpdir, monkeypatch): [18, 9], [19, 10], [20, 10], - [21, 11], - [22, 11], # Cleanup is triggered - [23, 2], - [24, 2], + [21, 2], + [22, 2], + [23, 3], + [24, 3], ] - assert len(os.listdir(cache_dir)) == 3 + assert len(os.listdir(cache_dir)) in [3, 4]