Introduce Cache 1/n (#18642)

Co-authored-by: Ethan Harris <ethanwharris@gmail.com>
Co-authored-by: Luca Antiga <luca.antiga@gmail.com>
Co-authored-by: thomas <thomas@thomass-MacBook-Pro.local>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
thomas chaton 2023-10-09 16:06:32 +01:00 committed by GitHub
parent 6537a05977
commit 1d5851ffe2
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
21 changed files with 2267 additions and 20 deletions

View File

@ -173,9 +173,9 @@ subprojects:
- "!*.md"
- "!**/*.md"
checks:
- "data-cpu (macOS-11, lightning, 3.10, 2.0)"
- "data-cpu (ubuntu-20.04, lightning, 3.10, 2.0)"
- "data-cpu (windows-2022, lightning, 3.10, 2.0)"
- "data-cpu (macOS-11, lightning, 3.10, 2.1)"
- "data-cpu (ubuntu-20.04, lightning, 3.10, 2.1)"
- "data-cpu (windows-2022, lightning, 3.10, 2.1)"
# SECTION: lightning_fabric

View File

@ -34,9 +34,9 @@ jobs:
fail-fast: false
matrix:
include:
- { os: "macOS-11", pkg-name: "lightning", python-version: "3.10", pytorch-version: "2.0" }
- { os: "ubuntu-20.04", pkg-name: "lightning", python-version: "3.10", pytorch-version: "2.0" }
- { os: "windows-2022", pkg-name: "lightning", python-version: "3.10", pytorch-version: "2.0" }
- { os: "macOS-11", pkg-name: "lightning", python-version: "3.10", pytorch-version: "2.1" }
- { os: "ubuntu-20.04", pkg-name: "lightning", python-version: "3.10", pytorch-version: "2.1" }
- { os: "windows-2022", pkg-name: "lightning", python-version: "3.10", pytorch-version: "2.1" }
# "oldest" versions tests, only on minimum Python
# - {os: "macOS-11", pkg-name: "lightning", python-version: "3.8", pytorch-version: "2.0", requires: "oldest"}
# - {os: "ubuntu-20.04", pkg-name: "lightning", python-version: "3.8", pytorch-version: "2.0", requires: "oldest"}

View File

@ -3,6 +3,6 @@
lightning-utilities >=0.8.0, <0.10.0
# to be able to include also 0.6 and preserve `>` needed for CI min version bypass
torchdata >0.5.9, <0.7.0
torchdata >0.5.9, <=0.7.0
# to be able to include also PL 2.0 and preserve `>` needed for CI min version bypass
torch >0.14.0, <2.1.0
torch >0.14.0, <=2.1.0

17
src/lightning/data/cache/__init__.py vendored Normal file
View File

@ -0,0 +1,17 @@
# Copyright The Lightning AI team.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from lightning.data.cache.cache import Cache
from lightning.data.cache.dataloader import LightningDataLoader
__all__ = ["Cache", "LightningDataLoader"]

90
src/lightning/data/cache/cache.py vendored Normal file
View File

@ -0,0 +1,90 @@
# Copyright The Lightning AI team.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import os
from typing import Any, Dict, List, Optional, Tuple, Union
from lightning.data.cache.constants import _INDEX_FILENAME, _TORCH_2_1_0_AVAILABLE
from lightning.data.cache.reader import BinaryReader
from lightning.data.cache.sampler import ChunkedIndex
from lightning.data.cache.writer import BinaryWriter
from lightning.data.datasets.env import _DistributedEnv
logger = logging.Logger(__name__)
class Cache:
def __init__(
self,
cache_dir: str,
remote_dir: Optional[str] = None,
compression: Optional[str] = None,
chunk_size: Optional[int] = None,
chunk_bytes: Optional[int] = None,
):
"""The Cache enables to optimise dataset format for cloud training. This is done by grouping several elements
together in order to accelerate fetching.
Arguments:
cache_dir: The path to where the chunks will be stored.
remote_dir: The path to a remote folder where the data are located.
The scheme needs to be added to the path.
compression: The name of the algorithm to reduce the size of the chunks.
chunk_bytes: The maximum number of bytes within a chunk.
chunk_size: The maximum number of items within a chunk.
"""
super().__init__()
if not _TORCH_2_1_0_AVAILABLE:
raise ModuleNotFoundError("PyTorch version 2.1 or higher is required to use the cache.")
self._writer = BinaryWriter(cache_dir, chunk_size=chunk_size, chunk_bytes=chunk_bytes, compression=compression)
self._reader = BinaryReader(cache_dir, remote_dir=remote_dir, compression=compression)
self._cache_dir = cache_dir
self._is_done = False
self._distributed_env = _DistributedEnv.detect()
@property
def filled(self) -> bool:
"""Returns whether the caching phase is done."""
if self._is_done:
return True
self._is_done = os.path.exists(os.path.join(self._cache_dir, _INDEX_FILENAME))
return self._is_done
def __setitem__(self, index: int, data: Any) -> None:
"""Store an item in the writer."""
self._writer[index] = data
def __getitem__(self, index: Union[int, ChunkedIndex]) -> Dict[str, Any]:
"""Read an item in the reader."""
if isinstance(index, int):
index = ChunkedIndex(index, self._get_chunk_index_from_index(index))
return self._reader.read(index)
def done(self) -> None:
"""Inform the writer the chunking phase is finished."""
self._writer.done()
def merge(self, num_workers: int = 1) -> None:
"""Inform the writer the chunking phase is finished."""
self._writer.merge(num_workers)
def __len__(self) -> int:
return self._reader.get_length()
def get_chunk_interval(self) -> List[Tuple[int, int]]:
return self._reader.get_chunk_interval()
def _get_chunk_index_from_index(self, index: int) -> int:
return self._reader._get_chunk_index_from_index(index)

76
src/lightning/data/cache/compression.py vendored Normal file
View File

@ -0,0 +1,76 @@
# Copyright The Lightning AI team.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from abc import ABC, abstractclassmethod, abstractmethod
from typing import Dict, TypeVar
from lightning_utilities.core.imports import RequirementCache, requires
_ZSTD_AVAILABLE = RequirementCache("zstd")
if _ZSTD_AVAILABLE:
import zstd
TCompressor = TypeVar("TCompressor", bound="Compressor")
class Compressor(ABC):
"""Base class for compression algorithm."""
@abstractmethod
def compress(self, data: bytes) -> bytes:
pass
@abstractmethod
def decompress(self, data: bytes) -> bytes:
pass
@abstractclassmethod
def register(cls, compressors: Dict[str, "Compressor"]) -> None:
pass
class ZSTDCompressor(Compressor):
"""Compressor for the zstd package."""
@requires("zstd")
def __init__(self, level: int) -> None:
super().__init__()
self.level = level
self.extension = "zstd"
@property
def name(self) -> str:
return f"{self.extension}:{self.level}"
def compress(self, data: bytes) -> bytes:
return zstd.compress(data, self.level)
def decompress(self, data: bytes) -> bytes:
return zstd.decompress(data)
@classmethod
def register(cls, compressors: Dict[str, "Compressor"]) -> None: # type: ignore
if not _ZSTD_AVAILABLE:
return
# default
compressors["zstd"] = ZSTDCompressor(4)
for level in list(range(1, 23)):
compressors[f"zstd:{level}"] = ZSTDCompressor(level)
_COMPRESSORS: Dict[str, Compressor] = {}
ZSTDCompressor.register(_COMPRESSORS)

125
src/lightning/data/cache/config.py vendored Normal file
View File

@ -0,0 +1,125 @@
# Copyright The Lightning AI team.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json
import os
from typing import Any, Dict, List, Optional, Tuple
from lightning.data.cache.constants import _INDEX_FILENAME, _TORCH_2_1_0_AVAILABLE
from lightning.data.cache.downloader import get_downloader_cls
from lightning.data.cache.sampler import ChunkedIndex
if _TORCH_2_1_0_AVAILABLE:
from torch.utils._pytree import treespec_loads
class ChunksConfig:
def __init__(self, cache_dir: str, remote_dir: Optional[str]):
"""The ChunksConfig reads the index files associated a chunked dataset and enables to map an index to its
chunk.
Arguments:
cache_dir: The path to cache folder.
remote_dir: The path to a remote folder where the data are located.
The scheme needs to be added to the path.
"""
self._cache_dir = cache_dir
self._intervals: List[Tuple[int, int]] = []
self._config = None
self._chunks = []
self._remote_dir = remote_dir
with open(os.path.join(self._cache_dir, _INDEX_FILENAME)) as f:
data = json.load(f)
self._config = data["config"]
self._chunks.extend(data["chunks"])
self._config["data_spec"] = treespec_loads(self._config["data_spec"])
for chunk in self._chunks:
start, end = chunk["interval"]
if (end - start) != chunk["chunk_size"]:
raise Exception(
"The config intervals doesn't match the number of samples. This shouldn't have happened."
)
self._intervals.append((chunk["interval"][0], chunk["interval"][1]))
self._length = sum([chunk["chunk_size"] for chunk in self._chunks])
self._downloader = None
if remote_dir:
self._downloader = get_downloader_cls(remote_dir)(remote_dir, cache_dir, self._chunks)
def download_chunk_from_index(self, chunk_index: int) -> None:
chunk_filename = self._chunks[chunk_index]["filename"]
local_chunkpath = os.path.join(self._cache_dir, chunk_filename)
if os.path.exists(local_chunkpath):
return
if self._downloader is None:
raise RuntimeError("The downloader should be defined.")
self._downloader.download_chunk_from_index(chunk_index)
@property
def intervals(self) -> List[Tuple[int, int]]:
if self._intervals is None:
raise RuntimeError("The intervals should be defined.")
return self._intervals
@property
def data_format(self) -> Any:
if self._config is None:
raise RuntimeError("The config should be defined.")
return self._config["data_format"]
@property
def config(self) -> Dict[str, Any]:
if self._config is None:
raise RuntimeError("The config should be defined.")
return self._config
def _get_chunk_index_from_index(self, index: int) -> int:
for chunk_index, internal in enumerate(self._intervals):
if internal[0] <= index < internal[1]:
return chunk_index
raise ValueError(
f"The provided index {index} didn't find a match within the chunk intervals {self._intervals}."
)
def __getitem__(self, index: ChunkedIndex) -> Tuple[str, int, int]:
"""Find the associated chunk metadata."""
chunk = self._chunks[index.chunk_index]
return os.path.join(self._cache_dir, chunk["filename"]), *self._intervals[index.chunk_index]
@classmethod
def load(cls, cache_dir: str, remote_dir: Optional[str] = None) -> Optional["ChunksConfig"]:
cache_index_filepath = os.path.join(cache_dir, _INDEX_FILENAME)
if isinstance(remote_dir, str):
downloader = get_downloader_cls(remote_dir)(remote_dir, cache_dir, [])
downloader.download_file(os.path.join(remote_dir, _INDEX_FILENAME), cache_index_filepath)
if not os.path.exists(cache_index_filepath):
return None
return ChunksConfig(cache_dir, remote_dir)
def __len__(self) -> int:
return self._length

21
src/lightning/data/cache/constants.py vendored Normal file
View File

@ -0,0 +1,21 @@
# Copyright The Lightning AI team.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from lightning_utilities.core.imports import RequirementCache
_INDEX_FILENAME = "index.json"
_DEFAULT_CHUNK_BYTES = 1 << 26 # 64M B
# This is required for full pytree serialization / deserialization support
_TORCH_2_1_0_AVAILABLE = RequirementCache("torch>=2.1.0")
_VIZ_TRACKER_AVAILABLE = RequirementCache("viztracer")

311
src/lightning/data/cache/dataloader.py vendored Normal file
View File

@ -0,0 +1,311 @@
# Copyright The Lightning AI team.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import asyncio
import inspect
import logging
import os
from importlib import reload
from typing import Any, Callable, List, Optional
import torch
from torch.utils.data import Dataset, IterableDataset
from torch.utils.data._utils.collate import default_collate
from torch.utils.data._utils.fetch import _BaseDatasetFetcher
from torch.utils.data.dataloader import (
DataLoader,
_BaseDataLoaderIter,
_DatasetKind,
_MultiProcessingDataLoaderIter,
_SingleProcessDataLoaderIter,
)
from torch.utils.data.sampler import BatchSampler, Sampler
from lightning.data.cache import Cache
from lightning.data.cache.constants import _DEFAULT_CHUNK_BYTES, _TORCH_2_1_0_AVAILABLE, _VIZ_TRACKER_AVAILABLE
from lightning.data.cache.sampler import CacheBatchSampler
from lightning.data.datasets.env import _DistributedEnv, _WorkerEnv
if _TORCH_2_1_0_AVAILABLE:
from torch.utils._pytree import tree_flatten
logger = logging.Logger(__name__)
def _equal_items(data_1: Any, data_2: Any) -> bool:
data_1_flattened, _ = tree_flatten(data_1)
data_2_flattened, _ = tree_flatten(data_2)
if len(data_1_flattened) != len(data_2_flattened):
return False
return all(_equal_item(d1, d2) for d1, d2 in zip(data_1_flattened, data_2_flattened))
def _equal_item(d1: Any, d2: Any) -> bool:
if not isinstance(d1, type(d2)):
return False
equality = d1 == d2
if isinstance(equality, torch.Tensor):
return bool(equality.all().item())
if equality is True:
return True
return False
class CacheDataset(Dataset):
def __init__(
self,
dataset: Any,
cache_dir: str,
chunk_bytes: Optional[int],
chunk_size: Optional[int],
compression: Optional[str],
):
"""The `CacheDataset` is a dataset wraper to provide a beginner experience with the Cache.
Arguments:
dataset: The dataset of the user
cache_dir: The folder where the chunks are written to.
chunk_bytes: The maximal number of bytes to write within a chunk.
chunk_sie: The maximal number of items to write to a chunk.
compression: The compression algorithm to use to reduce the size of the chunk.
"""
self._dataset = dataset
self._cache = Cache(cache_dir, chunk_bytes=chunk_bytes, chunk_size=chunk_size, compression=compression)
self._is_deterministic = False
def __len__(self) -> int:
return len(self._cache) if self._cache.filled else len(self._dataset)
def __getitem__(self, index: int) -> Any:
data_1 = self._cache[index] if self._cache.filled else self._dataset[index]
if not self._cache.filled:
if not self._is_deterministic:
data2 = self._dataset[index]
if not _equal_items(data_1, data2):
raise ValueError(
f"Your dataset items aren't deterministic. Found {data_1} and {data2} for index {index}."
" HINT: Use the `lightning.data.cache.Cache` directly within your dataset."
)
self._is_deterministic = True
self._cache[index] = data_1
return data_1
class CacheCollateFn:
"""This CacheCollateFn is used to accelerate the processing of the data generated using the Cache.
During the chunking phase, there is no need to return any data from the DataLoader reducing some time.
Additionally, if the user makes their __getitem__ asynchronous, the collate executes them in parallel.
"""
def __init__(self, collate_fn: Optional[Callable] = None) -> None:
self.collate_fn = collate_fn or default_collate
def __call__(self, items: List[Any]) -> Any:
if all(item is None for item in items):
return None
# If the __getitem__ method is asynchornous, collect all the items.
if all(inspect.iscoroutine(item) for item in items):
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
items = loop.run_until_complete(asyncio.gather(*items))
return self.collate_fn([item for item in items if item is not None])
class _SingleProcessDataLoaderIterPatch(_SingleProcessDataLoaderIter):
"""This is overriden to inform the cache is done chunking."""
def _next_data(self) -> Any:
try:
data = None
while data is None:
data = super()._next_data()
return data
except StopIteration:
for v in self._dataset_fetcher.dataset.__dict__.values():
if isinstance(v, Cache):
v.done()
if not v.filled:
v.merge(1)
raise StopIteration()
class WorkerLoop:
"""Wrap the PyTorch DataLoader WorkerLoop to perform caching and profiling."""
def __init__(self, global_rank: int, profile: bool = False) -> None:
self._global_rank = global_rank
self._profile = profile
def __call__(self, dataset_kind: _DatasetKind, *args: Any, **kwargs: Any) -> None:
from torch.utils.data._utils import worker
from lightning.data.cache.cache import Cache
rank = _WorkerEnv.detect().rank
enable_profiling = self._global_rank == 0 and rank == 0 and _VIZ_TRACKER_AVAILABLE and self._profile
if enable_profiling:
from viztracer import VizTracer
tracer = VizTracer(output_file=os.path.join(os.getcwd(), "trace.json"))
tracer.start()
# Reload to remove the patching
reloaded_worker = reload(worker)
create_fetcher = _DatasetKind.create_fetcher
fetcher = None
def create_fetcher_fn(*args: Any, **kwargs: Any) -> "_BaseDatasetFetcher":
nonlocal fetcher
fetcher = create_fetcher(*args, **kwargs)
return fetcher
_DatasetKind.create_fetcher = create_fetcher_fn # type: ignore
reloaded_worker._worker_loop(dataset_kind, *args, **kwargs)
if dataset_kind == _DatasetKind.Map:
assert fetcher
for v in fetcher.dataset.__dict__.values():
if isinstance(v, Cache):
v.done()
if enable_profiling:
tracer.stop()
tracer.save()
class _MultiProcessingDataLoaderIterPatch(_MultiProcessingDataLoaderIter):
def __init__(self, loader: DataLoader) -> None:
self._cache = loader._cache
self._num_workers = loader.num_workers
# Patch PyTorch worker loop to call the `cache.done()` method.
from torch.utils.data._utils import worker
worker._worker_loop = WorkerLoop(loader._global_rank, loader._profile)
super().__init__(loader)
def _shutdown_workers(self) -> None:
super()._shutdown_workers()
# If the data isn't filled, we trigger an indedm merge
if not self._cache.filled:
self._cache.merge(self._num_workers)
def _next_data(self) -> Any:
try:
data = None
while data is None:
data = super()._next_data()
return data
except StopIteration as e:
raise e
class LightningDataLoader(DataLoader):
__doc__ = DataLoader.__doc__
def __init__(
self,
dataset: Any,
*args: Any,
sampler: Optional[Sampler] = None,
batch_sampler: Optional[BatchSampler] = None,
num_workers: int = 0,
shuffle: bool = False,
generator: Optional[torch.Generator] = None,
batch_size: Optional[int] = None,
drop_last: bool = False,
cache_dir: Optional[str] = None,
chunk_bytes: Optional[int] = _DEFAULT_CHUNK_BYTES,
compression: Optional[str] = None,
profile: bool = False,
collate_fn: Optional[Callable] = None,
**kwargs: Any,
) -> None:
if sampler:
raise ValueError(
"The LightningDataLoader relies on its own internal sampler. Passing a sampler isn't supported."
)
if batch_sampler:
raise ValueError(
"The LightningDataLoader relies on its own internal sampler. Passing a batch_sampler isn't supported."
)
if isinstance(dataset, IterableDataset):
raise ValueError("Only map-based dataset are supported by the LightningDataLoader for now.")
if profile and not _VIZ_TRACKER_AVAILABLE:
raise ModuleNotFoundError("To enable DataLoader profiling, run `pip install viztracer`.")
cache_list = [v for v in dataset.__dict__.values() if isinstance(v, Cache)]
if len(cache_list) > 1:
raise ValueError(
"We found several Cache used as attributes from your dataset. Only one is support for now."
)
if len(cache_list) == 0:
if cache_dir is None:
raise ValueError("You should provide a `cache_dir` filepath to the LightningDataLoader.")
dataset = CacheDataset(dataset, cache_dir, chunk_bytes, batch_size, compression)
cache = dataset._cache
else:
cache = cache_list[0]
if not cache.filled and shuffle:
logger.info("Shuffle is ignored during the caching phase phase.")
self._cache = cache
distributed_env = _DistributedEnv.detect()
self._global_rank = distributed_env.global_rank
batch_sampler = CacheBatchSampler(
len(dataset),
distributed_env.world_size,
self._global_rank,
num_workers,
batch_size or 1,
drop_last,
shuffle,
cache,
)
self._profile = profile
super().__init__(
dataset,
*args,
batch_sampler=batch_sampler, # type: ignore
collate_fn=CacheCollateFn(collate_fn),
num_workers=num_workers,
**kwargs,
)
def _get_iterator(self) -> "_BaseDataLoaderIter":
"""Overriden to ensure the `Cache.done()` method is triggered on iteration done."""
if self.num_workers == 0:
return _SingleProcessDataLoaderIterPatch(self)
self.check_worker_number_rationality()
return _MultiProcessingDataLoaderIterPatch(self)

74
src/lightning/data/cache/downloader.py vendored Normal file
View File

@ -0,0 +1,74 @@
# Copyright The Lightning AI team.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
from abc import ABC, abstractmethod
from typing import Any, Dict, List, Type
from urllib import parse
class Downloader(ABC):
def __init__(self, remote_dir: str, cache_dir: str, chunks: List[Dict[str, Any]]):
self._remote_dir = remote_dir
self._cache_dir = cache_dir
self._chunks = chunks
def download_chunk_from_index(self, chunk_index: int) -> None:
chunk_filename = self._chunks[chunk_index]["filename"]
local_chunkpath = os.path.join(self._cache_dir, chunk_filename)
remote_chunkpath = os.path.join(self._remote_dir, chunk_filename)
self.download_file(remote_chunkpath, local_chunkpath)
@abstractmethod
def download_file(self, remote_chunkpath: str, local_chunkpath: str) -> None:
pass
class S3Downloader(Downloader):
@classmethod
def download_file(cls, remote_filepath: str, local_filepath: str) -> None:
import boto3
from boto3.s3.transfer import TransferConfig
from botocore.config import Config
obj = parse.urlparse(remote_filepath)
if obj.scheme != "s3":
raise ValueError(f"Expected obj.scheme to be `s3`, instead, got {obj.scheme} for remote={remote_filepath}")
extra_args: Dict[str, Any] = {}
# Create a new session per thread
session = boto3.session.Session()
# Create a resource client using a thread's session object
s3 = session.client("s3", config=Config(read_timeout=None))
# Threads calling S3 operations return RuntimeError (cannot schedule new futures after
# interpreter shutdown). Temporary solution is to have `use_threads` as `False`.
# Issue: https://github.com/boto/boto3/issues/3113
s3.download_file(
obj.netloc,
obj.path.lstrip("/"),
local_filepath,
ExtraArgs=extra_args,
Config=TransferConfig(use_threads=False),
)
# TODO: Add fsspec support
_DOWNLOADERS = {"s3://": S3Downloader}
def get_downloader_cls(remote_dir: str) -> Type[Downloader]:
for k, cls in _DOWNLOADERS.items():
if remote_dir.startswith(k):
return cls
raise ValueError(f"The provided `remote_dir` {remote_dir} doesn't have a downloader associated.")

177
src/lightning/data/cache/reader.py vendored Normal file
View File

@ -0,0 +1,177 @@
# Copyright The Lightning AI team.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
from threading import Lock, Thread
from time import sleep
from typing import Any, Dict, List, Optional, Tuple
import numpy as np
from lightning.data.cache.config import ChunksConfig
from lightning.data.cache.constants import _TORCH_2_1_0_AVAILABLE
from lightning.data.cache.sampler import ChunkedIndex
from lightning.data.cache.serializers import _SERIALIZERS, Serializer
from lightning.data.datasets.env import _DistributedEnv, _WorkerEnv
if _TORCH_2_1_0_AVAILABLE:
from torch.utils._pytree import PyTree, tree_unflatten
class PrepareChunksThread(Thread):
"""This thread is responsible to download the chunks associated to a given worker."""
def __init__(self, config: ChunksConfig) -> None:
super().__init__(daemon=True)
self._config = config
self._chunks_index_to_be_processed: List[int] = []
self._chunks_index_to_ready: List[int] = []
self._lock = Lock()
def add(self, chunk_indices: List[int]) -> None:
"""Receive the list of the chunk indices to download for the current epoch."""
with self._lock:
self._chunks_index_to_be_processed.extend(chunk_indices)
def run(self) -> None:
while True:
with self._lock:
if len(self._chunks_index_to_be_processed) == 0:
sleep(0.007)
continue
chunk_index = self._chunks_index_to_be_processed.pop(0)
# TODO: Implement eviction
self._config.download_chunk_from_index(chunk_index)
self._chunks_index_to_ready.append(chunk_index)
class BinaryReader:
def __init__(self, cache_dir: str, remote_dir: Optional[str] = None, compression: Optional[str] = None) -> None:
"""The BinaryReader enables to read chunked dataset in an efficient way.
Arguments:
cache_dir: The path to cache folder.
remote_dir: The path to a remote folder where the data are located.
The scheme needs to be added to the path.
compression: The algorithm to decompress the chunks.
"""
super().__init__()
self._cache_dir = cache_dir
self._remote_dir = remote_dir
if not os.path.exists(self._cache_dir):
raise FileNotFoundError(f"The provided cache_dir `{self._cache_dir}` doesn't exist.")
self._compression = compression
self._intervals: Optional[List[str]] = None
self._serializers: Dict[str, Serializer] = _SERIALIZERS
self._distributed_env = _DistributedEnv.detect()
self._rank: Optional[int] = None
self._config: Optional[ChunksConfig] = None
self._prepare_thread: Optional[PrepareChunksThread] = None
def _get_chunk_index_from_index(self, index: int) -> int:
# Load the config containing the index
if self._config is None and self._try_load_config() is None:
raise Exception("The reader index isn't defined.")
return self._config._get_chunk_index_from_index(index) # type: ignore
def _try_load_config(self) -> Optional[ChunksConfig]:
"""Try to load the chunks config if the index files are available."""
self._config = ChunksConfig.load(self._cache_dir, self._remote_dir)
return self._config
@property
def config(self) -> ChunksConfig:
if self._config is None:
raise RuntimeError("The config should be defined.")
return self._config
@property
def rank(self) -> int:
"""Returns the rank of the writer."""
if self._rank is None:
self._worker_env = _WorkerEnv.detect()
self._rank = self._distributed_env.global_rank * self._worker_env.world_size + self._worker_env.rank
return self._rank
def read(self, index: ChunkedIndex) -> Any:
"""Read an item for the given from a chunk.
If the chunk isn't available locally or in memory, it will be downloaded.
Prefetching should reduce the wait time to be the batch available.
"""
if not isinstance(index, ChunkedIndex):
raise ValueError("The Reader.read(...) method expects a chunked Index.")
# Load the config containing the index
if self._config is None and self._try_load_config() is None:
raise Exception("The reader index isn't defined.")
# Create and start the prepare chunks thread
if index.chunk_indexes is not None and self._prepare_thread is None and self._config:
self._prepare_thread = PrepareChunksThread(self._config)
self._prepare_thread.start()
self._prepare_thread.add(index.chunk_indexes)
# Fetch the element
chunk_filepath, begin, _ = self.config[index]
raw_item_data = self.load_item_from_chunk(index.index, chunk_filepath, begin)
return self.deserialize(raw_item_data)
def deserialize(self, raw_item_data: bytes) -> "PyTree":
"""Deserialize the raw bytes into their python equivalent."""
idx = len(self.config.data_format) * 4
sizes = np.frombuffer(raw_item_data[:idx], np.uint32)
data = []
for size, data_format in zip(sizes, self.config.data_format):
serializer = self._serializers[data_format]
data_bytes = raw_item_data[idx : idx + size]
data.append(serializer.deserialize(data_bytes))
idx += size
return tree_unflatten(data, self.config.config["data_spec"])
def load_item_from_chunk(self, index: int, chunk_filepath: str, begin: int) -> bytes:
offset = (1 + (index - begin)) * 4
while not os.path.exists(chunk_filepath):
sleep(0.0001)
with open(chunk_filepath, "rb", 0) as fp:
fp.seek(offset)
pair = fp.read(8)
begin, end = np.frombuffer(pair, np.uint32)
fp.seek(begin)
data = fp.read(end - begin)
return data
def get_length(self) -> int:
"""Get the number of samples across all chunks."""
if self._config is None and self._try_load_config() is None:
raise Exception("The reader index isn't defined.")
return len(self.config)
def get_chunk_interval(self) -> List[Tuple[int, int]]:
"""Get the index interval of each chunk."""
if self._config is None and self._try_load_config() is None:
raise Exception("The reader index isn't defined.")
return self.config.intervals

226
src/lightning/data/cache/sampler.py vendored Normal file
View File

@ -0,0 +1,226 @@
# Copyright The Lightning AI team.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from dataclasses import dataclass
from typing import Any, Dict, Iterator, List, Optional, Union
import numpy as np
logger = logging.Logger(__name__)
@dataclass
class ChunkedIndex:
index: int
chunk_index: int
chunk_indexes: Optional[List[int]] = None
class CacheBatchSampler:
def __init__(
self,
dataset_size: int,
num_replicas: int,
global_rank: int,
num_workers: int,
batch_size: int,
drop_last: bool,
shuffle: bool,
cache: Any,
):
"""The CacheBatchSampler handles the generation of batch indices.
If the cache isn't filled, the batch sampler alternates with ordered indices for the writer to chunk the dataset
If the cache is filled, it acts as normal BatchSampler.
Arguments:
dataset_size: The size of the dataset.
num_replicas: The number of processes involves in the distributed training.
global_rank: The global_rank of the given process
num_workers: The number of workers provided to the DataLoader.
batch_size: The number of items in a batch.
drop_last: Whether to drop the last batch of data.
shuffle: Whether the data should be shuffled.
cache: The cache associated to the dataset.
"""
self._dataset_size = dataset_size
self._num_replicas = num_replicas
self._global_rank = global_rank
self._cache = cache
self._shuffle = shuffle
self._num_workers = num_workers or 1
self._shuffled_chunk_intervals = None
self._batch_size = batch_size
self._drop_last = drop_last
self._length = 0
# Before starting, ensures the chunk indices are properly defined.
self._validate()
def _validate(self) -> None:
"""Checks each worker is getting sucessive indices."""
if self._num_workers > 1 and not self._cache.filled:
batches: Dict[int, Any] = {}
for batch_index, batch_indices in enumerate(self):
self._length += 1
worker_index = batch_index % self._num_workers
if worker_index not in batches:
batches[worker_index] = []
batches[worker_index].extend(batch_indices)
elif len(batch_indices) > 0:
batches[worker_index].extend(batch_indices)
for indices in batches.values():
indices = np.asarray(indices)
diff = indices[1:] - (indices[:-1] + 1)
if diff.sum() != 0:
raise RuntimeError("This shouldn't have happened. There is a bug in the CacheSampler.")
def __iter__(self) -> Iterator[List[Union[int, ChunkedIndex]]]:
# When the cache is filled, we need to iterate though the chunks
if self._cache.filled:
if self._num_replicas == 1:
return self.__iter_from_chunks_non_distributed__()
return self.__iter_from_chunks_distributed__()
# shuffle is ignored while building the binarized version of the dataset
if self._num_replicas == 1:
return self.__iter_non_distributed__()
return self.__iter_distributed__()
def __iter_non_distributed__(self) -> Iterator[List[Union[int, ChunkedIndex]]]:
worker_size = self._dataset_size // self._num_workers
indices = list(range(self._dataset_size))
worker_indices = []
for worker_idx in range(self._num_workers):
is_last = worker_idx == self._num_workers - 1
start = worker_idx * worker_size
end = self._dataset_size if is_last else (worker_idx + 1) * worker_size
worker_indices.append(indices[start:end])
assert sum([len(s) for s in worker_indices]) == self._dataset_size
worker_indices_batches = [self._chunk_list(indices, self._batch_size) for indices in worker_indices]
yield from self.__iter_indices_per_workers__(worker_indices_batches)
def __iter_distributed__(self) -> Iterator[List[Union[int, ChunkedIndex]]]:
self.indices = list(range(self._dataset_size))
replica_size = self._dataset_size // self._num_replicas
worker_size = self._dataset_size // (self._num_replicas * self._num_workers)
for rank in range(self._num_replicas):
if rank != self._global_rank:
continue
is_last_replica = rank == self._num_replicas - 1
start_replica = rank * replica_size
end_replica = self._dataset_size if is_last_replica else (rank + 1) * replica_size
replica_indices = self.indices[start_replica:end_replica]
replica_size = len(replica_indices)
worker_indices = []
for worker_idx in range(self._num_workers):
is_last_worker = worker_idx == self._num_workers - 1
start_worker = worker_idx * worker_size
end_worker = replica_size if is_last_worker else (worker_idx + 1) * worker_size
worker_indices.append(replica_indices[start_worker:end_worker])
assert sum([len(s) for s in worker_indices]) == len(replica_indices)
worker_indices_batches = [self._chunk_list(indices, self._batch_size) for indices in worker_indices]
yield from self.__iter_indices_per_workers__(worker_indices_batches)
def __iter_from_chunks_non_distributed__(self) -> Iterator[List[Union[int, ChunkedIndex]]]:
chunk_intervals = self._cache.get_chunk_interval()
shuffled_indexes = np.random.permutation(range(len(chunk_intervals)))
shuffled_chunk_intervals = np.asarray(chunk_intervals)[shuffled_indexes]
yield from self.__iter_from_shuffled_chunks(shuffled_indexes.tolist(), shuffled_chunk_intervals)
def __iter_from_chunks_distributed__(self) -> Iterator[List[Union[int, ChunkedIndex]]]:
chunk_intervals = self._cache.get_chunk_interval()
shuffled_indexes = np.random.permutation(range(len(chunk_intervals)))
shuffled_chunk_intervals = np.asarray(chunk_intervals)[shuffled_indexes]
replica_chunks = []
replica_intervals = []
for index, (chunk_index, chunk_interval) in enumerate(zip(shuffled_indexes, shuffled_chunk_intervals)):
if index % self._num_replicas == self._global_rank:
replica_chunks.append(chunk_index)
replica_intervals.append(chunk_interval)
yield from self.__iter_from_shuffled_chunks(replica_chunks, replica_intervals)
def __iter_from_shuffled_chunks(
self, shuffled_indexes: List[int], shuffled_chunk_intervals: List[List[int]]
) -> Iterator[List[Union[int, ChunkedIndex]]]:
chunks_per_workers: List[List[int]] = [[] for _ in range(self._num_workers)]
for i, chunk_index in enumerate(shuffled_indexes):
chunks_per_workers[i % self._num_workers].append(chunk_index)
indices_per_workers: List[List[ChunkedIndex]] = [[] for _ in range(self._num_workers)]
for i, (chunk_index, chunk_interval) in enumerate(zip(shuffled_indexes, shuffled_chunk_intervals)):
worker_id = i % self._num_workers
interval_indices = np.arange(chunk_interval[0], chunk_interval[1])
shuffled_interval_indices = np.random.permutation(interval_indices).tolist()
is_empty = len(indices_per_workers[worker_id]) == 0
indices_per_workers[worker_id].extend(
[
ChunkedIndex(
index,
chunk_index,
chunk_indexes=chunks_per_workers[worker_id] if j == 0 and is_empty else None,
)
for j, index in enumerate(shuffled_interval_indices)
]
)
indices_per_workers_splitted = [self._chunk_list(indices, self._batch_size) for indices in indices_per_workers]
yield from self.__iter_indices_per_workers__(indices_per_workers_splitted)
def __len__(self) -> int:
return self._length
def __iter_indices_per_workers__(
self, indices_per_workers: List[List[List[Union[int, ChunkedIndex]]]]
) -> Iterator[List[Union[int, ChunkedIndex]]]:
batches: List[List[Union[int, ChunkedIndex]]] = []
counter = 0
while sum([len(v) for v in indices_per_workers]) != 0:
worker_indices = indices_per_workers[counter % self._num_workers]
if len(worker_indices) == 0:
batches.append([])
else:
batches.append(worker_indices.pop(0))
counter += 1
while True:
if len(batches[-1]) == 0:
batches.pop(-1)
else:
break
yield from batches
def _chunk_list(self, arr: List[Any], chunk_size: int) -> List[List[Any]]:
out = []
for i in range(0, len(arr), chunk_size):
slice_item = slice(i, i + chunk_size, 1)
out.append(arr[slice_item])
return out

223
src/lightning/data/cache/serializers.py vendored Normal file
View File

@ -0,0 +1,223 @@
# Copyright The Lightning AI team.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import pickle
from abc import ABC, abstractmethod
from collections import OrderedDict
from io import BytesIO
from typing import Any, Optional, Tuple, Union
import numpy as np
import torch
from lightning_utilities.core.imports import RequirementCache
_PIL_AVAILABLE = RequirementCache("PIL")
_TORCH_VISION_AVAILABLE = RequirementCache("torchvision")
if _PIL_AVAILABLE:
from PIL import Image
from PIL.JpegImagePlugin import JpegImageFile
else:
Image = None
JpegImageFile = None
if _TORCH_VISION_AVAILABLE:
from torchvision.io import decode_jpeg
class Serializer(ABC):
"""The base interface for any serializers.
A Serializer serialize and deserialize to and from bytes.
"""
@abstractmethod
def serialize(self, data: Any) -> Tuple[bytes, Optional[str]]:
pass
@abstractmethod
def deserialize(self, data: bytes) -> Any:
pass
@abstractmethod
def can_serialize(self, data: Any) -> bool:
pass
class PILSerializer(Serializer):
"""The PILSerializer serialize and deserialize PIL Image to and from bytes."""
def serialize(self, item: Image) -> Tuple[bytes, Optional[str]]:
mode = item.mode.encode("utf-8")
width, height = item.size
raw = item.tobytes()
ints = np.array([width, height, len(mode)], np.uint32)
return ints.tobytes() + mode + raw, None
def deserialize(self, data: bytes) -> Any:
idx = 3 * 4
width, height, mode_size = np.frombuffer(data[:idx], np.uint32)
idx2 = idx + mode_size
mode = data[idx:idx2].decode("utf-8")
size = width, height
raw = data[idx2:]
return Image.frombytes(mode, size, raw) # pyright: ignore
def can_serialize(self, item: Any) -> bool:
return isinstance(item, Image.Image) and not isinstance(item, JpegImageFile)
class IntSerializer(Serializer):
"""The IntSerializer serialize and deserialize integer to and from bytes."""
def serialize(self, item: int) -> Tuple[bytes, Optional[str]]:
return str(item).encode("utf-8"), None
def deserialize(self, data: bytes) -> int:
return int(data.decode("utf-8"))
def can_serialize(self, item: Any) -> bool:
return isinstance(item, int)
class JPEGSerializer(Serializer):
"""The JPEGSerializer serialize and deserialize JPEG image to and from bytes."""
def serialize(self, item: Image) -> Tuple[bytes, Optional[str]]:
if isinstance(item, JpegImageFile):
if not hasattr(item, "filename"):
raise ValueError(
"The JPEG Image's filename isn't defined. HINT: Open the image in your Dataset __getitem__ method."
)
with open(item.filename, "rb") as f:
return f.read(), None
raise TypeError(f"The provided itemect should be of type {JpegImageFile}. Found {item}.")
def deserialize(self, data: bytes) -> Union[JpegImageFile, torch.Tensor]:
if _TORCH_VISION_AVAILABLE:
array = torch.frombuffer(data, dtype=torch.uint8)
return decode_jpeg(array)
inp = BytesIO(data)
return Image.open(inp)
def can_serialize(self, item: Any) -> bool:
return isinstance(item, JpegImageFile)
class BytesSerializer(Serializer):
"""The BytesSerializer serialize and deserialize integer to and from bytes."""
def serialize(self, item: bytes) -> Tuple[bytes, Optional[str]]:
return item, None
def deserialize(self, item: bytes) -> bytes:
return item
def can_serialize(self, item: bytes) -> bool:
return isinstance(item, bytes)
_TORCH_DTYPES_MAPPING = {
0: torch.float32,
1: torch.float,
2: torch.float64,
3: torch.double,
4: torch.complex64,
5: torch.cfloat,
6: torch.complex128,
7: torch.cdouble,
8: torch.float16,
9: torch.half,
10: torch.bfloat16, # Not supported https://github.com/pytorch/pytorch/issues/110285
11: torch.uint8,
12: torch.int8,
13: torch.int16,
14: torch.short,
15: torch.int32,
16: torch.int,
17: torch.int64,
18: torch.long,
19: torch.bool,
}
class TensorSerializer(Serializer):
"""The TensorSerializer serialize and deserialize tensor to and from bytes."""
def __init__(self) -> None:
super().__init__()
self._dtype_to_indice = {v: k for k, v in _TORCH_DTYPES_MAPPING.items()}
def serialize(self, item: torch.Tensor) -> Tuple[bytes, Optional[str]]:
dtype_indice = self._dtype_to_indice[item.dtype]
data = [np.uint32(dtype_indice).tobytes()]
data.append(np.uint32(len(item.shape)).tobytes())
for dim in item.shape:
data.append(np.uint32(dim).tobytes())
data.append(item.numpy().tobytes())
return b"".join(data), None
def deserialize(self, data: bytes) -> torch.Tensor:
dtype_indice = np.frombuffer(data[0:4], np.uint32).item()
dtype = _TORCH_DTYPES_MAPPING[dtype_indice]
shape_size = np.frombuffer(data[4:8], np.uint32).item()
shape = []
for shape_idx in range(shape_size):
shape.append(np.frombuffer(data[8 + 4 * shape_idx : 8 + 4 * (shape_idx + 1)], np.uint32).item())
tensor = torch.frombuffer(data[8 + 4 * (shape_idx + 1) : len(data)], dtype=dtype)
return torch.reshape(tensor, torch.Size(shape))
def can_serialize(self, item: torch.Tensor) -> bool:
return isinstance(item, torch.Tensor) and type(item) == torch.Tensor
class PickleSerializer(Serializer):
"""The PickleSerializer serialize and deserialize python objects to and from bytes."""
def serialize(self, item: Any) -> Tuple[bytes, Optional[str]]:
return pickle.dumps(item), None
def deserialize(self, data: bytes) -> Any:
return pickle.loads(data)
def can_serialize(self, _: Any) -> bool:
return True
class FileSerializer(Serializer):
def serialize(self, filepath: str) -> Tuple[bytes, Optional[str]]:
_, file_extension = os.path.splitext(filepath)
with open(filepath, "rb") as f:
return f.read(), file_extension.replace(".", "").lower()
def deserialize(self, data: bytes) -> Any:
pass
def can_serialize(self, data: Any) -> bool:
return isinstance(data, str) and os.path.exists(data)
_SERIALIZERS = OrderedDict(
**{
"file": FileSerializer(),
"pil": PILSerializer(),
"int": IntSerializer(),
"jpeg": JPEGSerializer(),
"bytes": BytesSerializer(),
"tensor": TensorSerializer(),
"pickle": PickleSerializer(),
}
)

305
src/lightning/data/cache/writer.py vendored Normal file
View File

@ -0,0 +1,305 @@
# Copyright The Lightning AI team.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json
import os
from time import sleep
from typing import Any, Dict, List, Optional
import numpy as np
from lightning.data.cache.compression import _COMPRESSORS, Compressor
from lightning.data.cache.constants import _INDEX_FILENAME, _TORCH_2_1_0_AVAILABLE
from lightning.data.cache.serializers import _SERIALIZERS, Serializer
from lightning.data.datasets.env import _DistributedEnv, _WorkerEnv
if _TORCH_2_1_0_AVAILABLE:
from torch.utils._pytree import PyTree, tree_flatten, treespec_dumps
class BinaryWriter:
def __init__(
self,
cache_dir: str,
chunk_size: Optional[int] = None,
chunk_bytes: Optional[int] = None,
compression: Optional[str] = None,
):
"""The BinaryWriter enables to chunk dataset into an efficient streaming format for cloud training.
Arguments:
cache_dir: The path to where the chunks will be saved.
chunk_bytes: The maximum number of bytes within a chunk.
chunk_size: The maximum number of items within a chunk.
compression: The compression algorithm to use.
"""
self._cache_dir = cache_dir
if not os.path.exists(self._cache_dir):
raise FileNotFoundError(f"The provided cache directory `{self._cache_dir}` doesn't exist.")
if (chunk_size is None and chunk_bytes is None) or (chunk_size and chunk_bytes):
raise ValueError("Either one of the `chunk_size` or the `chunk_bytes` need to be provided.")
self._serializers: Dict[str, Serializer] = _SERIALIZERS
self._chunk_size = chunk_size
self._chunk_bytes = chunk_bytes
self._compression = compression
self._data_format: Optional[List[str]] = None
self._data_spec: Optional[PyTree] = None
if self._compression:
if len(_COMPRESSORS) == 0:
raise ValueError("No compresion algorithms are installed.")
if self._compression not in _COMPRESSORS:
raise ValueError(
f"The provided compression {self._compression} isn't available in {sorted(_COMPRESSORS)}"
)
self._compressor: Compressor = _COMPRESSORS[self._compression]
self._current_chunk_bytes = 0
self._chunk_index = 0
self._serialized_items: List[bytes] = []
self._chunks_info: List[Dict[str, Any]] = []
self._indexes: List[int] = []
self._worker_env: Optional[_WorkerEnv] = None
self._rank: Optional[int] = None
self._is_done = False
self._distributed_env = _DistributedEnv.detect()
@property
def filled(self) -> bool:
"""Returns whether the caching phase is done."""
if self._is_done:
return True
files = os.listdir(self._cache_dir)
index_files = [f for f in files if f.endswith(_INDEX_FILENAME)]
worker_end = _WorkerEnv.detect()
self._is_done = len(index_files) == self._distributed_env.world_size * worker_end.world_size
return self._is_done
@property
def rank(self) -> int:
"""Returns the rank of the writer."""
if self._rank is None:
self._worker_env = _WorkerEnv.detect()
self._rank = self._distributed_env.global_rank * self._worker_env.world_size + self._worker_env.rank
return self._rank
def get_config(self) -> Dict[str, Any]:
"""Returns the config of the writer."""
out = {
"compression": self._compression,
"chunk_size": self._chunk_size,
"chunk_bytes": self._chunk_bytes,
"data_format": self._data_format,
"data_spec": treespec_dumps(self._data_spec) if self._data_spec else None,
}
return out
def serialize(self, items: Any) -> bytes:
"""Serialize a dictionary into its binary format."""
# Flatten the items provided by the users
flattened, data_spec = tree_flatten(items)
# Collect the sizes and associated bytes for each item
sizes: List[int] = []
data: List[bytes] = []
data_format: List[str] = []
for item in flattened:
data_format.append(self._serialize(item, sizes, data))
if self._data_format is None:
self._data_format = data_format
elif self._data_format != data_format:
raise Exception(
f"The data format changed between items. Found {data_format} instead of {self._data_format}."
)
if self._data_spec is None:
self._data_spec = data_spec
elif self._data_spec != data_spec:
raise Exception(f"The data format changed between items. Found {data_spec} instead of {self._data_spec}.")
# Concatenante into a single byte array
head = np.array(sizes, np.uint32).tobytes()
body = b"".join(data)
return head + body
def _serialize(self, item: Any, sizes: List[int], data: List[bytes]) -> str:
"""Serialize a given item and append its size and bytes to the sizes and data array."""
for serializer_name, serializer in self._serializers.items():
if serializer.can_serialize(item):
serialized_item, name = serializer.serialize(item)
data.append(serialized_item)
sizes.append(len(serialized_item))
return name or serializer_name
raise ValueError(f"The provided item isn't serializable. Found {item}")
def _create_chunk(self, filename: str) -> bytes:
"""Create a binary chunk from all the binarized items."""
num_items = np.uint32(len(self._serialized_items))
sizes = list(map(len, self._serialized_items))
offsets = np.array([0] + sizes).cumsum().astype(np.uint32)
offsets += len(num_items.tobytes()) + len(offsets.tobytes())
sample_data = b"".join(self._serialized_items)
data = num_items.tobytes() + offsets.tobytes() + sample_data
offsets = offsets.tolist()
mapping = {}
for i in range(len(self._indexes)):
mapping[self._indexes[i]] = [offsets[i], offsets[i + 1]]
assert len(mapping) == len(self._indexes)
assert (self._indexes[-1] - self._indexes[0] + 1) == len(self._serialized_items)
chunk_info = {
"chunk_bytes": self._current_chunk_bytes,
"chunk_size": len(self._serialized_items),
"filename": filename,
"interval": [self._indexes[0], self._indexes[-1] + 1],
}
self._chunks_info.append(chunk_info)
return data
def write_chunk(self) -> None:
"""Write a chunk to the filesystem."""
if self._compression:
filename = f"chunk-{self.rank}-{self._chunk_index}.{self._compression}.bin"
else:
filename = f"chunk-{self.rank}-{self._chunk_index}.bin"
self.write_chunk_to_file(self._create_chunk(filename), filename)
def reset(self) -> None:
"""Reset the writer to handle the next chunk."""
self._serialized_items = []
self._indexes = []
self._current_chunk_bytes = 0
def __setitem__(self, index: int, items: Any) -> None:
"""Store an item to a chunk.
The index needs to be provided in order.
This is handled by the samplers automatically. This ensures we can map an index to a shard from an interval.
"""
# Serialize the items
serialized_items = self.serialize(items)
serialized_items_size = len(serialized_items)
# Check whether it is time to write a chunk
should_write = (
self._chunk_bytes and self._chunk_bytes < self._current_chunk_bytes + serialized_items_size
) or (self._chunk_size and len(self._indexes) >= self._chunk_size)
if should_write:
if self._current_chunk_bytes == 0:
raise Exception(
f"The provided chunk_size {self._chunk_bytes} is too small."
f" You should use a multiple of {serialized_items_size} bytes."
)
self.write_chunk()
self.reset()
self._chunk_index += 1
# Store the serialized items into the chunk.
self._serialized_items.append(serialized_items)
self._current_chunk_bytes += serialized_items_size
# Validate the index are provided in an incremental order
# This is required to ensure we can find efficiently a chunk index from an index using the chunk interval
if self._indexes:
assert self._indexes[-1] == index - 1, (self._indexes, index - 1)
# Store the index
self._indexes.append(index)
def write_chunk_to_file(
self,
raw_data: bytes,
filename: str,
) -> None:
"""Write chunk bytes to a file."""
# Whether to compress the raw bytes
if self._compression:
raw_data = self._compressor.compress(raw_data)
# Write the binary chunk file
with open(os.path.join(self._cache_dir, filename), "wb") as out:
out.write(raw_data)
def write_chunks_index(self) -> None:
"""Write the chunks index to a JSON file."""
filepath = os.path.join(self._cache_dir, f"{self.rank}.{_INDEX_FILENAME}")
config = self.get_config()
with open(filepath, "w") as out:
json.dump({"chunks": self._chunks_info, "config": config}, out, sort_keys=True)
def done(self) -> None:
"""Called when StopIteration is triggered."""
if self.filled:
return
if self._serialized_items:
self.write_chunk()
self.write_chunks_index()
self.reset()
self._is_done = True
def merge(self, num_workers: int = 1) -> None:
"""Once all the workers have written their own index, the merge function is responsible to read and merge them
into a single index."""
num_workers = num_workers or 1
# Only for non rank 0
if self.rank != 0:
while not os.path.exists(os.path.join(self._cache_dir, _INDEX_FILENAME)):
sleep(0.001)
return
# Wait for all indexes to be available
is_done = False
while not is_done:
files = os.listdir(self._cache_dir)
if _INDEX_FILENAME in files:
return
index_files = [f for f in files if f.endswith(_INDEX_FILENAME)]
is_done = len(index_files) == self._distributed_env.world_size * num_workers
sleep(0.001)
# Read the index and append the chunks together
chunks_info = []
config = None
for index_filename in sorted(index_files):
chunk_path = os.path.join(self._cache_dir, index_filename)
with open(chunk_path) as f:
data = json.load(f)
if config is None:
config = data["config"]
elif config != data["config"]:
raise Exception("The config isn't consistent between chunks. This shouldn't have happened.")
chunks_info.extend(data["chunks"])
os.remove(chunk_path)
# Write down the collected index
with open(os.path.join(self._cache_dir, _INDEX_FILENAME), "w") as f:
json.dump({"chunks": chunks_info, "config": config}, f, sort_keys=True)

View File

@ -1,7 +1,7 @@
from typing import Optional
from typing import Callable, Optional
import torch
from torch.utils.data import get_worker_info
from torch.utils.data import get_worker_info as torch_get_worker_info
class _DistributedEnv:
@ -60,7 +60,7 @@ class _WorkerEnv:
self.rank = rank
@classmethod
def detect(cls) -> "_WorkerEnv":
def detect(cls, get_worker_info_fn: Optional[Callable] = None) -> "_WorkerEnv":
"""Automatically detects the number of workers and the current rank.
Note:
@ -68,6 +68,7 @@ class _WorkerEnv:
In such a case it will default to 1 worker
"""
get_worker_info = get_worker_info_fn or torch_get_worker_info
worker_info = get_worker_info()
num_workers = worker_info.num_workers if worker_info is not None else 1
current_worker_rank = worker_info.id if worker_info is not None else 0

0
tests/tests_data/cache/__init__.py vendored Normal file
View File

205
tests/tests_data/cache/test_cache.py vendored Normal file
View File

@ -0,0 +1,205 @@
# Copyright The Lightning AI team.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import sys
from functools import partial
import numpy as np
import pytest
import torch
from lightning import seed_everything
from lightning.data.cache import Cache
from lightning.data.cache.dataloader import LightningDataLoader
from lightning.data.datasets.env import _DistributedEnv
from lightning.fabric import Fabric
from lightning.pytorch.demos.boring_classes import RandomDataset
from lightning_utilities.core.imports import RequirementCache
from torch.utils.data import Dataset
_PIL_AVAILABLE = RequirementCache("PIL")
_TORCH_VISION_AVAILABLE = RequirementCache("torchvision")
class ImageDataset(Dataset):
def __init__(self, tmpdir, cache, size, num_classes):
from PIL import Image
self.data = []
self.cache = cache
seed_everything(42)
for i in range(size):
path = os.path.join(tmpdir, f"img{i}.jpeg")
np_data = np.random.randint(255, size=(28, 28), dtype=np.uint8)
img = Image.fromarray(np_data).convert("L")
img.save(path, format="jpeg", quality=100)
self.data.append({"image": path, "class": np.random.randint(num_classes)})
def __len__(self):
return len(self.data)
def __getitem__(self, index):
if self.cache.filled:
return self.cache[index]
self.cache[index] = {**self.data[index], "index": index}
return None
def _cache_for_image_dataset(num_workers, tmpdir, fabric=None):
from PIL import Image
from torchvision.transforms import PILToTensor
dataset_size = 85
cache_dir = os.path.join(tmpdir, "cache")
distributed_env = _DistributedEnv.detect()
cache = Cache(cache_dir, chunk_size=10)
dataset = ImageDataset(tmpdir, cache, dataset_size, 10)
dataloader = LightningDataLoader(dataset, num_workers=num_workers, batch_size=4)
for _ in dataloader:
pass
# Not strictly required but added to avoid race condition
if distributed_env.world_size > 1:
fabric.barrier()
assert cache.filled
for i in range(len(dataset)):
cached_data = dataset[i]
original_data = dataset.data[i]
assert cached_data["class"] == original_data["class"]
original_array = PILToTensor()(Image.open(original_data["image"]))
assert torch.equal(original_array, cached_data["image"])
if distributed_env.world_size == 1:
indexes = []
dataloader = LightningDataLoader(dataset, num_workers=num_workers, batch_size=4)
for batch in dataloader:
if batch:
indexes.extend(batch["index"].numpy().tolist())
assert len(indexes) == dataset_size
seed_everything(42)
dataloader = LightningDataLoader(dataset, num_workers=num_workers, batch_size=4, shuffle=True)
dataloader_iter = iter(dataloader)
indexes = []
for batch in dataloader_iter:
indexes.extend(batch["index"].numpy().tolist())
if distributed_env.world_size == 1:
assert len(indexes) == dataset_size
indexes2 = []
for batch in dataloader_iter:
indexes2.extend(batch["index"].numpy().tolist())
assert indexes2 != indexes
@pytest.mark.skipif(
condition=not _PIL_AVAILABLE or not _TORCH_VISION_AVAILABLE, reason="Requires: ['pil', 'torchvision']"
)
@pytest.mark.parametrize("num_workers", [0, 1, 2])
def test_cache_for_image_dataset(num_workers, tmpdir):
cache_dir = os.path.join(tmpdir, "cache")
os.makedirs(cache_dir)
_cache_for_image_dataset(num_workers, tmpdir)
def _fabric_cache_for_image_dataset(fabric, num_workers, tmpdir):
_cache_for_image_dataset(num_workers, tmpdir, fabric=fabric)
@pytest.mark.skipif(
condition=not _PIL_AVAILABLE or not _TORCH_VISION_AVAILABLE or sys.platform == "win32",
reason="Requires: ['pil', 'torchvision']",
)
@pytest.mark.parametrize("num_workers", [2])
def test_cache_for_image_dataset_distributed(num_workers, tmpdir):
cache_dir = os.path.join(tmpdir, "cache")
os.makedirs(cache_dir)
fabric = Fabric(accelerator="cpu", devices=2, strategy="ddp_spawn")
fabric.launch(partial(_fabric_cache_for_image_dataset, num_workers=num_workers, tmpdir=tmpdir))
def test_cache_with_simple_format(tmpdir):
cache_dir = os.path.join(tmpdir, "cache1")
os.makedirs(cache_dir)
cache = Cache(cache_dir, chunk_bytes=90)
for i in range(100):
cache[i] = i
cache.done()
cache.merge()
for i in range(100):
assert i == cache[i]
cache_dir = os.path.join(tmpdir, "cache2")
os.makedirs(cache_dir)
cache = Cache(cache_dir, chunk_bytes=90)
for i in range(100):
cache[i] = [i, {0: [i + 1]}]
cache.done()
cache.merge()
for i in range(100):
assert [i, {0: [i + 1]}] == cache[i]
def test_cache_with_auto_wrapping(tmpdir):
os.makedirs(os.path.join(tmpdir, "cache_1"), exist_ok=True)
dataset = RandomDataset(64, 64)
dataloader = LightningDataLoader(dataset, cache_dir=os.path.join(tmpdir, "cache_1"), chunk_bytes=2 << 12)
for batch in dataloader:
assert isinstance(batch, torch.Tensor)
assert sorted(os.listdir(os.path.join(tmpdir, "cache_1"))) == [
"chunk-0-0.bin",
"chunk-0-1.bin",
"chunk-0-2.bin",
"index.json",
]
# Your dataset is optimised for the cloud
class RandomDatasetAtRuntime(Dataset):
def __init__(self, size: int, length: int):
self.len = length
self.size = size
def __getitem__(self, index: int) -> torch.Tensor:
return torch.randn(1, self.size)
def __len__(self) -> int:
return self.len
os.makedirs(os.path.join(tmpdir, "cache_2"), exist_ok=True)
dataset = RandomDatasetAtRuntime(64, 64)
dataloader = LightningDataLoader(dataset, cache_dir=os.path.join(tmpdir, "cache_2"), chunk_bytes=2 << 12)
with pytest.raises(ValueError, match="Your dataset items aren't deterministic"):
for batch in dataloader:
pass

112
tests/tests_data/cache/test_sampler.py vendored Normal file
View File

@ -0,0 +1,112 @@
from unittest import mock
import pytest
from lightning import seed_everything
from lightning.data.cache.sampler import CacheBatchSampler
@pytest.mark.parametrize(
"params",
[
(
21,
1,
[[0, 1, 2], [7, 8, 9], [14, 15, 16], [3, 4, 5], [10, 11, 12], [17, 18, 19], [6], [13], [20]],
[[7, 0, 0], [1, 1, 1], [5, 5, 5], [0, 4, 4], [8, 3, 3], [2, 2, 2], [4], [3], [6]],
),
(
11,
1,
[[0, 1, 2], [3, 4, 5], [6, 7, 8], [], [], [9, 10]],
[[1, 1, 1], [3, 3], [0, 0, 0], [2, 2, 2]],
),
(8, 1, [[0, 1], [2, 3], [4, 5, 6], [], [], [7]], [[1, 1, 2], [3], [0, 0], [2, 2]]),
(4, 1, [[0], [1], [2, 3]], [[0], [1], [2, 2]]),
(
9,
1,
[[0, 1, 2], [3, 4, 5], [6, 7, 8]],
[[0, 0, 0], [1, 1, 1], [2, 2, 2]],
),
(
19,
1,
[[0, 1, 2], [6, 7, 8], [12, 13, 14], [3, 4, 5], [9, 10, 11], [15, 16, 17], [], [], [18]],
[[0, 0, 0], [1, 1, 1], [5, 5, 5], [2, 2, 2], [4, 4, 4], [3, 3, 3], [6]],
),
(19, 2, [[0, 1, 2], [3, 4, 5], [6, 7, 8]], [[0, 0, 0], [5, 5, 5], [4, 4, 4], [6]]),
],
)
def test_cache_batch_sampler(params):
seed_everything(42)
cache = mock.MagicMock()
cache.filled = False
if params[1] > 1:
batch_sampler = CacheBatchSampler(params[0], params[1], 0, 3, 3, False, True, cache)
batches = []
for batch in batch_sampler:
batches.append(batch)
assert batches == params[2], batches
batch_sampler = CacheBatchSampler(params[0], 1, 0, 3, 3, False, True, cache)
batches = []
for batch in batch_sampler:
batches.append(batch)
chunks_interval = [[batch[0], batch[-1] + 1] for batch in batches if len(batch)]
else:
batch_sampler = CacheBatchSampler(params[0], params[1], 0, 3, 3, False, True, cache)
batches = []
for batch in batch_sampler:
batches.append(batch)
assert batches == params[2], batches
chunks_interval = [[batch[0], batch[-1] + 1] for batch in batches if len(batch)]
cache.filled = True
cache.get_chunk_interval.return_value = chunks_interval
seed_everything(42)
batch_sampler = CacheBatchSampler(params[0], params[1], 0, 3, 3, False, True, cache)
batches_1 = []
for batch in batch_sampler:
batches_1.append(batch)
def validate_batch(data, check_values):
if params[1] == 1:
assert all(b[0].chunk_indexes is not None for b in data[:3])
assert all(b[1].chunk_indexes is None if len(b) > 1 else True for b in data[:3])
assert all(b[0].chunk_indexes is None if len(b) else True for b in data[3:])
if check_values:
assert [[x.chunk_index for x in d] for d in data] == params[3]
else:
assert all(b[0].chunk_indexes is not None for b in data[:3])
assert all(b[1].chunk_indexes is None if len(b) > 1 else True for b in data[:3])
assert all(b[0].chunk_indexes is None if len(b) else True for b in data[3:])
if check_values:
assert [[x.chunk_index for x in d] for d in data] == params[3]
validate_batch(batches_1, True)
batches_2 = []
for batch in batch_sampler:
batches_2.append(batch)
validate_batch(batches_2, False)
if params[1] == 1:
assert batches_1 != batches_2
def test_batch_sampler_imagenet():
"""Validate the Imagenet dataset is valid."""
dataset_size = 1281167
world_size = 1
rank = 0
num_workers = 32
batch_size = 8
cache = mock.MagicMock()
cache.filled = False
CacheBatchSampler(dataset_size, world_size, rank, num_workers, batch_size, False, True, cache)

View File

@ -0,0 +1,115 @@
# Copyright The Lightning AI team.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import sys
from time import time
import numpy as np
import pytest
import torch
from lightning import seed_everything
from lightning.data.cache.serializers import (
_SERIALIZERS,
_TORCH_DTYPES_MAPPING,
IntSerializer,
PickleSerializer,
PILSerializer,
TensorSerializer,
)
from lightning_utilities.core.imports import RequirementCache
_PIL_AVAILABLE = RequirementCache("PIL")
def test_serializers():
assert list(_SERIALIZERS.keys()) == ["file", "pil", "int", "jpeg", "bytes", "tensor", "pickle"]
def test_int_serializer():
serializer = IntSerializer()
for i in range(100):
data, _ = serializer.serialize(i)
assert isinstance(data, bytes)
assert i == serializer.deserialize(data)
@pytest.mark.skipif(condition=not _PIL_AVAILABLE, reason="Requires: ['pil']")
@pytest.mark.parametrize("mode", ["I", "L", "RGB"])
def test_pil_serializer(mode):
serializer = PILSerializer()
from PIL import Image
np_data = np.random.randint(255, size=(28, 28), dtype=np.uint32)
img = Image.fromarray(np_data).convert(mode)
data, _ = serializer.serialize(img)
assert isinstance(data, bytes)
deserialized_img = serializer.deserialize(data)
deserialized_img = deserialized_img.convert("I")
np_dec_data = np.asarray(deserialized_img, dtype=np.uint32)
assert isinstance(deserialized_img, Image.Image)
# Validate data content
assert np.array_equal(np_data, np_dec_data)
@pytest.mark.skipif(sys.platform == "win32", reason="Not supported on windows")
def test_tensor_serializer():
seed_everything(42)
serializer_tensor = TensorSerializer()
serializer_pickle = PickleSerializer()
ratio_times = []
ratio_bytes = []
shapes = [(10,), (10, 10), (10, 10, 10), (10, 10, 10, 5), (10, 10, 10, 5, 4)]
for dtype in _TORCH_DTYPES_MAPPING.values():
for shape in shapes:
# Not serializable for some reasons
if dtype in [torch.bfloat16]:
continue
tensor = torch.ones(shape, dtype=dtype)
t0 = time()
data, _ = serializer_tensor.serialize(tensor)
deserialized_tensor = serializer_tensor.deserialize(data)
tensor_time = time() - t0
tensor_bytes = len(data)
assert deserialized_tensor.dtype == dtype
assert torch.equal(tensor, deserialized_tensor)
t1 = time()
data, _ = serializer_pickle.serialize(tensor)
deserialized_tensor = serializer_pickle.deserialize(data)
pickle_time = time() - t1
pickle_bytes = len(data)
assert deserialized_tensor.dtype == dtype
assert torch.equal(tensor, deserialized_tensor)
ratio_times.append(pickle_time / tensor_time)
ratio_bytes.append(pickle_bytes / tensor_bytes)
assert np.mean(ratio_times) > 3.5
assert np.mean(ratio_bytes) > 2
def test_assert_bfloat16_tensor_serializer():
serializer = TensorSerializer()
tensor = torch.ones((10,), dtype=torch.bfloat16)
with pytest.raises(TypeError, match="Got unsupported ScalarType BFloat16"):
serializer.serialize(tensor)

168
tests/tests_data/cache/test_writer.py vendored Normal file
View File

@ -0,0 +1,168 @@
# Copyright The Lightning AI team.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json
import os
import numpy as np
import pytest
from lightning.data.cache.reader import BinaryReader
from lightning.data.cache.sampler import ChunkedIndex
from lightning.data.cache.writer import BinaryWriter
from lightning_utilities.core.imports import RequirementCache
_PIL_AVAILABLE = RequirementCache("PIL")
def test_binary_writer_with_ints_and_chunk_bytes(tmpdir):
with pytest.raises(FileNotFoundError, match="The provided cache directory `dontexists` doesn't exist."):
BinaryWriter("dontexists", {})
with pytest.raises(ValueError, match="No compresion algorithms are installed."):
BinaryWriter(tmpdir, {"i": "int"}, compression="something_else")
binary_writer = BinaryWriter(tmpdir, chunk_bytes=90)
for i in range(100):
binary_writer[i] = {"i": i, "i+1": i + 1, "i+2": i + 2}
assert len(os.listdir(tmpdir)) == 19
binary_writer.done()
binary_writer.merge()
assert len(os.listdir(tmpdir)) == 21
with open(os.path.join(tmpdir, "index.json")) as f:
data = json.load(f)
assert data["chunks"][0]["chunk_size"] == 6
assert data["chunks"][1]["chunk_size"] == 5
assert data["chunks"][-1]["chunk_size"] == 4
chunk_sizes = np.cumsum([chunk["chunk_size"] for chunk in data["chunks"]])
reader = BinaryReader(tmpdir)
for i in range(100):
for chunk_index, chunk_start in enumerate(chunk_sizes):
if i >= chunk_start:
continue
break
data = reader.read(ChunkedIndex(i, chunk_index=chunk_index))
assert data == {"i": i, "i+1": i + 1, "i+2": i + 2}
def test_binary_writer_with_ints_and_chunk_size(tmpdir):
with pytest.raises(FileNotFoundError, match="The provided cache directory `dontexists` doesn't exist."):
BinaryWriter("dontexists", {})
with pytest.raises(ValueError, match="No compresion algorithms are installed."):
BinaryWriter(tmpdir, {"i": "int"}, compression="something_else")
binary_writer = BinaryWriter(tmpdir, chunk_size=25)
for i in range(100):
binary_writer[i] = {"i": i, "i+1": i + 1, "i+2": i + 2}
assert len(os.listdir(tmpdir)) == 3
binary_writer.done()
binary_writer.merge()
assert len(os.listdir(tmpdir)) == 5
with open(os.path.join(tmpdir, "index.json")) as f:
data = json.load(f)
assert data["chunks"][0]["chunk_size"] == 25
assert data["chunks"][1]["chunk_size"] == 25
assert data["chunks"][-1]["chunk_size"] == 25
reader = BinaryReader(tmpdir)
for i in range(100):
data = reader.read(ChunkedIndex(i, chunk_index=i // 25))
assert data == {"i": i, "i+1": i + 1, "i+2": i + 2}
@pytest.mark.skipif(condition=not _PIL_AVAILABLE, reason="Requires: ['pil']")
def test_binary_writer_with_jpeg_and_int(tmpdir):
"""Validate the writer and reader can serialize / deserialize a pair of image and label."""
from PIL import Image
cache_dir = os.path.join(tmpdir, "chunks")
os.makedirs(cache_dir, exist_ok=True)
binary_writer = BinaryWriter(cache_dir, chunk_bytes=2 << 12)
imgs = []
for i in range(100):
path = os.path.join(tmpdir, f"img{i}.jpeg")
np_data = np.random.randint(255, size=(28, 28), dtype=np.uint8)
img = Image.fromarray(np_data).convert("L")
img.save(path, format="jpeg", quality=100)
img = Image.open(path)
imgs.append(img)
binary_writer[i] = {"x": img, "y": i}
assert len(os.listdir(cache_dir)) == 24
binary_writer.done()
binary_writer.merge()
assert len(os.listdir(cache_dir)) == 26
with open(os.path.join(cache_dir, "index.json")) as f:
data = json.load(f)
assert data["chunks"][0]["chunk_size"] == 4
assert data["chunks"][1]["chunk_size"] == 4
assert data["chunks"][-1]["chunk_size"] == 4
reader = BinaryReader(cache_dir)
for i in range(100):
data = reader.read(ChunkedIndex(i, chunk_index=i // 4))
np.testing.assert_array_equal(np.asarray(data["x"]).squeeze(0), imgs[i])
assert data["y"] == i
@pytest.mark.skipif(condition=not _PIL_AVAILABLE, reason="Requires: ['pil']")
def test_binary_writer_with_jpeg_filepath_and_int(tmpdir):
"""Validate the writer and reader can serialize / deserialize a pair of image and label."""
from PIL import Image
cache_dir = os.path.join(tmpdir, "chunks")
os.makedirs(cache_dir, exist_ok=True)
binary_writer = BinaryWriter(cache_dir, chunk_bytes=2 << 12)
imgs = []
for i in range(100):
path = os.path.join(tmpdir, f"img{i}.jpeg")
np_data = np.random.randint(255, size=(28, 28), dtype=np.uint8)
img = Image.fromarray(np_data).convert("L")
img.save(path, format="jpeg", quality=100)
img = Image.open(path)
imgs.append(img)
binary_writer[i] = {"x": path, "y": i}
assert len(os.listdir(cache_dir)) == 24
binary_writer.done()
binary_writer.merge()
assert len(os.listdir(cache_dir)) == 26
with open(os.path.join(cache_dir, "index.json")) as f:
data = json.load(f)
assert data["chunks"][0]["chunk_size"] == 4
assert data["chunks"][1]["chunk_size"] == 4
assert data["chunks"][-1]["chunk_size"] == 4
reader = BinaryReader(cache_dir)
for i in range(100):
data = reader.read(ChunkedIndex(i, chunk_index=i // 4))
np.testing.assert_array_equal(np.asarray(data["x"]).squeeze(0), imgs[i])
assert data["y"] == i

View File

@ -17,8 +17,7 @@ def get_test_index_data(index_path):
return list(dict.fromkeys([item.split("/")[-1] for item in data if "jpeg" in item]))
@pytest.fixture(scope="session")
def image_set(tmp_path_factory):
def image_set(tmpdir):
from PIL import Image
file_nums = [
@ -45,11 +44,11 @@ def image_set(tmp_path_factory):
img = img.astype(np.uint8)
im = Image.fromarray(img)
for i in file_nums:
fn = tmp_path_factory.mktemp("test_data") / f"img-{i}.jpeg"
im.save(fn)
folder_path = os.path.join(tmpdir, "test_data")
os.makedirs(folder_path, exist_ok=True)
return tmp_path_factory.getbasetemp()._str
for i in file_nums:
im.save(os.path.join(folder_path, f"img-{i}.jpeg"))
@pytest.mark.xfail(strict=False, reason="Need a valid AWS key and AWS secret key in CI for this to work")
@ -70,7 +69,6 @@ def test_get_index_generate_for_s3_bucket(monkeypatch):
test_bucket = "s3://nohaspublictestbucket"
index_path = os.path.join(os.getcwd(), "index_1.txt")
print(index_path)
got_index = get_index(s3_connection_path=test_bucket, index_file_path=index_path)
assert got_index
@ -80,13 +78,16 @@ def test_get_index_generate_for_s3_bucket(monkeypatch):
assert len(test_index_data) == len(generated_index)
assert test_index_data == generated_index
os.remove(index_path)
@pytest.mark.skipif(not package_available("lightning"), reason="Supported only with mono-package")
@mock.patch("lightning.data.datasets.index.LightningClient", MagicMock())
def test_get_index_generate_for_local_folder(image_set, monkeypatch):
def test_get_index_generate_for_local_folder(monkeypatch, tmpdir):
"""Can generate an index for an s3 bucket."""
image_set(tmpdir)
client = MagicMock()
client.projects_service_list_project_cluster_bindings.return_value = None
client.data_connection_service_list_data_connections.return_value = None
@ -100,7 +101,7 @@ def test_get_index_generate_for_local_folder(image_set, monkeypatch):
# test_local_bucket = "data/test_dataset"
index_path = os.path.join(THIS_DIR, "index_2.txt")
got_index = get_index(s3_connection_path=image_set, index_file_path=index_path)
got_index = get_index(s3_connection_path=str(tmpdir), index_file_path=index_path)
assert got_index