Improve map and chunkify (#18901)

Co-authored-by: thomas <thomas@thomass-MacBook-Pro.local>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
thomas chaton 2023-11-01 09:35:35 +00:00 committed by GitHub
parent 31b8777350
commit 85933f355a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 332 additions and 32 deletions

View File

@ -1,4 +1,4 @@
lightning-cloud ==0.5.44 # Must be pinned to ensure compatibility
lightning-cloud ==0.5.46 # Must be pinned to ensure compatibility
packaging
typing-extensions >=4.0.0, <4.8.0
deepdiff >=5.7.0, <6.6.0

View File

@ -1,10 +1,11 @@
from lightning.data.datasets import LightningDataset, LightningIterableDataset
from lightning.data.streaming.dataset import StreamingDataset
from lightning.data.streaming.map import map
from lightning.data.streaming.functions import map, optimize
__all__ = [
"LightningDataset",
"StreamingDataset",
"LightningIterableDataset",
"map",
"optimize",
]

View File

@ -18,7 +18,7 @@ from typing import Any, Dict, List, Literal, Optional, Tuple, Union
from lightning.data.datasets.env import _DistributedEnv
from lightning.data.streaming.constants import (
_INDEX_FILENAME,
_LIGHTNING_CLOUD_GREATER_EQUAL_0_5_42,
_LIGHTNING_CLOUD_GREATER_EQUAL_0_5_46,
_TORCH_GREATER_EQUAL_2_1_0,
)
from lightning.data.streaming.item_loader import BaseItemLoader
@ -26,7 +26,7 @@ from lightning.data.streaming.reader import BinaryReader
from lightning.data.streaming.sampler import ChunkedIndex
from lightning.data.streaming.writer import BinaryWriter
if _LIGHTNING_CLOUD_GREATER_EQUAL_0_5_42:
if _LIGHTNING_CLOUD_GREATER_EQUAL_0_5_46:
from lightning_cloud.resolver import _find_remote_dir, _try_create_cache_dir
logger = logging.Logger(__name__)

View File

@ -21,7 +21,7 @@ _DEFAULT_FAST_DEV_RUN_ITEMS = 10
# This is required for full pytree serialization / deserialization support
_TORCH_GREATER_EQUAL_2_1_0 = RequirementCache("torch>=2.1.0")
_VIZ_TRACKER_AVAILABLE = RequirementCache("viztracer")
_LIGHTNING_CLOUD_GREATER_EQUAL_0_5_42 = RequirementCache("lightning-cloud>=0.5.42")
_LIGHTNING_CLOUD_GREATER_EQUAL_0_5_46 = RequirementCache("lightning-cloud>=0.5.46")
_BOTO3_AVAILABLE = RequirementCache("boto3")
# DON'T CHANGE ORDER

View File

@ -5,11 +5,12 @@ import tempfile
import traceback
import types
from abc import abstractmethod
from dataclasses import dataclass
from multiprocessing import Process, Queue
from queue import Empty
from shutil import copyfile, rmtree
from time import sleep, time
from typing import Any, Callable, Dict, List, Optional, Tuple, TypeVar
from typing import Any, Callable, Dict, List, Optional, Tuple, TypeVar, Union
from urllib import parse
import torch
@ -21,7 +22,7 @@ from lightning.data.streaming.constants import (
_BOTO3_AVAILABLE,
_DEFAULT_FAST_DEV_RUN_ITEMS,
_INDEX_FILENAME,
_LIGHTNING_CLOUD_GREATER_EQUAL_0_5_42,
_LIGHTNING_CLOUD_GREATER_EQUAL_0_5_46,
_TORCH_GREATER_EQUAL_2_1_0,
)
from lightning.fabric.accelerators.cuda import is_cuda_available
@ -35,7 +36,7 @@ from lightning.fabric.utilities.distributed import group as _group
if _TORCH_GREATER_EQUAL_2_1_0:
from torch.utils._pytree import tree_flatten, tree_unflatten
if _LIGHTNING_CLOUD_GREATER_EQUAL_0_5_42:
if _LIGHTNING_CLOUD_GREATER_EQUAL_0_5_46:
from lightning_cloud.resolver import _LightningSrcResolver, _LightningTargetResolver
if _BOTO3_AVAILABLE:
@ -160,10 +161,14 @@ def _remove_target(input_dir: str, cache_dir: str, queue_in: Queue) -> None:
# 3. Iterate through the paths and delete them sequentially.
for path in paths:
if input_dir:
cached_filepath = path.replace(input_dir, cache_dir)
if not path.startswith(cache_dir):
path = path.replace(input_dir, cache_dir)
if os.path.exists(cached_filepath):
os.remove(cached_filepath)
if os.path.exists(path):
os.remove(path)
elif os.path.exists(path) and "s3_connections" not in path:
os.remove(path)
def _upload_fn(upload_queue: Queue, remove_queue: Queue, cache_dir: str, remote_output_dir: str) -> None:
@ -387,7 +392,9 @@ class BaseWorker:
}
if len(indexed_paths) == 0:
raise ValueError(f"The provided item {item} didn't contain any filepaths. {flattened_item}")
raise ValueError(
f"The provided item {item} didn't contain any filepaths. The input_dir is {self.input_dir}."
)
paths = []
for index, path in indexed_paths.items():
@ -548,7 +555,7 @@ class DataRecipe:
def _setup(self, name: Optional[str]) -> None:
self._name = name
def _done(self, delete_cached_files: bool, remote_output_dir: str) -> None:
def _done(self, delete_cached_files: bool, remote_output_dir: Any) -> None:
pass
@ -578,7 +585,6 @@ class DataChunkRecipe(DataRecipe):
def _done(self, delete_cached_files: bool, remote_output_dir: str) -> None:
num_nodes = _get_num_nodes()
assert self._name
cache_dir = _get_cache_dir(self._name)
chunks = [file for file in os.listdir(cache_dir) if file.endswith(".bin")]
@ -647,6 +653,14 @@ class DataTransformRecipe(DataRecipe):
"""Use your item metadata to process your files and save the file outputs into `output_dir`."""
@dataclass
class PrettyDirectory:
"""Holds a directory and its URL."""
directory: str
url: str
class DataProcessor:
def __init__(
self,
@ -656,10 +670,11 @@ class DataProcessor:
num_downloaders: Optional[int] = None,
delete_cached_files: bool = True,
src_resolver: Optional[Callable[[str], Optional[str]]] = None,
fast_dev_run: Optional[bool] = None,
fast_dev_run: Optional[Union[bool, int]] = None,
remote_input_dir: Optional[str] = None,
remote_output_dir: Optional[str] = None,
remote_output_dir: Optional[Union[str, PrettyDirectory]] = None,
random_seed: Optional[int] = 42,
version: Optional[int] = None,
):
"""The `DatasetOptimiser` provides an efficient way to process data across multiple machine into chunks to make
training faster.
@ -692,18 +707,22 @@ class DataProcessor:
self.remote_input_dir = (
str(remote_input_dir)
if remote_input_dir is not None
else ((self.src_resolver(input_dir) if input_dir else None) if self.src_resolver else None)
else ((self.src_resolver(str(input_dir)) if input_dir else None) if self.src_resolver else None)
)
self.remote_output_dir = (
remote_output_dir
if remote_output_dir is not None
else (self.dst_resolver(name) if self.dst_resolver else None)
else (self.dst_resolver(name, version=version) if self.dst_resolver else None)
)
if self.remote_output_dir:
self.name = self._broadcast_object(self.name)
# Ensure the remote src dir is the same across all ranks
self.remote_output_dir = self._broadcast_object(self.remote_output_dir)
print(f"Storing the files under {self.remote_output_dir}")
if isinstance(self.remote_output_dir, PrettyDirectory):
print(f"Storing the files under {self.remote_output_dir.directory}")
self.remote_output_dir = self.remote_output_dir.url
else:
print(f"Storing the files under {self.remote_output_dir}")
self.random_seed = random_seed
@ -725,7 +744,7 @@ class DataProcessor:
user_items: List[Any] = data_recipe.prepare_structure(self.input_dir)
if not isinstance(user_items, list):
raise ValueError("The setup_fn should return a list of item metadata.")
raise ValueError("The `prepare_structure` should return a list of item metadata.")
# Associate the items to the workers based on num_nodes and node_rank
begins, workers_user_items = _associated_items_to_workers(self.num_workers, user_items)
@ -779,6 +798,8 @@ class DataProcessor:
w.join(0)
print("Workers are finished.")
if self.remote_output_dir:
assert isinstance(self.remote_output_dir, str)
data_recipe._done(self.delete_cached_files, self.remote_output_dir)
print("Finished data processing!")
@ -856,7 +877,7 @@ class DataProcessor:
# Cleanup the cache dir folder to avoid corrupted files from previous run to be there.
if os.path.exists(cache_dir):
rmtree(cache_dir)
rmtree(cache_dir, ignore_errors=True)
os.makedirs(cache_dir, exist_ok=True)
@ -864,7 +885,7 @@ class DataProcessor:
# Cleanup the cache data folder to avoid corrupted files from previous run to be there.
if os.path.exists(cache_data_dir):
rmtree(cache_data_dir)
rmtree(cache_data_dir, ignore_errors=True)
os.makedirs(cache_data_dir, exist_ok=True)

View File

@ -0,0 +1,207 @@
# Copyright The Lightning AI team.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
from datetime import datetime
from pathlib import Path
from types import GeneratorType
from typing import Any, Callable, Optional, Sequence, Union
from lightning.data.streaming.constants import _LIGHTNING_CLOUD_GREATER_EQUAL_0_5_46, _TORCH_GREATER_EQUAL_2_1_0
from lightning.data.streaming.data_processor import DataChunkRecipe, DataProcessor, DataTransformRecipe, PrettyDirectory
if _LIGHTNING_CLOUD_GREATER_EQUAL_0_5_46:
from lightning_cloud.resolver import _execute, _LightningSrcResolver
if _TORCH_GREATER_EQUAL_2_1_0:
from torch.utils._pytree import tree_flatten
def _get_input_dir(inputs: Sequence[Any]) -> str:
flattened_item, _ = tree_flatten(inputs[0])
indexed_paths = {
index: element
for index, element in enumerate(flattened_item)
if isinstance(element, str) and os.path.exists(element)
}
if len(indexed_paths) == 0:
raise ValueError(f"The provided item {inputs[0]} didn't contain any filepaths.")
absolute_path = str(Path(indexed_paths[0]).resolve())
if indexed_paths[0] != absolute_path:
raise ValueError("The provided path should be absolute.")
return "/" + os.path.join(*str(absolute_path).split("/")[:4])
class LambdaDataTransformRecipe(DataTransformRecipe):
def __init__(self, fn: Callable[[str, Any], None], inputs: Sequence[Any]):
super().__init__()
self._fn = fn
self._inputs = inputs
def prepare_structure(self, input_dir: Optional[str]) -> Any:
return self._inputs
def prepare_item(self, output_dir: str, item_metadata: Any) -> None: # type: ignore
self._fn(output_dir, item_metadata)
class LambdaDataChunkRecipe(DataChunkRecipe):
def __init__(
self,
fn: Callable[[Any], None],
inputs: Sequence[Any],
chunk_size: Optional[int],
chunk_bytes: Optional[int],
compression: Optional[str],
):
super().__init__(chunk_size=chunk_size, chunk_bytes=chunk_bytes, compression=compression)
self._fn = fn
self._inputs = inputs
def prepare_structure(self, input_dir: Optional[str]) -> Any:
return self._inputs
def prepare_item(self, item_metadata: Any) -> Any: # type: ignore
if isinstance(self._fn, GeneratorType):
yield from self._fn(item_metadata)
else:
yield self._fn(item_metadata)
def map(
fn: Callable[[str, Any], None],
inputs: Sequence[Any],
output_dir: str,
num_workers: Optional[int] = None,
fast_dev_run: Union[bool, int] = False,
num_nodes: Optional[int] = None,
machine: Optional[str] = None,
input_dir: Optional[str] = None,
) -> None:
"""This function map a callbable over a collection of files possibly in a distributed way.
Arguments:
fn: A function to be executed over each input element
inputs: A sequence of input to be processed by the `fn` function.
Each input should contain at least a valid filepath.
output_dir: The folder where the processed data should be stored.
num_workers: The number of workers to use during processing
fast_dev_run: Whether to use process only a sub part of the inputs
num_nodes: When doing remote execution, the number of nodes to use.
machine: When doing remote execution, the machine to use.
"""
if not isinstance(inputs, Sequence):
raise ValueError(f"The provided inputs should be non empty sequence. Found {inputs}.")
if len(inputs) == 0:
raise ValueError(f"The provided inputs should be non empty. Found {inputs}.")
if num_nodes is None or int(os.getenv("DATA_OPTIMIZER_NUM_NODES", 0)) > 0:
remote_output_dir = _LightningSrcResolver()(output_dir)
if remote_output_dir is None or "cloudspaces" in remote_output_dir:
raise ValueError(
f"The provided `output_dir` isn't valid. Found {output_dir}."
" HINT: You can either use `/teamspace/s3_connections/...` or `/teamspace/datasets/...`."
)
data_processor = DataProcessor(
num_workers=num_workers or os.cpu_count(),
remote_output_dir=PrettyDirectory(output_dir, remote_output_dir),
fast_dev_run=fast_dev_run,
version=None,
input_dir=input_dir or _get_input_dir(inputs),
)
return data_processor.run(LambdaDataTransformRecipe(fn, inputs))
return _execute(
f"data-prep-map-{datetime.now().strftime('%Y-%m-%d_%H-%M-%S')}",
num_nodes,
machine,
)
def optimize(
fn: Callable[[Any], Any],
inputs: Sequence[Any],
output_dir: str,
chunk_size: Optional[int] = None,
chunk_bytes: Optional[int] = None,
compression: Optional[str] = None,
name: Optional[str] = None,
num_workers: Optional[int] = None,
fast_dev_run: bool = False,
num_nodes: Optional[int] = None,
machine: Optional[str] = None,
input_dir: Optional[str] = None,
) -> None:
"""This function converts a dataset into chunks possibly in a distributed way.
Arguments:
fn: A function to be executed over each input element
inputs: A sequence of input to be processed by the `fn` function.
Each input should contain at least a valid filepath.
output_dir: The folder where the processed data should be stored.
chunk_size: The maximum number of elements to hold within a chunk.
chunk_bytes: The maximum number of bytes to hold within a chunk.
compression: The compression algorithm to use over the chunks.
num_workers: The number of workers to use during processing
fast_dev_run: Whether to use process only a sub part of the inputs
num_nodes: When doing remote execution, the number of nodes to use.
machine: When doing remote execution, the machine to use.
"""
if not isinstance(inputs, Sequence):
raise ValueError(f"The provided inputs should be non empty sequence. Found {inputs}.")
if len(inputs) == 0:
raise ValueError(f"The provided inputs should be non empty. Found {inputs}.")
if chunk_size is None and chunk_bytes is None:
raise ValueError("Either `chunk_size` or `chunk_bytes` needs to be defined.")
if num_nodes is None or int(os.getenv("DATA_OPTIMIZER_NUM_NODES", 0)) > 0:
remote_output_dir = _LightningSrcResolver()(output_dir)
if remote_output_dir is None or "cloudspaces" in remote_output_dir:
raise ValueError(
f"The provided `output_dir` isn't valid. Found {output_dir}."
" HINT: You can either use `/teamspace/s3_connections/...` or `/teamspace/datasets/...`."
)
data_processor = DataProcessor(
name=name,
num_workers=num_workers or os.cpu_count(),
remote_output_dir=PrettyDirectory(output_dir, remote_output_dir),
fast_dev_run=fast_dev_run,
input_dir=input_dir or _get_input_dir(inputs),
)
return data_processor.run(
LambdaDataChunkRecipe(
fn,
inputs,
chunk_size=chunk_size,
chunk_bytes=chunk_bytes,
compression=compression,
)
)
return _execute(
f"data-prep-optimize-{datetime.now().strftime('%Y-%m-%d_%H-%M-%S')}",
num_nodes,
machine,
)

View File

@ -70,6 +70,9 @@ class PyTreeLoader(BaseItemLoader):
if chunk_filepath not in self._chunk_filepaths:
while not os.path.exists(chunk_filepath):
sleep(0.001)
# Wait to avoid any corruption when the file appears
sleep(0.001)
self._chunk_filepaths[chunk_filepath] = True
with open(chunk_filepath, "rb", 0) as fp:

View File

@ -206,7 +206,8 @@ class BinaryWriter:
if len(items) == 0:
raise RuntimeError(
f"The items shouldn't have an empty length. Something went wrong. Found {self._serialized_items}."
"The items shouldn't have an empty length. Something went wrong."
f" Found {self._pretty_serialized_items()} with boundaries: {self._min_index}, {self._max_index}."
)
sizes = list(map(len, items))
@ -414,3 +415,15 @@ class BinaryWriter:
return f1 != f2
return any(is_non_valid(f1, f2) for f1, f2 in zip(data_format_1, data_format_2))
def _pretty_serialized_items(self) -> Dict[int, Item]:
out = {}
for key, value in self._serialized_items.items():
# drop `data` as it would make logs unreadable.
out[key] = Item(
index=value.index,
bytes=value.bytes,
dim=value.dim,
data=b"",
)
return out

View File

@ -8,6 +8,8 @@ import pytest
import torch
from lightning import seed_everything
from lightning.data.streaming import data_processor as data_processor_module
from lightning.data.streaming import functions
from lightning.data.streaming.cache import Cache
from lightning.data.streaming.data_processor import (
DataChunkRecipe,
DataProcessor,
@ -18,7 +20,7 @@ from lightning.data.streaming.data_processor import (
_upload_fn,
_wait_for_file_to_exist,
)
from lightning.data.streaming.map import map
from lightning.data.streaming.functions import map, optimize
from lightning_utilities.core.imports import RequirementCache
_PIL_AVAILABLE = RequirementCache("PIL")
@ -519,7 +521,7 @@ def test_data_process_transform(monkeypatch, tmpdir):
assert img.size == (12, 12)
def fn(output_dir, filepath):
def map_fn(output_dir, filepath):
from PIL import Image
img = Image.open(filepath)
@ -528,31 +530,84 @@ def fn(output_dir, filepath):
img.save(os.path.join(output_dir, os.path.basename(filepath)))
@pytest.mark.skipif(condition=not _PIL_AVAILABLE or sys.platform == "win32", reason="Requires: ['pil']")
def test_data_processing_map(monkeypatch, tmpdir):
from PIL import Image
input_dir = os.path.join(tmpdir, "input_dir")
os.makedirs(input_dir, exist_ok=True)
imgs = []
for i in range(5):
np_data = np.random.randint(255, size=(28, 28), dtype=np.uint32)
img = Image.fromarray(np_data).convert("L")
imgs.append(img)
img.save(os.path.join(tmpdir, f"{i}.JPEG"))
img.save(os.path.join(input_dir, f"{i}.JPEG"))
home_dir = os.path.join(tmpdir, "home")
cache_dir = os.path.join(tmpdir, "cache")
remote_output_dir = os.path.join(tmpdir, "target_dir")
os.makedirs(remote_output_dir, exist_ok=True)
output_dir = os.path.join(tmpdir, "target_dir")
os.makedirs(output_dir, exist_ok=True)
monkeypatch.setenv("DATA_OPTIMIZER_HOME_FOLDER", home_dir)
monkeypatch.setenv("DATA_OPTIMIZER_CACHE_FOLDER", cache_dir)
inputs = [os.path.join(tmpdir, filename) for filename in os.listdir(tmpdir)]
resolver = mock.MagicMock()
resolver.return_value = lambda x: x
monkeypatch.setattr(functions, "_LightningSrcResolver", resolver)
monkeypatch.setattr(data_processor_module, "_LightningSrcResolver", resolver)
monkeypatch.setattr(data_processor_module, "_LightningTargetResolver", resolver)
inputs = [os.path.join(input_dir, filename) for filename in os.listdir(input_dir)]
inputs = [filepath for filepath in inputs if os.path.isfile(filepath)]
map(fn, inputs, num_workers=1, remote_output_dir=remote_output_dir)
map(map_fn, inputs, num_workers=1, output_dir=output_dir, input_dir=input_dir)
assert sorted(os.listdir(remote_output_dir)) == ["0.JPEG", "1.JPEG", "2.JPEG", "3.JPEG", "4.JPEG"]
assert sorted(os.listdir(output_dir)) == ["0.JPEG", "1.JPEG", "2.JPEG", "3.JPEG", "4.JPEG"]
from PIL import Image
img = Image.open(os.path.join(remote_output_dir, "0.JPEG"))
img = Image.open(os.path.join(output_dir, "0.JPEG"))
assert img.size == (12, 12)
def optimize_fn(filepath):
print(filepath)
from PIL import Image
return [Image.open(filepath), os.path.basename(filepath)]
@pytest.mark.skipif(condition=not _PIL_AVAILABLE or sys.platform == "win32", reason="Requires: ['pil']")
def test_data_processing_optimize(monkeypatch, tmpdir):
from PIL import Image
input_dir = os.path.join(tmpdir, "input_dir")
os.makedirs(input_dir, exist_ok=True)
imgs = []
for i in range(5):
np_data = np.random.randint(255, size=(28, 28), dtype=np.uint32)
img = Image.fromarray(np_data).convert("L")
imgs.append(img)
img.save(os.path.join(input_dir, f"{i}.JPEG"))
home_dir = os.path.join(tmpdir, "home")
cache_dir = os.path.join(tmpdir, "cache")
output_dir = os.path.join(tmpdir, "target_dir")
os.makedirs(output_dir, exist_ok=True)
monkeypatch.setenv("DATA_OPTIMIZER_HOME_FOLDER", home_dir)
monkeypatch.setenv("DATA_OPTIMIZER_CACHE_FOLDER", cache_dir)
inputs = [os.path.join(input_dir, filename) for filename in os.listdir(input_dir)]
inputs = [filepath for filepath in inputs if os.path.isfile(filepath)]
resolver = mock.MagicMock()
resolver.return_value = lambda x: x
monkeypatch.setattr(functions, "_LightningSrcResolver", resolver)
monkeypatch.setattr(data_processor_module, "_LightningSrcResolver", resolver)
monkeypatch.setattr(data_processor_module, "_LightningTargetResolver", resolver)
optimize(optimize_fn, inputs, num_workers=1, output_dir=output_dir, chunk_size=2, input_dir=input_dir)
assert sorted(os.listdir(output_dir)) == ["chunk-0-0.bin", "chunk-0-1.bin", "chunk-0-2.bin", "index.json"]
cache = Cache(output_dir, chunk_size=1)
assert len(cache) == 5