Improve map and chunkify (#18901)
Co-authored-by: thomas <thomas@thomass-MacBook-Pro.local> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
parent
31b8777350
commit
85933f355a
|
@ -1,4 +1,4 @@
|
|||
lightning-cloud ==0.5.44 # Must be pinned to ensure compatibility
|
||||
lightning-cloud ==0.5.46 # Must be pinned to ensure compatibility
|
||||
packaging
|
||||
typing-extensions >=4.0.0, <4.8.0
|
||||
deepdiff >=5.7.0, <6.6.0
|
||||
|
|
|
@ -1,10 +1,11 @@
|
|||
from lightning.data.datasets import LightningDataset, LightningIterableDataset
|
||||
from lightning.data.streaming.dataset import StreamingDataset
|
||||
from lightning.data.streaming.map import map
|
||||
from lightning.data.streaming.functions import map, optimize
|
||||
|
||||
__all__ = [
|
||||
"LightningDataset",
|
||||
"StreamingDataset",
|
||||
"LightningIterableDataset",
|
||||
"map",
|
||||
"optimize",
|
||||
]
|
||||
|
|
|
@ -18,7 +18,7 @@ from typing import Any, Dict, List, Literal, Optional, Tuple, Union
|
|||
from lightning.data.datasets.env import _DistributedEnv
|
||||
from lightning.data.streaming.constants import (
|
||||
_INDEX_FILENAME,
|
||||
_LIGHTNING_CLOUD_GREATER_EQUAL_0_5_42,
|
||||
_LIGHTNING_CLOUD_GREATER_EQUAL_0_5_46,
|
||||
_TORCH_GREATER_EQUAL_2_1_0,
|
||||
)
|
||||
from lightning.data.streaming.item_loader import BaseItemLoader
|
||||
|
@ -26,7 +26,7 @@ from lightning.data.streaming.reader import BinaryReader
|
|||
from lightning.data.streaming.sampler import ChunkedIndex
|
||||
from lightning.data.streaming.writer import BinaryWriter
|
||||
|
||||
if _LIGHTNING_CLOUD_GREATER_EQUAL_0_5_42:
|
||||
if _LIGHTNING_CLOUD_GREATER_EQUAL_0_5_46:
|
||||
from lightning_cloud.resolver import _find_remote_dir, _try_create_cache_dir
|
||||
|
||||
logger = logging.Logger(__name__)
|
||||
|
|
|
@ -21,7 +21,7 @@ _DEFAULT_FAST_DEV_RUN_ITEMS = 10
|
|||
# This is required for full pytree serialization / deserialization support
|
||||
_TORCH_GREATER_EQUAL_2_1_0 = RequirementCache("torch>=2.1.0")
|
||||
_VIZ_TRACKER_AVAILABLE = RequirementCache("viztracer")
|
||||
_LIGHTNING_CLOUD_GREATER_EQUAL_0_5_42 = RequirementCache("lightning-cloud>=0.5.42")
|
||||
_LIGHTNING_CLOUD_GREATER_EQUAL_0_5_46 = RequirementCache("lightning-cloud>=0.5.46")
|
||||
_BOTO3_AVAILABLE = RequirementCache("boto3")
|
||||
|
||||
# DON'T CHANGE ORDER
|
||||
|
|
|
@ -5,11 +5,12 @@ import tempfile
|
|||
import traceback
|
||||
import types
|
||||
from abc import abstractmethod
|
||||
from dataclasses import dataclass
|
||||
from multiprocessing import Process, Queue
|
||||
from queue import Empty
|
||||
from shutil import copyfile, rmtree
|
||||
from time import sleep, time
|
||||
from typing import Any, Callable, Dict, List, Optional, Tuple, TypeVar
|
||||
from typing import Any, Callable, Dict, List, Optional, Tuple, TypeVar, Union
|
||||
from urllib import parse
|
||||
|
||||
import torch
|
||||
|
@ -21,7 +22,7 @@ from lightning.data.streaming.constants import (
|
|||
_BOTO3_AVAILABLE,
|
||||
_DEFAULT_FAST_DEV_RUN_ITEMS,
|
||||
_INDEX_FILENAME,
|
||||
_LIGHTNING_CLOUD_GREATER_EQUAL_0_5_42,
|
||||
_LIGHTNING_CLOUD_GREATER_EQUAL_0_5_46,
|
||||
_TORCH_GREATER_EQUAL_2_1_0,
|
||||
)
|
||||
from lightning.fabric.accelerators.cuda import is_cuda_available
|
||||
|
@ -35,7 +36,7 @@ from lightning.fabric.utilities.distributed import group as _group
|
|||
if _TORCH_GREATER_EQUAL_2_1_0:
|
||||
from torch.utils._pytree import tree_flatten, tree_unflatten
|
||||
|
||||
if _LIGHTNING_CLOUD_GREATER_EQUAL_0_5_42:
|
||||
if _LIGHTNING_CLOUD_GREATER_EQUAL_0_5_46:
|
||||
from lightning_cloud.resolver import _LightningSrcResolver, _LightningTargetResolver
|
||||
|
||||
if _BOTO3_AVAILABLE:
|
||||
|
@ -160,10 +161,14 @@ def _remove_target(input_dir: str, cache_dir: str, queue_in: Queue) -> None:
|
|||
# 3. Iterate through the paths and delete them sequentially.
|
||||
for path in paths:
|
||||
if input_dir:
|
||||
cached_filepath = path.replace(input_dir, cache_dir)
|
||||
if not path.startswith(cache_dir):
|
||||
path = path.replace(input_dir, cache_dir)
|
||||
|
||||
if os.path.exists(cached_filepath):
|
||||
os.remove(cached_filepath)
|
||||
if os.path.exists(path):
|
||||
os.remove(path)
|
||||
|
||||
elif os.path.exists(path) and "s3_connections" not in path:
|
||||
os.remove(path)
|
||||
|
||||
|
||||
def _upload_fn(upload_queue: Queue, remove_queue: Queue, cache_dir: str, remote_output_dir: str) -> None:
|
||||
|
@ -387,7 +392,9 @@ class BaseWorker:
|
|||
}
|
||||
|
||||
if len(indexed_paths) == 0:
|
||||
raise ValueError(f"The provided item {item} didn't contain any filepaths. {flattened_item}")
|
||||
raise ValueError(
|
||||
f"The provided item {item} didn't contain any filepaths. The input_dir is {self.input_dir}."
|
||||
)
|
||||
|
||||
paths = []
|
||||
for index, path in indexed_paths.items():
|
||||
|
@ -548,7 +555,7 @@ class DataRecipe:
|
|||
def _setup(self, name: Optional[str]) -> None:
|
||||
self._name = name
|
||||
|
||||
def _done(self, delete_cached_files: bool, remote_output_dir: str) -> None:
|
||||
def _done(self, delete_cached_files: bool, remote_output_dir: Any) -> None:
|
||||
pass
|
||||
|
||||
|
||||
|
@ -578,7 +585,6 @@ class DataChunkRecipe(DataRecipe):
|
|||
|
||||
def _done(self, delete_cached_files: bool, remote_output_dir: str) -> None:
|
||||
num_nodes = _get_num_nodes()
|
||||
assert self._name
|
||||
cache_dir = _get_cache_dir(self._name)
|
||||
|
||||
chunks = [file for file in os.listdir(cache_dir) if file.endswith(".bin")]
|
||||
|
@ -647,6 +653,14 @@ class DataTransformRecipe(DataRecipe):
|
|||
"""Use your item metadata to process your files and save the file outputs into `output_dir`."""
|
||||
|
||||
|
||||
@dataclass
|
||||
class PrettyDirectory:
|
||||
"""Holds a directory and its URL."""
|
||||
|
||||
directory: str
|
||||
url: str
|
||||
|
||||
|
||||
class DataProcessor:
|
||||
def __init__(
|
||||
self,
|
||||
|
@ -656,10 +670,11 @@ class DataProcessor:
|
|||
num_downloaders: Optional[int] = None,
|
||||
delete_cached_files: bool = True,
|
||||
src_resolver: Optional[Callable[[str], Optional[str]]] = None,
|
||||
fast_dev_run: Optional[bool] = None,
|
||||
fast_dev_run: Optional[Union[bool, int]] = None,
|
||||
remote_input_dir: Optional[str] = None,
|
||||
remote_output_dir: Optional[str] = None,
|
||||
remote_output_dir: Optional[Union[str, PrettyDirectory]] = None,
|
||||
random_seed: Optional[int] = 42,
|
||||
version: Optional[int] = None,
|
||||
):
|
||||
"""The `DatasetOptimiser` provides an efficient way to process data across multiple machine into chunks to make
|
||||
training faster.
|
||||
|
@ -692,18 +707,22 @@ class DataProcessor:
|
|||
self.remote_input_dir = (
|
||||
str(remote_input_dir)
|
||||
if remote_input_dir is not None
|
||||
else ((self.src_resolver(input_dir) if input_dir else None) if self.src_resolver else None)
|
||||
else ((self.src_resolver(str(input_dir)) if input_dir else None) if self.src_resolver else None)
|
||||
)
|
||||
self.remote_output_dir = (
|
||||
remote_output_dir
|
||||
if remote_output_dir is not None
|
||||
else (self.dst_resolver(name) if self.dst_resolver else None)
|
||||
else (self.dst_resolver(name, version=version) if self.dst_resolver else None)
|
||||
)
|
||||
if self.remote_output_dir:
|
||||
self.name = self._broadcast_object(self.name)
|
||||
# Ensure the remote src dir is the same across all ranks
|
||||
self.remote_output_dir = self._broadcast_object(self.remote_output_dir)
|
||||
print(f"Storing the files under {self.remote_output_dir}")
|
||||
if isinstance(self.remote_output_dir, PrettyDirectory):
|
||||
print(f"Storing the files under {self.remote_output_dir.directory}")
|
||||
self.remote_output_dir = self.remote_output_dir.url
|
||||
else:
|
||||
print(f"Storing the files under {self.remote_output_dir}")
|
||||
|
||||
self.random_seed = random_seed
|
||||
|
||||
|
@ -725,7 +744,7 @@ class DataProcessor:
|
|||
user_items: List[Any] = data_recipe.prepare_structure(self.input_dir)
|
||||
|
||||
if not isinstance(user_items, list):
|
||||
raise ValueError("The setup_fn should return a list of item metadata.")
|
||||
raise ValueError("The `prepare_structure` should return a list of item metadata.")
|
||||
|
||||
# Associate the items to the workers based on num_nodes and node_rank
|
||||
begins, workers_user_items = _associated_items_to_workers(self.num_workers, user_items)
|
||||
|
@ -779,6 +798,8 @@ class DataProcessor:
|
|||
w.join(0)
|
||||
|
||||
print("Workers are finished.")
|
||||
if self.remote_output_dir:
|
||||
assert isinstance(self.remote_output_dir, str)
|
||||
data_recipe._done(self.delete_cached_files, self.remote_output_dir)
|
||||
print("Finished data processing!")
|
||||
|
||||
|
@ -856,7 +877,7 @@ class DataProcessor:
|
|||
|
||||
# Cleanup the cache dir folder to avoid corrupted files from previous run to be there.
|
||||
if os.path.exists(cache_dir):
|
||||
rmtree(cache_dir)
|
||||
rmtree(cache_dir, ignore_errors=True)
|
||||
|
||||
os.makedirs(cache_dir, exist_ok=True)
|
||||
|
||||
|
@ -864,7 +885,7 @@ class DataProcessor:
|
|||
|
||||
# Cleanup the cache data folder to avoid corrupted files from previous run to be there.
|
||||
if os.path.exists(cache_data_dir):
|
||||
rmtree(cache_data_dir)
|
||||
rmtree(cache_data_dir, ignore_errors=True)
|
||||
|
||||
os.makedirs(cache_data_dir, exist_ok=True)
|
||||
|
||||
|
|
|
@ -0,0 +1,207 @@
|
|||
# Copyright The Lightning AI team.
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import os
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from types import GeneratorType
|
||||
from typing import Any, Callable, Optional, Sequence, Union
|
||||
|
||||
from lightning.data.streaming.constants import _LIGHTNING_CLOUD_GREATER_EQUAL_0_5_46, _TORCH_GREATER_EQUAL_2_1_0
|
||||
from lightning.data.streaming.data_processor import DataChunkRecipe, DataProcessor, DataTransformRecipe, PrettyDirectory
|
||||
|
||||
if _LIGHTNING_CLOUD_GREATER_EQUAL_0_5_46:
|
||||
from lightning_cloud.resolver import _execute, _LightningSrcResolver
|
||||
|
||||
if _TORCH_GREATER_EQUAL_2_1_0:
|
||||
from torch.utils._pytree import tree_flatten
|
||||
|
||||
|
||||
def _get_input_dir(inputs: Sequence[Any]) -> str:
|
||||
flattened_item, _ = tree_flatten(inputs[0])
|
||||
|
||||
indexed_paths = {
|
||||
index: element
|
||||
for index, element in enumerate(flattened_item)
|
||||
if isinstance(element, str) and os.path.exists(element)
|
||||
}
|
||||
|
||||
if len(indexed_paths) == 0:
|
||||
raise ValueError(f"The provided item {inputs[0]} didn't contain any filepaths.")
|
||||
|
||||
absolute_path = str(Path(indexed_paths[0]).resolve())
|
||||
|
||||
if indexed_paths[0] != absolute_path:
|
||||
raise ValueError("The provided path should be absolute.")
|
||||
|
||||
return "/" + os.path.join(*str(absolute_path).split("/")[:4])
|
||||
|
||||
|
||||
class LambdaDataTransformRecipe(DataTransformRecipe):
|
||||
def __init__(self, fn: Callable[[str, Any], None], inputs: Sequence[Any]):
|
||||
super().__init__()
|
||||
self._fn = fn
|
||||
self._inputs = inputs
|
||||
|
||||
def prepare_structure(self, input_dir: Optional[str]) -> Any:
|
||||
return self._inputs
|
||||
|
||||
def prepare_item(self, output_dir: str, item_metadata: Any) -> None: # type: ignore
|
||||
self._fn(output_dir, item_metadata)
|
||||
|
||||
|
||||
class LambdaDataChunkRecipe(DataChunkRecipe):
|
||||
def __init__(
|
||||
self,
|
||||
fn: Callable[[Any], None],
|
||||
inputs: Sequence[Any],
|
||||
chunk_size: Optional[int],
|
||||
chunk_bytes: Optional[int],
|
||||
compression: Optional[str],
|
||||
):
|
||||
super().__init__(chunk_size=chunk_size, chunk_bytes=chunk_bytes, compression=compression)
|
||||
self._fn = fn
|
||||
self._inputs = inputs
|
||||
|
||||
def prepare_structure(self, input_dir: Optional[str]) -> Any:
|
||||
return self._inputs
|
||||
|
||||
def prepare_item(self, item_metadata: Any) -> Any: # type: ignore
|
||||
if isinstance(self._fn, GeneratorType):
|
||||
yield from self._fn(item_metadata)
|
||||
else:
|
||||
yield self._fn(item_metadata)
|
||||
|
||||
|
||||
def map(
|
||||
fn: Callable[[str, Any], None],
|
||||
inputs: Sequence[Any],
|
||||
output_dir: str,
|
||||
num_workers: Optional[int] = None,
|
||||
fast_dev_run: Union[bool, int] = False,
|
||||
num_nodes: Optional[int] = None,
|
||||
machine: Optional[str] = None,
|
||||
input_dir: Optional[str] = None,
|
||||
) -> None:
|
||||
"""This function map a callbable over a collection of files possibly in a distributed way.
|
||||
|
||||
Arguments:
|
||||
fn: A function to be executed over each input element
|
||||
inputs: A sequence of input to be processed by the `fn` function.
|
||||
Each input should contain at least a valid filepath.
|
||||
output_dir: The folder where the processed data should be stored.
|
||||
num_workers: The number of workers to use during processing
|
||||
fast_dev_run: Whether to use process only a sub part of the inputs
|
||||
num_nodes: When doing remote execution, the number of nodes to use.
|
||||
machine: When doing remote execution, the machine to use.
|
||||
|
||||
"""
|
||||
if not isinstance(inputs, Sequence):
|
||||
raise ValueError(f"The provided inputs should be non empty sequence. Found {inputs}.")
|
||||
|
||||
if len(inputs) == 0:
|
||||
raise ValueError(f"The provided inputs should be non empty. Found {inputs}.")
|
||||
|
||||
if num_nodes is None or int(os.getenv("DATA_OPTIMIZER_NUM_NODES", 0)) > 0:
|
||||
remote_output_dir = _LightningSrcResolver()(output_dir)
|
||||
|
||||
if remote_output_dir is None or "cloudspaces" in remote_output_dir:
|
||||
raise ValueError(
|
||||
f"The provided `output_dir` isn't valid. Found {output_dir}."
|
||||
" HINT: You can either use `/teamspace/s3_connections/...` or `/teamspace/datasets/...`."
|
||||
)
|
||||
|
||||
data_processor = DataProcessor(
|
||||
num_workers=num_workers or os.cpu_count(),
|
||||
remote_output_dir=PrettyDirectory(output_dir, remote_output_dir),
|
||||
fast_dev_run=fast_dev_run,
|
||||
version=None,
|
||||
input_dir=input_dir or _get_input_dir(inputs),
|
||||
)
|
||||
return data_processor.run(LambdaDataTransformRecipe(fn, inputs))
|
||||
return _execute(
|
||||
f"data-prep-map-{datetime.now().strftime('%Y-%m-%d_%H-%M-%S')}",
|
||||
num_nodes,
|
||||
machine,
|
||||
)
|
||||
|
||||
|
||||
def optimize(
|
||||
fn: Callable[[Any], Any],
|
||||
inputs: Sequence[Any],
|
||||
output_dir: str,
|
||||
chunk_size: Optional[int] = None,
|
||||
chunk_bytes: Optional[int] = None,
|
||||
compression: Optional[str] = None,
|
||||
name: Optional[str] = None,
|
||||
num_workers: Optional[int] = None,
|
||||
fast_dev_run: bool = False,
|
||||
num_nodes: Optional[int] = None,
|
||||
machine: Optional[str] = None,
|
||||
input_dir: Optional[str] = None,
|
||||
) -> None:
|
||||
"""This function converts a dataset into chunks possibly in a distributed way.
|
||||
|
||||
Arguments:
|
||||
fn: A function to be executed over each input element
|
||||
inputs: A sequence of input to be processed by the `fn` function.
|
||||
Each input should contain at least a valid filepath.
|
||||
output_dir: The folder where the processed data should be stored.
|
||||
chunk_size: The maximum number of elements to hold within a chunk.
|
||||
chunk_bytes: The maximum number of bytes to hold within a chunk.
|
||||
compression: The compression algorithm to use over the chunks.
|
||||
num_workers: The number of workers to use during processing
|
||||
fast_dev_run: Whether to use process only a sub part of the inputs
|
||||
num_nodes: When doing remote execution, the number of nodes to use.
|
||||
machine: When doing remote execution, the machine to use.
|
||||
|
||||
"""
|
||||
if not isinstance(inputs, Sequence):
|
||||
raise ValueError(f"The provided inputs should be non empty sequence. Found {inputs}.")
|
||||
|
||||
if len(inputs) == 0:
|
||||
raise ValueError(f"The provided inputs should be non empty. Found {inputs}.")
|
||||
|
||||
if chunk_size is None and chunk_bytes is None:
|
||||
raise ValueError("Either `chunk_size` or `chunk_bytes` needs to be defined.")
|
||||
|
||||
if num_nodes is None or int(os.getenv("DATA_OPTIMIZER_NUM_NODES", 0)) > 0:
|
||||
remote_output_dir = _LightningSrcResolver()(output_dir)
|
||||
|
||||
if remote_output_dir is None or "cloudspaces" in remote_output_dir:
|
||||
raise ValueError(
|
||||
f"The provided `output_dir` isn't valid. Found {output_dir}."
|
||||
" HINT: You can either use `/teamspace/s3_connections/...` or `/teamspace/datasets/...`."
|
||||
)
|
||||
|
||||
data_processor = DataProcessor(
|
||||
name=name,
|
||||
num_workers=num_workers or os.cpu_count(),
|
||||
remote_output_dir=PrettyDirectory(output_dir, remote_output_dir),
|
||||
fast_dev_run=fast_dev_run,
|
||||
input_dir=input_dir or _get_input_dir(inputs),
|
||||
)
|
||||
return data_processor.run(
|
||||
LambdaDataChunkRecipe(
|
||||
fn,
|
||||
inputs,
|
||||
chunk_size=chunk_size,
|
||||
chunk_bytes=chunk_bytes,
|
||||
compression=compression,
|
||||
)
|
||||
)
|
||||
return _execute(
|
||||
f"data-prep-optimize-{datetime.now().strftime('%Y-%m-%d_%H-%M-%S')}",
|
||||
num_nodes,
|
||||
machine,
|
||||
)
|
|
@ -70,6 +70,9 @@ class PyTreeLoader(BaseItemLoader):
|
|||
if chunk_filepath not in self._chunk_filepaths:
|
||||
while not os.path.exists(chunk_filepath):
|
||||
sleep(0.001)
|
||||
|
||||
# Wait to avoid any corruption when the file appears
|
||||
sleep(0.001)
|
||||
self._chunk_filepaths[chunk_filepath] = True
|
||||
|
||||
with open(chunk_filepath, "rb", 0) as fp:
|
||||
|
|
|
@ -206,7 +206,8 @@ class BinaryWriter:
|
|||
|
||||
if len(items) == 0:
|
||||
raise RuntimeError(
|
||||
f"The items shouldn't have an empty length. Something went wrong. Found {self._serialized_items}."
|
||||
"The items shouldn't have an empty length. Something went wrong."
|
||||
f" Found {self._pretty_serialized_items()} with boundaries: {self._min_index}, {self._max_index}."
|
||||
)
|
||||
|
||||
sizes = list(map(len, items))
|
||||
|
@ -414,3 +415,15 @@ class BinaryWriter:
|
|||
return f1 != f2
|
||||
|
||||
return any(is_non_valid(f1, f2) for f1, f2 in zip(data_format_1, data_format_2))
|
||||
|
||||
def _pretty_serialized_items(self) -> Dict[int, Item]:
|
||||
out = {}
|
||||
for key, value in self._serialized_items.items():
|
||||
# drop `data` as it would make logs unreadable.
|
||||
out[key] = Item(
|
||||
index=value.index,
|
||||
bytes=value.bytes,
|
||||
dim=value.dim,
|
||||
data=b"",
|
||||
)
|
||||
return out
|
||||
|
|
|
@ -8,6 +8,8 @@ import pytest
|
|||
import torch
|
||||
from lightning import seed_everything
|
||||
from lightning.data.streaming import data_processor as data_processor_module
|
||||
from lightning.data.streaming import functions
|
||||
from lightning.data.streaming.cache import Cache
|
||||
from lightning.data.streaming.data_processor import (
|
||||
DataChunkRecipe,
|
||||
DataProcessor,
|
||||
|
@ -18,7 +20,7 @@ from lightning.data.streaming.data_processor import (
|
|||
_upload_fn,
|
||||
_wait_for_file_to_exist,
|
||||
)
|
||||
from lightning.data.streaming.map import map
|
||||
from lightning.data.streaming.functions import map, optimize
|
||||
from lightning_utilities.core.imports import RequirementCache
|
||||
|
||||
_PIL_AVAILABLE = RequirementCache("PIL")
|
||||
|
@ -519,7 +521,7 @@ def test_data_process_transform(monkeypatch, tmpdir):
|
|||
assert img.size == (12, 12)
|
||||
|
||||
|
||||
def fn(output_dir, filepath):
|
||||
def map_fn(output_dir, filepath):
|
||||
from PIL import Image
|
||||
|
||||
img = Image.open(filepath)
|
||||
|
@ -528,31 +530,84 @@ def fn(output_dir, filepath):
|
|||
img.save(os.path.join(output_dir, os.path.basename(filepath)))
|
||||
|
||||
|
||||
@pytest.mark.skipif(condition=not _PIL_AVAILABLE or sys.platform == "win32", reason="Requires: ['pil']")
|
||||
def test_data_processing_map(monkeypatch, tmpdir):
|
||||
from PIL import Image
|
||||
|
||||
input_dir = os.path.join(tmpdir, "input_dir")
|
||||
os.makedirs(input_dir, exist_ok=True)
|
||||
imgs = []
|
||||
for i in range(5):
|
||||
np_data = np.random.randint(255, size=(28, 28), dtype=np.uint32)
|
||||
img = Image.fromarray(np_data).convert("L")
|
||||
imgs.append(img)
|
||||
img.save(os.path.join(tmpdir, f"{i}.JPEG"))
|
||||
img.save(os.path.join(input_dir, f"{i}.JPEG"))
|
||||
|
||||
home_dir = os.path.join(tmpdir, "home")
|
||||
cache_dir = os.path.join(tmpdir, "cache")
|
||||
remote_output_dir = os.path.join(tmpdir, "target_dir")
|
||||
os.makedirs(remote_output_dir, exist_ok=True)
|
||||
output_dir = os.path.join(tmpdir, "target_dir")
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
monkeypatch.setenv("DATA_OPTIMIZER_HOME_FOLDER", home_dir)
|
||||
monkeypatch.setenv("DATA_OPTIMIZER_CACHE_FOLDER", cache_dir)
|
||||
|
||||
inputs = [os.path.join(tmpdir, filename) for filename in os.listdir(tmpdir)]
|
||||
resolver = mock.MagicMock()
|
||||
resolver.return_value = lambda x: x
|
||||
monkeypatch.setattr(functions, "_LightningSrcResolver", resolver)
|
||||
monkeypatch.setattr(data_processor_module, "_LightningSrcResolver", resolver)
|
||||
monkeypatch.setattr(data_processor_module, "_LightningTargetResolver", resolver)
|
||||
|
||||
inputs = [os.path.join(input_dir, filename) for filename in os.listdir(input_dir)]
|
||||
inputs = [filepath for filepath in inputs if os.path.isfile(filepath)]
|
||||
|
||||
map(fn, inputs, num_workers=1, remote_output_dir=remote_output_dir)
|
||||
map(map_fn, inputs, num_workers=1, output_dir=output_dir, input_dir=input_dir)
|
||||
|
||||
assert sorted(os.listdir(remote_output_dir)) == ["0.JPEG", "1.JPEG", "2.JPEG", "3.JPEG", "4.JPEG"]
|
||||
assert sorted(os.listdir(output_dir)) == ["0.JPEG", "1.JPEG", "2.JPEG", "3.JPEG", "4.JPEG"]
|
||||
|
||||
from PIL import Image
|
||||
|
||||
img = Image.open(os.path.join(remote_output_dir, "0.JPEG"))
|
||||
img = Image.open(os.path.join(output_dir, "0.JPEG"))
|
||||
assert img.size == (12, 12)
|
||||
|
||||
|
||||
def optimize_fn(filepath):
|
||||
print(filepath)
|
||||
from PIL import Image
|
||||
|
||||
return [Image.open(filepath), os.path.basename(filepath)]
|
||||
|
||||
|
||||
@pytest.mark.skipif(condition=not _PIL_AVAILABLE or sys.platform == "win32", reason="Requires: ['pil']")
|
||||
def test_data_processing_optimize(monkeypatch, tmpdir):
|
||||
from PIL import Image
|
||||
|
||||
input_dir = os.path.join(tmpdir, "input_dir")
|
||||
os.makedirs(input_dir, exist_ok=True)
|
||||
imgs = []
|
||||
for i in range(5):
|
||||
np_data = np.random.randint(255, size=(28, 28), dtype=np.uint32)
|
||||
img = Image.fromarray(np_data).convert("L")
|
||||
imgs.append(img)
|
||||
img.save(os.path.join(input_dir, f"{i}.JPEG"))
|
||||
|
||||
home_dir = os.path.join(tmpdir, "home")
|
||||
cache_dir = os.path.join(tmpdir, "cache")
|
||||
output_dir = os.path.join(tmpdir, "target_dir")
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
monkeypatch.setenv("DATA_OPTIMIZER_HOME_FOLDER", home_dir)
|
||||
monkeypatch.setenv("DATA_OPTIMIZER_CACHE_FOLDER", cache_dir)
|
||||
|
||||
inputs = [os.path.join(input_dir, filename) for filename in os.listdir(input_dir)]
|
||||
inputs = [filepath for filepath in inputs if os.path.isfile(filepath)]
|
||||
|
||||
resolver = mock.MagicMock()
|
||||
resolver.return_value = lambda x: x
|
||||
monkeypatch.setattr(functions, "_LightningSrcResolver", resolver)
|
||||
monkeypatch.setattr(data_processor_module, "_LightningSrcResolver", resolver)
|
||||
monkeypatch.setattr(data_processor_module, "_LightningTargetResolver", resolver)
|
||||
|
||||
optimize(optimize_fn, inputs, num_workers=1, output_dir=output_dir, chunk_size=2, input_dir=input_dir)
|
||||
|
||||
assert sorted(os.listdir(output_dir)) == ["chunk-0-0.bin", "chunk-0-1.bin", "chunk-0-2.bin", "index.json"]
|
||||
|
||||
cache = Cache(output_dir, chunk_size=1)
|
||||
assert len(cache) == 5
|
||||
|
|
Loading…
Reference in New Issue