From 85933f355a14eb1a38759acd4f249178e4949486 Mon Sep 17 00:00:00 2001 From: thomas chaton Date: Wed, 1 Nov 2023 09:35:35 +0000 Subject: [PATCH] Improve map and chunkify (#18901) Co-authored-by: thomas Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- requirements/app/app.txt | 2 +- src/lightning/data/__init__.py | 3 +- src/lightning/data/streaming/cache.py | 4 +- src/lightning/data/streaming/constants.py | 2 +- .../data/streaming/data_processor.py | 55 +++-- src/lightning/data/streaming/functions.py | 207 ++++++++++++++++++ src/lightning/data/streaming/item_loader.py | 3 + src/lightning/data/streaming/writer.py | 15 +- .../streaming/test_data_processor.py | 73 +++++- 9 files changed, 332 insertions(+), 32 deletions(-) create mode 100644 src/lightning/data/streaming/functions.py diff --git a/requirements/app/app.txt b/requirements/app/app.txt index a6ac03f4be..5b14c88b64 100644 --- a/requirements/app/app.txt +++ b/requirements/app/app.txt @@ -1,4 +1,4 @@ -lightning-cloud ==0.5.44 # Must be pinned to ensure compatibility +lightning-cloud ==0.5.46 # Must be pinned to ensure compatibility packaging typing-extensions >=4.0.0, <4.8.0 deepdiff >=5.7.0, <6.6.0 diff --git a/src/lightning/data/__init__.py b/src/lightning/data/__init__.py index 3a349ec7c9..cd86acb08a 100644 --- a/src/lightning/data/__init__.py +++ b/src/lightning/data/__init__.py @@ -1,10 +1,11 @@ from lightning.data.datasets import LightningDataset, LightningIterableDataset from lightning.data.streaming.dataset import StreamingDataset -from lightning.data.streaming.map import map +from lightning.data.streaming.functions import map, optimize __all__ = [ "LightningDataset", "StreamingDataset", "LightningIterableDataset", "map", + "optimize", ] diff --git a/src/lightning/data/streaming/cache.py b/src/lightning/data/streaming/cache.py index e2df4bded0..520ec649bb 100644 --- a/src/lightning/data/streaming/cache.py +++ b/src/lightning/data/streaming/cache.py @@ -18,7 +18,7 @@ from typing import Any, Dict, List, Literal, Optional, Tuple, Union from lightning.data.datasets.env import _DistributedEnv from lightning.data.streaming.constants import ( _INDEX_FILENAME, - _LIGHTNING_CLOUD_GREATER_EQUAL_0_5_42, + _LIGHTNING_CLOUD_GREATER_EQUAL_0_5_46, _TORCH_GREATER_EQUAL_2_1_0, ) from lightning.data.streaming.item_loader import BaseItemLoader @@ -26,7 +26,7 @@ from lightning.data.streaming.reader import BinaryReader from lightning.data.streaming.sampler import ChunkedIndex from lightning.data.streaming.writer import BinaryWriter -if _LIGHTNING_CLOUD_GREATER_EQUAL_0_5_42: +if _LIGHTNING_CLOUD_GREATER_EQUAL_0_5_46: from lightning_cloud.resolver import _find_remote_dir, _try_create_cache_dir logger = logging.Logger(__name__) diff --git a/src/lightning/data/streaming/constants.py b/src/lightning/data/streaming/constants.py index feb0dec536..722baf111f 100644 --- a/src/lightning/data/streaming/constants.py +++ b/src/lightning/data/streaming/constants.py @@ -21,7 +21,7 @@ _DEFAULT_FAST_DEV_RUN_ITEMS = 10 # This is required for full pytree serialization / deserialization support _TORCH_GREATER_EQUAL_2_1_0 = RequirementCache("torch>=2.1.0") _VIZ_TRACKER_AVAILABLE = RequirementCache("viztracer") -_LIGHTNING_CLOUD_GREATER_EQUAL_0_5_42 = RequirementCache("lightning-cloud>=0.5.42") +_LIGHTNING_CLOUD_GREATER_EQUAL_0_5_46 = RequirementCache("lightning-cloud>=0.5.46") _BOTO3_AVAILABLE = RequirementCache("boto3") # DON'T CHANGE ORDER diff --git a/src/lightning/data/streaming/data_processor.py b/src/lightning/data/streaming/data_processor.py index de6db32707..5e9f19b3d2 100644 --- a/src/lightning/data/streaming/data_processor.py +++ b/src/lightning/data/streaming/data_processor.py @@ -5,11 +5,12 @@ import tempfile import traceback import types from abc import abstractmethod +from dataclasses import dataclass from multiprocessing import Process, Queue from queue import Empty from shutil import copyfile, rmtree from time import sleep, time -from typing import Any, Callable, Dict, List, Optional, Tuple, TypeVar +from typing import Any, Callable, Dict, List, Optional, Tuple, TypeVar, Union from urllib import parse import torch @@ -21,7 +22,7 @@ from lightning.data.streaming.constants import ( _BOTO3_AVAILABLE, _DEFAULT_FAST_DEV_RUN_ITEMS, _INDEX_FILENAME, - _LIGHTNING_CLOUD_GREATER_EQUAL_0_5_42, + _LIGHTNING_CLOUD_GREATER_EQUAL_0_5_46, _TORCH_GREATER_EQUAL_2_1_0, ) from lightning.fabric.accelerators.cuda import is_cuda_available @@ -35,7 +36,7 @@ from lightning.fabric.utilities.distributed import group as _group if _TORCH_GREATER_EQUAL_2_1_0: from torch.utils._pytree import tree_flatten, tree_unflatten -if _LIGHTNING_CLOUD_GREATER_EQUAL_0_5_42: +if _LIGHTNING_CLOUD_GREATER_EQUAL_0_5_46: from lightning_cloud.resolver import _LightningSrcResolver, _LightningTargetResolver if _BOTO3_AVAILABLE: @@ -160,10 +161,14 @@ def _remove_target(input_dir: str, cache_dir: str, queue_in: Queue) -> None: # 3. Iterate through the paths and delete them sequentially. for path in paths: if input_dir: - cached_filepath = path.replace(input_dir, cache_dir) + if not path.startswith(cache_dir): + path = path.replace(input_dir, cache_dir) - if os.path.exists(cached_filepath): - os.remove(cached_filepath) + if os.path.exists(path): + os.remove(path) + + elif os.path.exists(path) and "s3_connections" not in path: + os.remove(path) def _upload_fn(upload_queue: Queue, remove_queue: Queue, cache_dir: str, remote_output_dir: str) -> None: @@ -387,7 +392,9 @@ class BaseWorker: } if len(indexed_paths) == 0: - raise ValueError(f"The provided item {item} didn't contain any filepaths. {flattened_item}") + raise ValueError( + f"The provided item {item} didn't contain any filepaths. The input_dir is {self.input_dir}." + ) paths = [] for index, path in indexed_paths.items(): @@ -548,7 +555,7 @@ class DataRecipe: def _setup(self, name: Optional[str]) -> None: self._name = name - def _done(self, delete_cached_files: bool, remote_output_dir: str) -> None: + def _done(self, delete_cached_files: bool, remote_output_dir: Any) -> None: pass @@ -578,7 +585,6 @@ class DataChunkRecipe(DataRecipe): def _done(self, delete_cached_files: bool, remote_output_dir: str) -> None: num_nodes = _get_num_nodes() - assert self._name cache_dir = _get_cache_dir(self._name) chunks = [file for file in os.listdir(cache_dir) if file.endswith(".bin")] @@ -647,6 +653,14 @@ class DataTransformRecipe(DataRecipe): """Use your item metadata to process your files and save the file outputs into `output_dir`.""" +@dataclass +class PrettyDirectory: + """Holds a directory and its URL.""" + + directory: str + url: str + + class DataProcessor: def __init__( self, @@ -656,10 +670,11 @@ class DataProcessor: num_downloaders: Optional[int] = None, delete_cached_files: bool = True, src_resolver: Optional[Callable[[str], Optional[str]]] = None, - fast_dev_run: Optional[bool] = None, + fast_dev_run: Optional[Union[bool, int]] = None, remote_input_dir: Optional[str] = None, - remote_output_dir: Optional[str] = None, + remote_output_dir: Optional[Union[str, PrettyDirectory]] = None, random_seed: Optional[int] = 42, + version: Optional[int] = None, ): """The `DatasetOptimiser` provides an efficient way to process data across multiple machine into chunks to make training faster. @@ -692,18 +707,22 @@ class DataProcessor: self.remote_input_dir = ( str(remote_input_dir) if remote_input_dir is not None - else ((self.src_resolver(input_dir) if input_dir else None) if self.src_resolver else None) + else ((self.src_resolver(str(input_dir)) if input_dir else None) if self.src_resolver else None) ) self.remote_output_dir = ( remote_output_dir if remote_output_dir is not None - else (self.dst_resolver(name) if self.dst_resolver else None) + else (self.dst_resolver(name, version=version) if self.dst_resolver else None) ) if self.remote_output_dir: self.name = self._broadcast_object(self.name) # Ensure the remote src dir is the same across all ranks self.remote_output_dir = self._broadcast_object(self.remote_output_dir) - print(f"Storing the files under {self.remote_output_dir}") + if isinstance(self.remote_output_dir, PrettyDirectory): + print(f"Storing the files under {self.remote_output_dir.directory}") + self.remote_output_dir = self.remote_output_dir.url + else: + print(f"Storing the files under {self.remote_output_dir}") self.random_seed = random_seed @@ -725,7 +744,7 @@ class DataProcessor: user_items: List[Any] = data_recipe.prepare_structure(self.input_dir) if not isinstance(user_items, list): - raise ValueError("The setup_fn should return a list of item metadata.") + raise ValueError("The `prepare_structure` should return a list of item metadata.") # Associate the items to the workers based on num_nodes and node_rank begins, workers_user_items = _associated_items_to_workers(self.num_workers, user_items) @@ -779,6 +798,8 @@ class DataProcessor: w.join(0) print("Workers are finished.") + if self.remote_output_dir: + assert isinstance(self.remote_output_dir, str) data_recipe._done(self.delete_cached_files, self.remote_output_dir) print("Finished data processing!") @@ -856,7 +877,7 @@ class DataProcessor: # Cleanup the cache dir folder to avoid corrupted files from previous run to be there. if os.path.exists(cache_dir): - rmtree(cache_dir) + rmtree(cache_dir, ignore_errors=True) os.makedirs(cache_dir, exist_ok=True) @@ -864,7 +885,7 @@ class DataProcessor: # Cleanup the cache data folder to avoid corrupted files from previous run to be there. if os.path.exists(cache_data_dir): - rmtree(cache_data_dir) + rmtree(cache_data_dir, ignore_errors=True) os.makedirs(cache_data_dir, exist_ok=True) diff --git a/src/lightning/data/streaming/functions.py b/src/lightning/data/streaming/functions.py new file mode 100644 index 0000000000..2bb60dfb34 --- /dev/null +++ b/src/lightning/data/streaming/functions.py @@ -0,0 +1,207 @@ +# Copyright The Lightning AI team. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +from datetime import datetime +from pathlib import Path +from types import GeneratorType +from typing import Any, Callable, Optional, Sequence, Union + +from lightning.data.streaming.constants import _LIGHTNING_CLOUD_GREATER_EQUAL_0_5_46, _TORCH_GREATER_EQUAL_2_1_0 +from lightning.data.streaming.data_processor import DataChunkRecipe, DataProcessor, DataTransformRecipe, PrettyDirectory + +if _LIGHTNING_CLOUD_GREATER_EQUAL_0_5_46: + from lightning_cloud.resolver import _execute, _LightningSrcResolver + +if _TORCH_GREATER_EQUAL_2_1_0: + from torch.utils._pytree import tree_flatten + + +def _get_input_dir(inputs: Sequence[Any]) -> str: + flattened_item, _ = tree_flatten(inputs[0]) + + indexed_paths = { + index: element + for index, element in enumerate(flattened_item) + if isinstance(element, str) and os.path.exists(element) + } + + if len(indexed_paths) == 0: + raise ValueError(f"The provided item {inputs[0]} didn't contain any filepaths.") + + absolute_path = str(Path(indexed_paths[0]).resolve()) + + if indexed_paths[0] != absolute_path: + raise ValueError("The provided path should be absolute.") + + return "/" + os.path.join(*str(absolute_path).split("/")[:4]) + + +class LambdaDataTransformRecipe(DataTransformRecipe): + def __init__(self, fn: Callable[[str, Any], None], inputs: Sequence[Any]): + super().__init__() + self._fn = fn + self._inputs = inputs + + def prepare_structure(self, input_dir: Optional[str]) -> Any: + return self._inputs + + def prepare_item(self, output_dir: str, item_metadata: Any) -> None: # type: ignore + self._fn(output_dir, item_metadata) + + +class LambdaDataChunkRecipe(DataChunkRecipe): + def __init__( + self, + fn: Callable[[Any], None], + inputs: Sequence[Any], + chunk_size: Optional[int], + chunk_bytes: Optional[int], + compression: Optional[str], + ): + super().__init__(chunk_size=chunk_size, chunk_bytes=chunk_bytes, compression=compression) + self._fn = fn + self._inputs = inputs + + def prepare_structure(self, input_dir: Optional[str]) -> Any: + return self._inputs + + def prepare_item(self, item_metadata: Any) -> Any: # type: ignore + if isinstance(self._fn, GeneratorType): + yield from self._fn(item_metadata) + else: + yield self._fn(item_metadata) + + +def map( + fn: Callable[[str, Any], None], + inputs: Sequence[Any], + output_dir: str, + num_workers: Optional[int] = None, + fast_dev_run: Union[bool, int] = False, + num_nodes: Optional[int] = None, + machine: Optional[str] = None, + input_dir: Optional[str] = None, +) -> None: + """This function map a callbable over a collection of files possibly in a distributed way. + + Arguments: + fn: A function to be executed over each input element + inputs: A sequence of input to be processed by the `fn` function. + Each input should contain at least a valid filepath. + output_dir: The folder where the processed data should be stored. + num_workers: The number of workers to use during processing + fast_dev_run: Whether to use process only a sub part of the inputs + num_nodes: When doing remote execution, the number of nodes to use. + machine: When doing remote execution, the machine to use. + + """ + if not isinstance(inputs, Sequence): + raise ValueError(f"The provided inputs should be non empty sequence. Found {inputs}.") + + if len(inputs) == 0: + raise ValueError(f"The provided inputs should be non empty. Found {inputs}.") + + if num_nodes is None or int(os.getenv("DATA_OPTIMIZER_NUM_NODES", 0)) > 0: + remote_output_dir = _LightningSrcResolver()(output_dir) + + if remote_output_dir is None or "cloudspaces" in remote_output_dir: + raise ValueError( + f"The provided `output_dir` isn't valid. Found {output_dir}." + " HINT: You can either use `/teamspace/s3_connections/...` or `/teamspace/datasets/...`." + ) + + data_processor = DataProcessor( + num_workers=num_workers or os.cpu_count(), + remote_output_dir=PrettyDirectory(output_dir, remote_output_dir), + fast_dev_run=fast_dev_run, + version=None, + input_dir=input_dir or _get_input_dir(inputs), + ) + return data_processor.run(LambdaDataTransformRecipe(fn, inputs)) + return _execute( + f"data-prep-map-{datetime.now().strftime('%Y-%m-%d_%H-%M-%S')}", + num_nodes, + machine, + ) + + +def optimize( + fn: Callable[[Any], Any], + inputs: Sequence[Any], + output_dir: str, + chunk_size: Optional[int] = None, + chunk_bytes: Optional[int] = None, + compression: Optional[str] = None, + name: Optional[str] = None, + num_workers: Optional[int] = None, + fast_dev_run: bool = False, + num_nodes: Optional[int] = None, + machine: Optional[str] = None, + input_dir: Optional[str] = None, +) -> None: + """This function converts a dataset into chunks possibly in a distributed way. + + Arguments: + fn: A function to be executed over each input element + inputs: A sequence of input to be processed by the `fn` function. + Each input should contain at least a valid filepath. + output_dir: The folder where the processed data should be stored. + chunk_size: The maximum number of elements to hold within a chunk. + chunk_bytes: The maximum number of bytes to hold within a chunk. + compression: The compression algorithm to use over the chunks. + num_workers: The number of workers to use during processing + fast_dev_run: Whether to use process only a sub part of the inputs + num_nodes: When doing remote execution, the number of nodes to use. + machine: When doing remote execution, the machine to use. + + """ + if not isinstance(inputs, Sequence): + raise ValueError(f"The provided inputs should be non empty sequence. Found {inputs}.") + + if len(inputs) == 0: + raise ValueError(f"The provided inputs should be non empty. Found {inputs}.") + + if chunk_size is None and chunk_bytes is None: + raise ValueError("Either `chunk_size` or `chunk_bytes` needs to be defined.") + + if num_nodes is None or int(os.getenv("DATA_OPTIMIZER_NUM_NODES", 0)) > 0: + remote_output_dir = _LightningSrcResolver()(output_dir) + + if remote_output_dir is None or "cloudspaces" in remote_output_dir: + raise ValueError( + f"The provided `output_dir` isn't valid. Found {output_dir}." + " HINT: You can either use `/teamspace/s3_connections/...` or `/teamspace/datasets/...`." + ) + + data_processor = DataProcessor( + name=name, + num_workers=num_workers or os.cpu_count(), + remote_output_dir=PrettyDirectory(output_dir, remote_output_dir), + fast_dev_run=fast_dev_run, + input_dir=input_dir or _get_input_dir(inputs), + ) + return data_processor.run( + LambdaDataChunkRecipe( + fn, + inputs, + chunk_size=chunk_size, + chunk_bytes=chunk_bytes, + compression=compression, + ) + ) + return _execute( + f"data-prep-optimize-{datetime.now().strftime('%Y-%m-%d_%H-%M-%S')}", + num_nodes, + machine, + ) diff --git a/src/lightning/data/streaming/item_loader.py b/src/lightning/data/streaming/item_loader.py index 8b31f5b4ef..60a9d08fc1 100644 --- a/src/lightning/data/streaming/item_loader.py +++ b/src/lightning/data/streaming/item_loader.py @@ -70,6 +70,9 @@ class PyTreeLoader(BaseItemLoader): if chunk_filepath not in self._chunk_filepaths: while not os.path.exists(chunk_filepath): sleep(0.001) + + # Wait to avoid any corruption when the file appears + sleep(0.001) self._chunk_filepaths[chunk_filepath] = True with open(chunk_filepath, "rb", 0) as fp: diff --git a/src/lightning/data/streaming/writer.py b/src/lightning/data/streaming/writer.py index 631b1167fe..1ec2995eae 100644 --- a/src/lightning/data/streaming/writer.py +++ b/src/lightning/data/streaming/writer.py @@ -206,7 +206,8 @@ class BinaryWriter: if len(items) == 0: raise RuntimeError( - f"The items shouldn't have an empty length. Something went wrong. Found {self._serialized_items}." + "The items shouldn't have an empty length. Something went wrong." + f" Found {self._pretty_serialized_items()} with boundaries: {self._min_index}, {self._max_index}." ) sizes = list(map(len, items)) @@ -414,3 +415,15 @@ class BinaryWriter: return f1 != f2 return any(is_non_valid(f1, f2) for f1, f2 in zip(data_format_1, data_format_2)) + + def _pretty_serialized_items(self) -> Dict[int, Item]: + out = {} + for key, value in self._serialized_items.items(): + # drop `data` as it would make logs unreadable. + out[key] = Item( + index=value.index, + bytes=value.bytes, + dim=value.dim, + data=b"", + ) + return out diff --git a/tests/tests_data/streaming/test_data_processor.py b/tests/tests_data/streaming/test_data_processor.py index 2a767c39af..90ede64627 100644 --- a/tests/tests_data/streaming/test_data_processor.py +++ b/tests/tests_data/streaming/test_data_processor.py @@ -8,6 +8,8 @@ import pytest import torch from lightning import seed_everything from lightning.data.streaming import data_processor as data_processor_module +from lightning.data.streaming import functions +from lightning.data.streaming.cache import Cache from lightning.data.streaming.data_processor import ( DataChunkRecipe, DataProcessor, @@ -18,7 +20,7 @@ from lightning.data.streaming.data_processor import ( _upload_fn, _wait_for_file_to_exist, ) -from lightning.data.streaming.map import map +from lightning.data.streaming.functions import map, optimize from lightning_utilities.core.imports import RequirementCache _PIL_AVAILABLE = RequirementCache("PIL") @@ -519,7 +521,7 @@ def test_data_process_transform(monkeypatch, tmpdir): assert img.size == (12, 12) -def fn(output_dir, filepath): +def map_fn(output_dir, filepath): from PIL import Image img = Image.open(filepath) @@ -528,31 +530,84 @@ def fn(output_dir, filepath): img.save(os.path.join(output_dir, os.path.basename(filepath))) +@pytest.mark.skipif(condition=not _PIL_AVAILABLE or sys.platform == "win32", reason="Requires: ['pil']") def test_data_processing_map(monkeypatch, tmpdir): from PIL import Image + input_dir = os.path.join(tmpdir, "input_dir") + os.makedirs(input_dir, exist_ok=True) imgs = [] for i in range(5): np_data = np.random.randint(255, size=(28, 28), dtype=np.uint32) img = Image.fromarray(np_data).convert("L") imgs.append(img) - img.save(os.path.join(tmpdir, f"{i}.JPEG")) + img.save(os.path.join(input_dir, f"{i}.JPEG")) home_dir = os.path.join(tmpdir, "home") cache_dir = os.path.join(tmpdir, "cache") - remote_output_dir = os.path.join(tmpdir, "target_dir") - os.makedirs(remote_output_dir, exist_ok=True) + output_dir = os.path.join(tmpdir, "target_dir") + os.makedirs(output_dir, exist_ok=True) monkeypatch.setenv("DATA_OPTIMIZER_HOME_FOLDER", home_dir) monkeypatch.setenv("DATA_OPTIMIZER_CACHE_FOLDER", cache_dir) - inputs = [os.path.join(tmpdir, filename) for filename in os.listdir(tmpdir)] + resolver = mock.MagicMock() + resolver.return_value = lambda x: x + monkeypatch.setattr(functions, "_LightningSrcResolver", resolver) + monkeypatch.setattr(data_processor_module, "_LightningSrcResolver", resolver) + monkeypatch.setattr(data_processor_module, "_LightningTargetResolver", resolver) + + inputs = [os.path.join(input_dir, filename) for filename in os.listdir(input_dir)] inputs = [filepath for filepath in inputs if os.path.isfile(filepath)] - map(fn, inputs, num_workers=1, remote_output_dir=remote_output_dir) + map(map_fn, inputs, num_workers=1, output_dir=output_dir, input_dir=input_dir) - assert sorted(os.listdir(remote_output_dir)) == ["0.JPEG", "1.JPEG", "2.JPEG", "3.JPEG", "4.JPEG"] + assert sorted(os.listdir(output_dir)) == ["0.JPEG", "1.JPEG", "2.JPEG", "3.JPEG", "4.JPEG"] from PIL import Image - img = Image.open(os.path.join(remote_output_dir, "0.JPEG")) + img = Image.open(os.path.join(output_dir, "0.JPEG")) assert img.size == (12, 12) + + +def optimize_fn(filepath): + print(filepath) + from PIL import Image + + return [Image.open(filepath), os.path.basename(filepath)] + + +@pytest.mark.skipif(condition=not _PIL_AVAILABLE or sys.platform == "win32", reason="Requires: ['pil']") +def test_data_processing_optimize(monkeypatch, tmpdir): + from PIL import Image + + input_dir = os.path.join(tmpdir, "input_dir") + os.makedirs(input_dir, exist_ok=True) + imgs = [] + for i in range(5): + np_data = np.random.randint(255, size=(28, 28), dtype=np.uint32) + img = Image.fromarray(np_data).convert("L") + imgs.append(img) + img.save(os.path.join(input_dir, f"{i}.JPEG")) + + home_dir = os.path.join(tmpdir, "home") + cache_dir = os.path.join(tmpdir, "cache") + output_dir = os.path.join(tmpdir, "target_dir") + os.makedirs(output_dir, exist_ok=True) + monkeypatch.setenv("DATA_OPTIMIZER_HOME_FOLDER", home_dir) + monkeypatch.setenv("DATA_OPTIMIZER_CACHE_FOLDER", cache_dir) + + inputs = [os.path.join(input_dir, filename) for filename in os.listdir(input_dir)] + inputs = [filepath for filepath in inputs if os.path.isfile(filepath)] + + resolver = mock.MagicMock() + resolver.return_value = lambda x: x + monkeypatch.setattr(functions, "_LightningSrcResolver", resolver) + monkeypatch.setattr(data_processor_module, "_LightningSrcResolver", resolver) + monkeypatch.setattr(data_processor_module, "_LightningTargetResolver", resolver) + + optimize(optimize_fn, inputs, num_workers=1, output_dir=output_dir, chunk_size=2, input_dir=input_dir) + + assert sorted(os.listdir(output_dir)) == ["chunk-0-0.bin", "chunk-0-1.bin", "chunk-0-2.bin", "index.json"] + + cache = Cache(output_dir, chunk_size=1) + assert len(cache) == 5