Introduce Cache 1/n (#18642)
Co-authored-by: Ethan Harris <ethanwharris@gmail.com> Co-authored-by: Luca Antiga <luca.antiga@gmail.com> Co-authored-by: thomas <thomas@thomass-MacBook-Pro.local> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
parent
6537a05977
commit
1d5851ffe2
|
@ -173,9 +173,9 @@ subprojects:
|
|||
- "!*.md"
|
||||
- "!**/*.md"
|
||||
checks:
|
||||
- "data-cpu (macOS-11, lightning, 3.10, 2.0)"
|
||||
- "data-cpu (ubuntu-20.04, lightning, 3.10, 2.0)"
|
||||
- "data-cpu (windows-2022, lightning, 3.10, 2.0)"
|
||||
- "data-cpu (macOS-11, lightning, 3.10, 2.1)"
|
||||
- "data-cpu (ubuntu-20.04, lightning, 3.10, 2.1)"
|
||||
- "data-cpu (windows-2022, lightning, 3.10, 2.1)"
|
||||
|
||||
# SECTION: lightning_fabric
|
||||
|
||||
|
|
|
@ -34,9 +34,9 @@ jobs:
|
|||
fail-fast: false
|
||||
matrix:
|
||||
include:
|
||||
- { os: "macOS-11", pkg-name: "lightning", python-version: "3.10", pytorch-version: "2.0" }
|
||||
- { os: "ubuntu-20.04", pkg-name: "lightning", python-version: "3.10", pytorch-version: "2.0" }
|
||||
- { os: "windows-2022", pkg-name: "lightning", python-version: "3.10", pytorch-version: "2.0" }
|
||||
- { os: "macOS-11", pkg-name: "lightning", python-version: "3.10", pytorch-version: "2.1" }
|
||||
- { os: "ubuntu-20.04", pkg-name: "lightning", python-version: "3.10", pytorch-version: "2.1" }
|
||||
- { os: "windows-2022", pkg-name: "lightning", python-version: "3.10", pytorch-version: "2.1" }
|
||||
# "oldest" versions tests, only on minimum Python
|
||||
# - {os: "macOS-11", pkg-name: "lightning", python-version: "3.8", pytorch-version: "2.0", requires: "oldest"}
|
||||
# - {os: "ubuntu-20.04", pkg-name: "lightning", python-version: "3.8", pytorch-version: "2.0", requires: "oldest"}
|
||||
|
|
|
@ -3,6 +3,6 @@
|
|||
|
||||
lightning-utilities >=0.8.0, <0.10.0
|
||||
# to be able to include also 0.6 and preserve `>` needed for CI min version bypass
|
||||
torchdata >0.5.9, <0.7.0
|
||||
torchdata >0.5.9, <=0.7.0
|
||||
# to be able to include also PL 2.0 and preserve `>` needed for CI min version bypass
|
||||
torch >0.14.0, <2.1.0
|
||||
torch >0.14.0, <=2.1.0
|
||||
|
|
|
@ -0,0 +1,17 @@
|
|||
# Copyright The Lightning AI team.
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from lightning.data.cache.cache import Cache
|
||||
from lightning.data.cache.dataloader import LightningDataLoader
|
||||
|
||||
__all__ = ["Cache", "LightningDataLoader"]
|
|
@ -0,0 +1,90 @@
|
|||
# Copyright The Lightning AI team.
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import logging
|
||||
import os
|
||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||
|
||||
from lightning.data.cache.constants import _INDEX_FILENAME, _TORCH_2_1_0_AVAILABLE
|
||||
from lightning.data.cache.reader import BinaryReader
|
||||
from lightning.data.cache.sampler import ChunkedIndex
|
||||
from lightning.data.cache.writer import BinaryWriter
|
||||
from lightning.data.datasets.env import _DistributedEnv
|
||||
|
||||
logger = logging.Logger(__name__)
|
||||
|
||||
|
||||
class Cache:
|
||||
def __init__(
|
||||
self,
|
||||
cache_dir: str,
|
||||
remote_dir: Optional[str] = None,
|
||||
compression: Optional[str] = None,
|
||||
chunk_size: Optional[int] = None,
|
||||
chunk_bytes: Optional[int] = None,
|
||||
):
|
||||
"""The Cache enables to optimise dataset format for cloud training. This is done by grouping several elements
|
||||
together in order to accelerate fetching.
|
||||
|
||||
Arguments:
|
||||
cache_dir: The path to where the chunks will be stored.
|
||||
remote_dir: The path to a remote folder where the data are located.
|
||||
The scheme needs to be added to the path.
|
||||
compression: The name of the algorithm to reduce the size of the chunks.
|
||||
chunk_bytes: The maximum number of bytes within a chunk.
|
||||
chunk_size: The maximum number of items within a chunk.
|
||||
|
||||
"""
|
||||
super().__init__()
|
||||
if not _TORCH_2_1_0_AVAILABLE:
|
||||
raise ModuleNotFoundError("PyTorch version 2.1 or higher is required to use the cache.")
|
||||
self._writer = BinaryWriter(cache_dir, chunk_size=chunk_size, chunk_bytes=chunk_bytes, compression=compression)
|
||||
self._reader = BinaryReader(cache_dir, remote_dir=remote_dir, compression=compression)
|
||||
self._cache_dir = cache_dir
|
||||
self._is_done = False
|
||||
self._distributed_env = _DistributedEnv.detect()
|
||||
|
||||
@property
|
||||
def filled(self) -> bool:
|
||||
"""Returns whether the caching phase is done."""
|
||||
if self._is_done:
|
||||
return True
|
||||
self._is_done = os.path.exists(os.path.join(self._cache_dir, _INDEX_FILENAME))
|
||||
return self._is_done
|
||||
|
||||
def __setitem__(self, index: int, data: Any) -> None:
|
||||
"""Store an item in the writer."""
|
||||
self._writer[index] = data
|
||||
|
||||
def __getitem__(self, index: Union[int, ChunkedIndex]) -> Dict[str, Any]:
|
||||
"""Read an item in the reader."""
|
||||
if isinstance(index, int):
|
||||
index = ChunkedIndex(index, self._get_chunk_index_from_index(index))
|
||||
return self._reader.read(index)
|
||||
|
||||
def done(self) -> None:
|
||||
"""Inform the writer the chunking phase is finished."""
|
||||
self._writer.done()
|
||||
|
||||
def merge(self, num_workers: int = 1) -> None:
|
||||
"""Inform the writer the chunking phase is finished."""
|
||||
self._writer.merge(num_workers)
|
||||
|
||||
def __len__(self) -> int:
|
||||
return self._reader.get_length()
|
||||
|
||||
def get_chunk_interval(self) -> List[Tuple[int, int]]:
|
||||
return self._reader.get_chunk_interval()
|
||||
|
||||
def _get_chunk_index_from_index(self, index: int) -> int:
|
||||
return self._reader._get_chunk_index_from_index(index)
|
|
@ -0,0 +1,76 @@
|
|||
# Copyright The Lightning AI team.
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from abc import ABC, abstractclassmethod, abstractmethod
|
||||
from typing import Dict, TypeVar
|
||||
|
||||
from lightning_utilities.core.imports import RequirementCache, requires
|
||||
|
||||
_ZSTD_AVAILABLE = RequirementCache("zstd")
|
||||
|
||||
if _ZSTD_AVAILABLE:
|
||||
import zstd
|
||||
|
||||
TCompressor = TypeVar("TCompressor", bound="Compressor")
|
||||
|
||||
|
||||
class Compressor(ABC):
|
||||
"""Base class for compression algorithm."""
|
||||
|
||||
@abstractmethod
|
||||
def compress(self, data: bytes) -> bytes:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def decompress(self, data: bytes) -> bytes:
|
||||
pass
|
||||
|
||||
@abstractclassmethod
|
||||
def register(cls, compressors: Dict[str, "Compressor"]) -> None:
|
||||
pass
|
||||
|
||||
|
||||
class ZSTDCompressor(Compressor):
|
||||
"""Compressor for the zstd package."""
|
||||
|
||||
@requires("zstd")
|
||||
def __init__(self, level: int) -> None:
|
||||
super().__init__()
|
||||
self.level = level
|
||||
self.extension = "zstd"
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return f"{self.extension}:{self.level}"
|
||||
|
||||
def compress(self, data: bytes) -> bytes:
|
||||
return zstd.compress(data, self.level)
|
||||
|
||||
def decompress(self, data: bytes) -> bytes:
|
||||
return zstd.decompress(data)
|
||||
|
||||
@classmethod
|
||||
def register(cls, compressors: Dict[str, "Compressor"]) -> None: # type: ignore
|
||||
if not _ZSTD_AVAILABLE:
|
||||
return
|
||||
|
||||
# default
|
||||
compressors["zstd"] = ZSTDCompressor(4)
|
||||
|
||||
for level in list(range(1, 23)):
|
||||
compressors[f"zstd:{level}"] = ZSTDCompressor(level)
|
||||
|
||||
|
||||
_COMPRESSORS: Dict[str, Compressor] = {}
|
||||
|
||||
ZSTDCompressor.register(_COMPRESSORS)
|
|
@ -0,0 +1,125 @@
|
|||
# Copyright The Lightning AI team.
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import json
|
||||
import os
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
from lightning.data.cache.constants import _INDEX_FILENAME, _TORCH_2_1_0_AVAILABLE
|
||||
from lightning.data.cache.downloader import get_downloader_cls
|
||||
from lightning.data.cache.sampler import ChunkedIndex
|
||||
|
||||
if _TORCH_2_1_0_AVAILABLE:
|
||||
from torch.utils._pytree import treespec_loads
|
||||
|
||||
|
||||
class ChunksConfig:
|
||||
def __init__(self, cache_dir: str, remote_dir: Optional[str]):
|
||||
"""The ChunksConfig reads the index files associated a chunked dataset and enables to map an index to its
|
||||
chunk.
|
||||
|
||||
Arguments:
|
||||
cache_dir: The path to cache folder.
|
||||
remote_dir: The path to a remote folder where the data are located.
|
||||
The scheme needs to be added to the path.
|
||||
|
||||
"""
|
||||
self._cache_dir = cache_dir
|
||||
self._intervals: List[Tuple[int, int]] = []
|
||||
self._config = None
|
||||
self._chunks = []
|
||||
self._remote_dir = remote_dir
|
||||
|
||||
with open(os.path.join(self._cache_dir, _INDEX_FILENAME)) as f:
|
||||
data = json.load(f)
|
||||
|
||||
self._config = data["config"]
|
||||
|
||||
self._chunks.extend(data["chunks"])
|
||||
|
||||
self._config["data_spec"] = treespec_loads(self._config["data_spec"])
|
||||
|
||||
for chunk in self._chunks:
|
||||
start, end = chunk["interval"]
|
||||
if (end - start) != chunk["chunk_size"]:
|
||||
raise Exception(
|
||||
"The config intervals doesn't match the number of samples. This shouldn't have happened."
|
||||
)
|
||||
self._intervals.append((chunk["interval"][0], chunk["interval"][1]))
|
||||
|
||||
self._length = sum([chunk["chunk_size"] for chunk in self._chunks])
|
||||
|
||||
self._downloader = None
|
||||
|
||||
if remote_dir:
|
||||
self._downloader = get_downloader_cls(remote_dir)(remote_dir, cache_dir, self._chunks)
|
||||
|
||||
def download_chunk_from_index(self, chunk_index: int) -> None:
|
||||
chunk_filename = self._chunks[chunk_index]["filename"]
|
||||
|
||||
local_chunkpath = os.path.join(self._cache_dir, chunk_filename)
|
||||
|
||||
if os.path.exists(local_chunkpath):
|
||||
return
|
||||
|
||||
if self._downloader is None:
|
||||
raise RuntimeError("The downloader should be defined.")
|
||||
|
||||
self._downloader.download_chunk_from_index(chunk_index)
|
||||
|
||||
@property
|
||||
def intervals(self) -> List[Tuple[int, int]]:
|
||||
if self._intervals is None:
|
||||
raise RuntimeError("The intervals should be defined.")
|
||||
return self._intervals
|
||||
|
||||
@property
|
||||
def data_format(self) -> Any:
|
||||
if self._config is None:
|
||||
raise RuntimeError("The config should be defined.")
|
||||
return self._config["data_format"]
|
||||
|
||||
@property
|
||||
def config(self) -> Dict[str, Any]:
|
||||
if self._config is None:
|
||||
raise RuntimeError("The config should be defined.")
|
||||
return self._config
|
||||
|
||||
def _get_chunk_index_from_index(self, index: int) -> int:
|
||||
for chunk_index, internal in enumerate(self._intervals):
|
||||
if internal[0] <= index < internal[1]:
|
||||
return chunk_index
|
||||
raise ValueError(
|
||||
f"The provided index {index} didn't find a match within the chunk intervals {self._intervals}."
|
||||
)
|
||||
|
||||
def __getitem__(self, index: ChunkedIndex) -> Tuple[str, int, int]:
|
||||
"""Find the associated chunk metadata."""
|
||||
chunk = self._chunks[index.chunk_index]
|
||||
return os.path.join(self._cache_dir, chunk["filename"]), *self._intervals[index.chunk_index]
|
||||
|
||||
@classmethod
|
||||
def load(cls, cache_dir: str, remote_dir: Optional[str] = None) -> Optional["ChunksConfig"]:
|
||||
cache_index_filepath = os.path.join(cache_dir, _INDEX_FILENAME)
|
||||
|
||||
if isinstance(remote_dir, str):
|
||||
downloader = get_downloader_cls(remote_dir)(remote_dir, cache_dir, [])
|
||||
downloader.download_file(os.path.join(remote_dir, _INDEX_FILENAME), cache_index_filepath)
|
||||
|
||||
if not os.path.exists(cache_index_filepath):
|
||||
return None
|
||||
|
||||
return ChunksConfig(cache_dir, remote_dir)
|
||||
|
||||
def __len__(self) -> int:
|
||||
return self._length
|
|
@ -0,0 +1,21 @@
|
|||
# Copyright The Lightning AI team.
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from lightning_utilities.core.imports import RequirementCache
|
||||
|
||||
_INDEX_FILENAME = "index.json"
|
||||
_DEFAULT_CHUNK_BYTES = 1 << 26 # 64M B
|
||||
|
||||
# This is required for full pytree serialization / deserialization support
|
||||
_TORCH_2_1_0_AVAILABLE = RequirementCache("torch>=2.1.0")
|
||||
_VIZ_TRACKER_AVAILABLE = RequirementCache("viztracer")
|
|
@ -0,0 +1,311 @@
|
|||
# Copyright The Lightning AI team.
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import asyncio
|
||||
import inspect
|
||||
import logging
|
||||
import os
|
||||
from importlib import reload
|
||||
from typing import Any, Callable, List, Optional
|
||||
|
||||
import torch
|
||||
from torch.utils.data import Dataset, IterableDataset
|
||||
from torch.utils.data._utils.collate import default_collate
|
||||
from torch.utils.data._utils.fetch import _BaseDatasetFetcher
|
||||
from torch.utils.data.dataloader import (
|
||||
DataLoader,
|
||||
_BaseDataLoaderIter,
|
||||
_DatasetKind,
|
||||
_MultiProcessingDataLoaderIter,
|
||||
_SingleProcessDataLoaderIter,
|
||||
)
|
||||
from torch.utils.data.sampler import BatchSampler, Sampler
|
||||
|
||||
from lightning.data.cache import Cache
|
||||
from lightning.data.cache.constants import _DEFAULT_CHUNK_BYTES, _TORCH_2_1_0_AVAILABLE, _VIZ_TRACKER_AVAILABLE
|
||||
from lightning.data.cache.sampler import CacheBatchSampler
|
||||
from lightning.data.datasets.env import _DistributedEnv, _WorkerEnv
|
||||
|
||||
if _TORCH_2_1_0_AVAILABLE:
|
||||
from torch.utils._pytree import tree_flatten
|
||||
|
||||
logger = logging.Logger(__name__)
|
||||
|
||||
|
||||
def _equal_items(data_1: Any, data_2: Any) -> bool:
|
||||
data_1_flattened, _ = tree_flatten(data_1)
|
||||
data_2_flattened, _ = tree_flatten(data_2)
|
||||
|
||||
if len(data_1_flattened) != len(data_2_flattened):
|
||||
return False
|
||||
|
||||
return all(_equal_item(d1, d2) for d1, d2 in zip(data_1_flattened, data_2_flattened))
|
||||
|
||||
|
||||
def _equal_item(d1: Any, d2: Any) -> bool:
|
||||
if not isinstance(d1, type(d2)):
|
||||
return False
|
||||
equality = d1 == d2
|
||||
if isinstance(equality, torch.Tensor):
|
||||
return bool(equality.all().item())
|
||||
if equality is True:
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
class CacheDataset(Dataset):
|
||||
def __init__(
|
||||
self,
|
||||
dataset: Any,
|
||||
cache_dir: str,
|
||||
chunk_bytes: Optional[int],
|
||||
chunk_size: Optional[int],
|
||||
compression: Optional[str],
|
||||
):
|
||||
"""The `CacheDataset` is a dataset wraper to provide a beginner experience with the Cache.
|
||||
|
||||
Arguments:
|
||||
dataset: The dataset of the user
|
||||
cache_dir: The folder where the chunks are written to.
|
||||
chunk_bytes: The maximal number of bytes to write within a chunk.
|
||||
chunk_sie: The maximal number of items to write to a chunk.
|
||||
compression: The compression algorithm to use to reduce the size of the chunk.
|
||||
|
||||
"""
|
||||
self._dataset = dataset
|
||||
self._cache = Cache(cache_dir, chunk_bytes=chunk_bytes, chunk_size=chunk_size, compression=compression)
|
||||
self._is_deterministic = False
|
||||
|
||||
def __len__(self) -> int:
|
||||
return len(self._cache) if self._cache.filled else len(self._dataset)
|
||||
|
||||
def __getitem__(self, index: int) -> Any:
|
||||
data_1 = self._cache[index] if self._cache.filled else self._dataset[index]
|
||||
if not self._cache.filled:
|
||||
if not self._is_deterministic:
|
||||
data2 = self._dataset[index]
|
||||
if not _equal_items(data_1, data2):
|
||||
raise ValueError(
|
||||
f"Your dataset items aren't deterministic. Found {data_1} and {data2} for index {index}."
|
||||
" HINT: Use the `lightning.data.cache.Cache` directly within your dataset."
|
||||
)
|
||||
self._is_deterministic = True
|
||||
self._cache[index] = data_1
|
||||
return data_1
|
||||
|
||||
|
||||
class CacheCollateFn:
|
||||
"""This CacheCollateFn is used to accelerate the processing of the data generated using the Cache.
|
||||
|
||||
During the chunking phase, there is no need to return any data from the DataLoader reducing some time.
|
||||
|
||||
Additionally, if the user makes their __getitem__ asynchronous, the collate executes them in parallel.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, collate_fn: Optional[Callable] = None) -> None:
|
||||
self.collate_fn = collate_fn or default_collate
|
||||
|
||||
def __call__(self, items: List[Any]) -> Any:
|
||||
if all(item is None for item in items):
|
||||
return None
|
||||
|
||||
# If the __getitem__ method is asynchornous, collect all the items.
|
||||
if all(inspect.iscoroutine(item) for item in items):
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
items = loop.run_until_complete(asyncio.gather(*items))
|
||||
|
||||
return self.collate_fn([item for item in items if item is not None])
|
||||
|
||||
|
||||
class _SingleProcessDataLoaderIterPatch(_SingleProcessDataLoaderIter):
|
||||
"""This is overriden to inform the cache is done chunking."""
|
||||
|
||||
def _next_data(self) -> Any:
|
||||
try:
|
||||
data = None
|
||||
while data is None:
|
||||
data = super()._next_data()
|
||||
return data
|
||||
except StopIteration:
|
||||
for v in self._dataset_fetcher.dataset.__dict__.values():
|
||||
if isinstance(v, Cache):
|
||||
v.done()
|
||||
if not v.filled:
|
||||
v.merge(1)
|
||||
raise StopIteration()
|
||||
|
||||
|
||||
class WorkerLoop:
|
||||
"""Wrap the PyTorch DataLoader WorkerLoop to perform caching and profiling."""
|
||||
|
||||
def __init__(self, global_rank: int, profile: bool = False) -> None:
|
||||
self._global_rank = global_rank
|
||||
self._profile = profile
|
||||
|
||||
def __call__(self, dataset_kind: _DatasetKind, *args: Any, **kwargs: Any) -> None:
|
||||
from torch.utils.data._utils import worker
|
||||
|
||||
from lightning.data.cache.cache import Cache
|
||||
|
||||
rank = _WorkerEnv.detect().rank
|
||||
enable_profiling = self._global_rank == 0 and rank == 0 and _VIZ_TRACKER_AVAILABLE and self._profile
|
||||
|
||||
if enable_profiling:
|
||||
from viztracer import VizTracer
|
||||
|
||||
tracer = VizTracer(output_file=os.path.join(os.getcwd(), "trace.json"))
|
||||
tracer.start()
|
||||
|
||||
# Reload to remove the patching
|
||||
reloaded_worker = reload(worker)
|
||||
create_fetcher = _DatasetKind.create_fetcher
|
||||
fetcher = None
|
||||
|
||||
def create_fetcher_fn(*args: Any, **kwargs: Any) -> "_BaseDatasetFetcher":
|
||||
nonlocal fetcher
|
||||
fetcher = create_fetcher(*args, **kwargs)
|
||||
return fetcher
|
||||
|
||||
_DatasetKind.create_fetcher = create_fetcher_fn # type: ignore
|
||||
|
||||
reloaded_worker._worker_loop(dataset_kind, *args, **kwargs)
|
||||
|
||||
if dataset_kind == _DatasetKind.Map:
|
||||
assert fetcher
|
||||
for v in fetcher.dataset.__dict__.values():
|
||||
if isinstance(v, Cache):
|
||||
v.done()
|
||||
|
||||
if enable_profiling:
|
||||
tracer.stop()
|
||||
tracer.save()
|
||||
|
||||
|
||||
class _MultiProcessingDataLoaderIterPatch(_MultiProcessingDataLoaderIter):
|
||||
def __init__(self, loader: DataLoader) -> None:
|
||||
self._cache = loader._cache
|
||||
self._num_workers = loader.num_workers
|
||||
# Patch PyTorch worker loop to call the `cache.done()` method.
|
||||
from torch.utils.data._utils import worker
|
||||
|
||||
worker._worker_loop = WorkerLoop(loader._global_rank, loader._profile)
|
||||
super().__init__(loader)
|
||||
|
||||
def _shutdown_workers(self) -> None:
|
||||
super()._shutdown_workers()
|
||||
|
||||
# If the data isn't filled, we trigger an indedm merge
|
||||
if not self._cache.filled:
|
||||
self._cache.merge(self._num_workers)
|
||||
|
||||
def _next_data(self) -> Any:
|
||||
try:
|
||||
data = None
|
||||
while data is None:
|
||||
data = super()._next_data()
|
||||
return data
|
||||
except StopIteration as e:
|
||||
raise e
|
||||
|
||||
|
||||
class LightningDataLoader(DataLoader):
|
||||
__doc__ = DataLoader.__doc__
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dataset: Any,
|
||||
*args: Any,
|
||||
sampler: Optional[Sampler] = None,
|
||||
batch_sampler: Optional[BatchSampler] = None,
|
||||
num_workers: int = 0,
|
||||
shuffle: bool = False,
|
||||
generator: Optional[torch.Generator] = None,
|
||||
batch_size: Optional[int] = None,
|
||||
drop_last: bool = False,
|
||||
cache_dir: Optional[str] = None,
|
||||
chunk_bytes: Optional[int] = _DEFAULT_CHUNK_BYTES,
|
||||
compression: Optional[str] = None,
|
||||
profile: bool = False,
|
||||
collate_fn: Optional[Callable] = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
if sampler:
|
||||
raise ValueError(
|
||||
"The LightningDataLoader relies on its own internal sampler. Passing a sampler isn't supported."
|
||||
)
|
||||
|
||||
if batch_sampler:
|
||||
raise ValueError(
|
||||
"The LightningDataLoader relies on its own internal sampler. Passing a batch_sampler isn't supported."
|
||||
)
|
||||
|
||||
if isinstance(dataset, IterableDataset):
|
||||
raise ValueError("Only map-based dataset are supported by the LightningDataLoader for now.")
|
||||
|
||||
if profile and not _VIZ_TRACKER_AVAILABLE:
|
||||
raise ModuleNotFoundError("To enable DataLoader profiling, run `pip install viztracer`.")
|
||||
|
||||
cache_list = [v for v in dataset.__dict__.values() if isinstance(v, Cache)]
|
||||
|
||||
if len(cache_list) > 1:
|
||||
raise ValueError(
|
||||
"We found several Cache used as attributes from your dataset. Only one is support for now."
|
||||
)
|
||||
|
||||
if len(cache_list) == 0:
|
||||
if cache_dir is None:
|
||||
raise ValueError("You should provide a `cache_dir` filepath to the LightningDataLoader.")
|
||||
|
||||
dataset = CacheDataset(dataset, cache_dir, chunk_bytes, batch_size, compression)
|
||||
cache = dataset._cache
|
||||
else:
|
||||
cache = cache_list[0]
|
||||
|
||||
if not cache.filled and shuffle:
|
||||
logger.info("Shuffle is ignored during the caching phase phase.")
|
||||
|
||||
self._cache = cache
|
||||
|
||||
distributed_env = _DistributedEnv.detect()
|
||||
self._global_rank = distributed_env.global_rank
|
||||
|
||||
batch_sampler = CacheBatchSampler(
|
||||
len(dataset),
|
||||
distributed_env.world_size,
|
||||
self._global_rank,
|
||||
num_workers,
|
||||
batch_size or 1,
|
||||
drop_last,
|
||||
shuffle,
|
||||
cache,
|
||||
)
|
||||
|
||||
self._profile = profile
|
||||
|
||||
super().__init__(
|
||||
dataset,
|
||||
*args,
|
||||
batch_sampler=batch_sampler, # type: ignore
|
||||
collate_fn=CacheCollateFn(collate_fn),
|
||||
num_workers=num_workers,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def _get_iterator(self) -> "_BaseDataLoaderIter":
|
||||
"""Overriden to ensure the `Cache.done()` method is triggered on iteration done."""
|
||||
if self.num_workers == 0:
|
||||
return _SingleProcessDataLoaderIterPatch(self)
|
||||
self.check_worker_number_rationality()
|
||||
return _MultiProcessingDataLoaderIterPatch(self)
|
|
@ -0,0 +1,74 @@
|
|||
# Copyright The Lightning AI team.
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import os
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Dict, List, Type
|
||||
from urllib import parse
|
||||
|
||||
|
||||
class Downloader(ABC):
|
||||
def __init__(self, remote_dir: str, cache_dir: str, chunks: List[Dict[str, Any]]):
|
||||
self._remote_dir = remote_dir
|
||||
self._cache_dir = cache_dir
|
||||
self._chunks = chunks
|
||||
|
||||
def download_chunk_from_index(self, chunk_index: int) -> None:
|
||||
chunk_filename = self._chunks[chunk_index]["filename"]
|
||||
local_chunkpath = os.path.join(self._cache_dir, chunk_filename)
|
||||
remote_chunkpath = os.path.join(self._remote_dir, chunk_filename)
|
||||
self.download_file(remote_chunkpath, local_chunkpath)
|
||||
|
||||
@abstractmethod
|
||||
def download_file(self, remote_chunkpath: str, local_chunkpath: str) -> None:
|
||||
pass
|
||||
|
||||
|
||||
class S3Downloader(Downloader):
|
||||
@classmethod
|
||||
def download_file(cls, remote_filepath: str, local_filepath: str) -> None:
|
||||
import boto3
|
||||
from boto3.s3.transfer import TransferConfig
|
||||
from botocore.config import Config
|
||||
|
||||
obj = parse.urlparse(remote_filepath)
|
||||
|
||||
if obj.scheme != "s3":
|
||||
raise ValueError(f"Expected obj.scheme to be `s3`, instead, got {obj.scheme} for remote={remote_filepath}")
|
||||
|
||||
extra_args: Dict[str, Any] = {}
|
||||
|
||||
# Create a new session per thread
|
||||
session = boto3.session.Session()
|
||||
# Create a resource client using a thread's session object
|
||||
s3 = session.client("s3", config=Config(read_timeout=None))
|
||||
# Threads calling S3 operations return RuntimeError (cannot schedule new futures after
|
||||
# interpreter shutdown). Temporary solution is to have `use_threads` as `False`.
|
||||
# Issue: https://github.com/boto/boto3/issues/3113
|
||||
s3.download_file(
|
||||
obj.netloc,
|
||||
obj.path.lstrip("/"),
|
||||
local_filepath,
|
||||
ExtraArgs=extra_args,
|
||||
Config=TransferConfig(use_threads=False),
|
||||
)
|
||||
|
||||
|
||||
# TODO: Add fsspec support
|
||||
_DOWNLOADERS = {"s3://": S3Downloader}
|
||||
|
||||
|
||||
def get_downloader_cls(remote_dir: str) -> Type[Downloader]:
|
||||
for k, cls in _DOWNLOADERS.items():
|
||||
if remote_dir.startswith(k):
|
||||
return cls
|
||||
raise ValueError(f"The provided `remote_dir` {remote_dir} doesn't have a downloader associated.")
|
|
@ -0,0 +1,177 @@
|
|||
# Copyright The Lightning AI team.
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import os
|
||||
from threading import Lock, Thread
|
||||
from time import sleep
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
import numpy as np
|
||||
|
||||
from lightning.data.cache.config import ChunksConfig
|
||||
from lightning.data.cache.constants import _TORCH_2_1_0_AVAILABLE
|
||||
from lightning.data.cache.sampler import ChunkedIndex
|
||||
from lightning.data.cache.serializers import _SERIALIZERS, Serializer
|
||||
from lightning.data.datasets.env import _DistributedEnv, _WorkerEnv
|
||||
|
||||
if _TORCH_2_1_0_AVAILABLE:
|
||||
from torch.utils._pytree import PyTree, tree_unflatten
|
||||
|
||||
|
||||
class PrepareChunksThread(Thread):
|
||||
"""This thread is responsible to download the chunks associated to a given worker."""
|
||||
|
||||
def __init__(self, config: ChunksConfig) -> None:
|
||||
super().__init__(daemon=True)
|
||||
self._config = config
|
||||
self._chunks_index_to_be_processed: List[int] = []
|
||||
self._chunks_index_to_ready: List[int] = []
|
||||
self._lock = Lock()
|
||||
|
||||
def add(self, chunk_indices: List[int]) -> None:
|
||||
"""Receive the list of the chunk indices to download for the current epoch."""
|
||||
with self._lock:
|
||||
self._chunks_index_to_be_processed.extend(chunk_indices)
|
||||
|
||||
def run(self) -> None:
|
||||
while True:
|
||||
with self._lock:
|
||||
if len(self._chunks_index_to_be_processed) == 0:
|
||||
sleep(0.007)
|
||||
continue
|
||||
|
||||
chunk_index = self._chunks_index_to_be_processed.pop(0)
|
||||
|
||||
# TODO: Implement eviction
|
||||
self._config.download_chunk_from_index(chunk_index)
|
||||
self._chunks_index_to_ready.append(chunk_index)
|
||||
|
||||
|
||||
class BinaryReader:
|
||||
def __init__(self, cache_dir: str, remote_dir: Optional[str] = None, compression: Optional[str] = None) -> None:
|
||||
"""The BinaryReader enables to read chunked dataset in an efficient way.
|
||||
|
||||
Arguments:
|
||||
cache_dir: The path to cache folder.
|
||||
remote_dir: The path to a remote folder where the data are located.
|
||||
The scheme needs to be added to the path.
|
||||
compression: The algorithm to decompress the chunks.
|
||||
|
||||
"""
|
||||
super().__init__()
|
||||
self._cache_dir = cache_dir
|
||||
self._remote_dir = remote_dir
|
||||
|
||||
if not os.path.exists(self._cache_dir):
|
||||
raise FileNotFoundError(f"The provided cache_dir `{self._cache_dir}` doesn't exist.")
|
||||
|
||||
self._compression = compression
|
||||
self._intervals: Optional[List[str]] = None
|
||||
|
||||
self._serializers: Dict[str, Serializer] = _SERIALIZERS
|
||||
self._distributed_env = _DistributedEnv.detect()
|
||||
self._rank: Optional[int] = None
|
||||
self._config: Optional[ChunksConfig] = None
|
||||
self._prepare_thread: Optional[PrepareChunksThread] = None
|
||||
|
||||
def _get_chunk_index_from_index(self, index: int) -> int:
|
||||
# Load the config containing the index
|
||||
if self._config is None and self._try_load_config() is None:
|
||||
raise Exception("The reader index isn't defined.")
|
||||
|
||||
return self._config._get_chunk_index_from_index(index) # type: ignore
|
||||
|
||||
def _try_load_config(self) -> Optional[ChunksConfig]:
|
||||
"""Try to load the chunks config if the index files are available."""
|
||||
self._config = ChunksConfig.load(self._cache_dir, self._remote_dir)
|
||||
return self._config
|
||||
|
||||
@property
|
||||
def config(self) -> ChunksConfig:
|
||||
if self._config is None:
|
||||
raise RuntimeError("The config should be defined.")
|
||||
return self._config
|
||||
|
||||
@property
|
||||
def rank(self) -> int:
|
||||
"""Returns the rank of the writer."""
|
||||
if self._rank is None:
|
||||
self._worker_env = _WorkerEnv.detect()
|
||||
self._rank = self._distributed_env.global_rank * self._worker_env.world_size + self._worker_env.rank
|
||||
return self._rank
|
||||
|
||||
def read(self, index: ChunkedIndex) -> Any:
|
||||
"""Read an item for the given from a chunk.
|
||||
|
||||
If the chunk isn't available locally or in memory, it will be downloaded.
|
||||
|
||||
Prefetching should reduce the wait time to be the batch available.
|
||||
|
||||
"""
|
||||
if not isinstance(index, ChunkedIndex):
|
||||
raise ValueError("The Reader.read(...) method expects a chunked Index.")
|
||||
|
||||
# Load the config containing the index
|
||||
if self._config is None and self._try_load_config() is None:
|
||||
raise Exception("The reader index isn't defined.")
|
||||
|
||||
# Create and start the prepare chunks thread
|
||||
if index.chunk_indexes is not None and self._prepare_thread is None and self._config:
|
||||
self._prepare_thread = PrepareChunksThread(self._config)
|
||||
self._prepare_thread.start()
|
||||
self._prepare_thread.add(index.chunk_indexes)
|
||||
|
||||
# Fetch the element
|
||||
chunk_filepath, begin, _ = self.config[index]
|
||||
raw_item_data = self.load_item_from_chunk(index.index, chunk_filepath, begin)
|
||||
return self.deserialize(raw_item_data)
|
||||
|
||||
def deserialize(self, raw_item_data: bytes) -> "PyTree":
|
||||
"""Deserialize the raw bytes into their python equivalent."""
|
||||
idx = len(self.config.data_format) * 4
|
||||
sizes = np.frombuffer(raw_item_data[:idx], np.uint32)
|
||||
data = []
|
||||
for size, data_format in zip(sizes, self.config.data_format):
|
||||
serializer = self._serializers[data_format]
|
||||
data_bytes = raw_item_data[idx : idx + size]
|
||||
data.append(serializer.deserialize(data_bytes))
|
||||
idx += size
|
||||
return tree_unflatten(data, self.config.config["data_spec"])
|
||||
|
||||
def load_item_from_chunk(self, index: int, chunk_filepath: str, begin: int) -> bytes:
|
||||
offset = (1 + (index - begin)) * 4
|
||||
|
||||
while not os.path.exists(chunk_filepath):
|
||||
sleep(0.0001)
|
||||
|
||||
with open(chunk_filepath, "rb", 0) as fp:
|
||||
fp.seek(offset)
|
||||
pair = fp.read(8)
|
||||
begin, end = np.frombuffer(pair, np.uint32)
|
||||
fp.seek(begin)
|
||||
data = fp.read(end - begin)
|
||||
return data
|
||||
|
||||
def get_length(self) -> int:
|
||||
"""Get the number of samples across all chunks."""
|
||||
if self._config is None and self._try_load_config() is None:
|
||||
raise Exception("The reader index isn't defined.")
|
||||
|
||||
return len(self.config)
|
||||
|
||||
def get_chunk_interval(self) -> List[Tuple[int, int]]:
|
||||
"""Get the index interval of each chunk."""
|
||||
if self._config is None and self._try_load_config() is None:
|
||||
raise Exception("The reader index isn't defined.")
|
||||
|
||||
return self.config.intervals
|
|
@ -0,0 +1,226 @@
|
|||
# Copyright The Lightning AI team.
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import logging
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Dict, Iterator, List, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
|
||||
logger = logging.Logger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ChunkedIndex:
|
||||
index: int
|
||||
chunk_index: int
|
||||
chunk_indexes: Optional[List[int]] = None
|
||||
|
||||
|
||||
class CacheBatchSampler:
|
||||
def __init__(
|
||||
self,
|
||||
dataset_size: int,
|
||||
num_replicas: int,
|
||||
global_rank: int,
|
||||
num_workers: int,
|
||||
batch_size: int,
|
||||
drop_last: bool,
|
||||
shuffle: bool,
|
||||
cache: Any,
|
||||
):
|
||||
"""The CacheBatchSampler handles the generation of batch indices.
|
||||
|
||||
If the cache isn't filled, the batch sampler alternates with ordered indices for the writer to chunk the dataset
|
||||
If the cache is filled, it acts as normal BatchSampler.
|
||||
|
||||
Arguments:
|
||||
dataset_size: The size of the dataset.
|
||||
num_replicas: The number of processes involves in the distributed training.
|
||||
global_rank: The global_rank of the given process
|
||||
num_workers: The number of workers provided to the DataLoader.
|
||||
batch_size: The number of items in a batch.
|
||||
drop_last: Whether to drop the last batch of data.
|
||||
shuffle: Whether the data should be shuffled.
|
||||
cache: The cache associated to the dataset.
|
||||
|
||||
"""
|
||||
self._dataset_size = dataset_size
|
||||
self._num_replicas = num_replicas
|
||||
self._global_rank = global_rank
|
||||
self._cache = cache
|
||||
self._shuffle = shuffle
|
||||
self._num_workers = num_workers or 1
|
||||
self._shuffled_chunk_intervals = None
|
||||
self._batch_size = batch_size
|
||||
|
||||
self._drop_last = drop_last
|
||||
self._length = 0
|
||||
|
||||
# Before starting, ensures the chunk indices are properly defined.
|
||||
self._validate()
|
||||
|
||||
def _validate(self) -> None:
|
||||
"""Checks each worker is getting sucessive indices."""
|
||||
if self._num_workers > 1 and not self._cache.filled:
|
||||
batches: Dict[int, Any] = {}
|
||||
for batch_index, batch_indices in enumerate(self):
|
||||
self._length += 1
|
||||
worker_index = batch_index % self._num_workers
|
||||
if worker_index not in batches:
|
||||
batches[worker_index] = []
|
||||
batches[worker_index].extend(batch_indices)
|
||||
elif len(batch_indices) > 0:
|
||||
batches[worker_index].extend(batch_indices)
|
||||
|
||||
for indices in batches.values():
|
||||
indices = np.asarray(indices)
|
||||
diff = indices[1:] - (indices[:-1] + 1)
|
||||
if diff.sum() != 0:
|
||||
raise RuntimeError("This shouldn't have happened. There is a bug in the CacheSampler.")
|
||||
|
||||
def __iter__(self) -> Iterator[List[Union[int, ChunkedIndex]]]:
|
||||
# When the cache is filled, we need to iterate though the chunks
|
||||
if self._cache.filled:
|
||||
if self._num_replicas == 1:
|
||||
return self.__iter_from_chunks_non_distributed__()
|
||||
return self.__iter_from_chunks_distributed__()
|
||||
|
||||
# shuffle is ignored while building the binarized version of the dataset
|
||||
if self._num_replicas == 1:
|
||||
return self.__iter_non_distributed__()
|
||||
return self.__iter_distributed__()
|
||||
|
||||
def __iter_non_distributed__(self) -> Iterator[List[Union[int, ChunkedIndex]]]:
|
||||
worker_size = self._dataset_size // self._num_workers
|
||||
indices = list(range(self._dataset_size))
|
||||
worker_indices = []
|
||||
for worker_idx in range(self._num_workers):
|
||||
is_last = worker_idx == self._num_workers - 1
|
||||
start = worker_idx * worker_size
|
||||
end = self._dataset_size if is_last else (worker_idx + 1) * worker_size
|
||||
worker_indices.append(indices[start:end])
|
||||
|
||||
assert sum([len(s) for s in worker_indices]) == self._dataset_size
|
||||
|
||||
worker_indices_batches = [self._chunk_list(indices, self._batch_size) for indices in worker_indices]
|
||||
|
||||
yield from self.__iter_indices_per_workers__(worker_indices_batches)
|
||||
|
||||
def __iter_distributed__(self) -> Iterator[List[Union[int, ChunkedIndex]]]:
|
||||
self.indices = list(range(self._dataset_size))
|
||||
replica_size = self._dataset_size // self._num_replicas
|
||||
worker_size = self._dataset_size // (self._num_replicas * self._num_workers)
|
||||
for rank in range(self._num_replicas):
|
||||
if rank != self._global_rank:
|
||||
continue
|
||||
|
||||
is_last_replica = rank == self._num_replicas - 1
|
||||
start_replica = rank * replica_size
|
||||
end_replica = self._dataset_size if is_last_replica else (rank + 1) * replica_size
|
||||
replica_indices = self.indices[start_replica:end_replica]
|
||||
|
||||
replica_size = len(replica_indices)
|
||||
|
||||
worker_indices = []
|
||||
for worker_idx in range(self._num_workers):
|
||||
is_last_worker = worker_idx == self._num_workers - 1
|
||||
start_worker = worker_idx * worker_size
|
||||
end_worker = replica_size if is_last_worker else (worker_idx + 1) * worker_size
|
||||
worker_indices.append(replica_indices[start_worker:end_worker])
|
||||
|
||||
assert sum([len(s) for s in worker_indices]) == len(replica_indices)
|
||||
|
||||
worker_indices_batches = [self._chunk_list(indices, self._batch_size) for indices in worker_indices]
|
||||
|
||||
yield from self.__iter_indices_per_workers__(worker_indices_batches)
|
||||
|
||||
def __iter_from_chunks_non_distributed__(self) -> Iterator[List[Union[int, ChunkedIndex]]]:
|
||||
chunk_intervals = self._cache.get_chunk_interval()
|
||||
shuffled_indexes = np.random.permutation(range(len(chunk_intervals)))
|
||||
shuffled_chunk_intervals = np.asarray(chunk_intervals)[shuffled_indexes]
|
||||
yield from self.__iter_from_shuffled_chunks(shuffled_indexes.tolist(), shuffled_chunk_intervals)
|
||||
|
||||
def __iter_from_chunks_distributed__(self) -> Iterator[List[Union[int, ChunkedIndex]]]:
|
||||
chunk_intervals = self._cache.get_chunk_interval()
|
||||
shuffled_indexes = np.random.permutation(range(len(chunk_intervals)))
|
||||
shuffled_chunk_intervals = np.asarray(chunk_intervals)[shuffled_indexes]
|
||||
|
||||
replica_chunks = []
|
||||
replica_intervals = []
|
||||
for index, (chunk_index, chunk_interval) in enumerate(zip(shuffled_indexes, shuffled_chunk_intervals)):
|
||||
if index % self._num_replicas == self._global_rank:
|
||||
replica_chunks.append(chunk_index)
|
||||
replica_intervals.append(chunk_interval)
|
||||
|
||||
yield from self.__iter_from_shuffled_chunks(replica_chunks, replica_intervals)
|
||||
|
||||
def __iter_from_shuffled_chunks(
|
||||
self, shuffled_indexes: List[int], shuffled_chunk_intervals: List[List[int]]
|
||||
) -> Iterator[List[Union[int, ChunkedIndex]]]:
|
||||
chunks_per_workers: List[List[int]] = [[] for _ in range(self._num_workers)]
|
||||
for i, chunk_index in enumerate(shuffled_indexes):
|
||||
chunks_per_workers[i % self._num_workers].append(chunk_index)
|
||||
|
||||
indices_per_workers: List[List[ChunkedIndex]] = [[] for _ in range(self._num_workers)]
|
||||
|
||||
for i, (chunk_index, chunk_interval) in enumerate(zip(shuffled_indexes, shuffled_chunk_intervals)):
|
||||
worker_id = i % self._num_workers
|
||||
interval_indices = np.arange(chunk_interval[0], chunk_interval[1])
|
||||
shuffled_interval_indices = np.random.permutation(interval_indices).tolist()
|
||||
is_empty = len(indices_per_workers[worker_id]) == 0
|
||||
indices_per_workers[worker_id].extend(
|
||||
[
|
||||
ChunkedIndex(
|
||||
index,
|
||||
chunk_index,
|
||||
chunk_indexes=chunks_per_workers[worker_id] if j == 0 and is_empty else None,
|
||||
)
|
||||
for j, index in enumerate(shuffled_interval_indices)
|
||||
]
|
||||
)
|
||||
|
||||
indices_per_workers_splitted = [self._chunk_list(indices, self._batch_size) for indices in indices_per_workers]
|
||||
|
||||
yield from self.__iter_indices_per_workers__(indices_per_workers_splitted)
|
||||
|
||||
def __len__(self) -> int:
|
||||
return self._length
|
||||
|
||||
def __iter_indices_per_workers__(
|
||||
self, indices_per_workers: List[List[List[Union[int, ChunkedIndex]]]]
|
||||
) -> Iterator[List[Union[int, ChunkedIndex]]]:
|
||||
batches: List[List[Union[int, ChunkedIndex]]] = []
|
||||
counter = 0
|
||||
while sum([len(v) for v in indices_per_workers]) != 0:
|
||||
worker_indices = indices_per_workers[counter % self._num_workers]
|
||||
if len(worker_indices) == 0:
|
||||
batches.append([])
|
||||
else:
|
||||
batches.append(worker_indices.pop(0))
|
||||
counter += 1
|
||||
|
||||
while True:
|
||||
if len(batches[-1]) == 0:
|
||||
batches.pop(-1)
|
||||
else:
|
||||
break
|
||||
|
||||
yield from batches
|
||||
|
||||
def _chunk_list(self, arr: List[Any], chunk_size: int) -> List[List[Any]]:
|
||||
out = []
|
||||
for i in range(0, len(arr), chunk_size):
|
||||
slice_item = slice(i, i + chunk_size, 1)
|
||||
out.append(arr[slice_item])
|
||||
return out
|
|
@ -0,0 +1,223 @@
|
|||
# Copyright The Lightning AI team.
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import os
|
||||
import pickle
|
||||
from abc import ABC, abstractmethod
|
||||
from collections import OrderedDict
|
||||
from io import BytesIO
|
||||
from typing import Any, Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from lightning_utilities.core.imports import RequirementCache
|
||||
|
||||
_PIL_AVAILABLE = RequirementCache("PIL")
|
||||
_TORCH_VISION_AVAILABLE = RequirementCache("torchvision")
|
||||
|
||||
if _PIL_AVAILABLE:
|
||||
from PIL import Image
|
||||
from PIL.JpegImagePlugin import JpegImageFile
|
||||
else:
|
||||
Image = None
|
||||
JpegImageFile = None
|
||||
|
||||
if _TORCH_VISION_AVAILABLE:
|
||||
from torchvision.io import decode_jpeg
|
||||
|
||||
|
||||
class Serializer(ABC):
|
||||
"""The base interface for any serializers.
|
||||
|
||||
A Serializer serialize and deserialize to and from bytes.
|
||||
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def serialize(self, data: Any) -> Tuple[bytes, Optional[str]]:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def deserialize(self, data: bytes) -> Any:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def can_serialize(self, data: Any) -> bool:
|
||||
pass
|
||||
|
||||
|
||||
class PILSerializer(Serializer):
|
||||
"""The PILSerializer serialize and deserialize PIL Image to and from bytes."""
|
||||
|
||||
def serialize(self, item: Image) -> Tuple[bytes, Optional[str]]:
|
||||
mode = item.mode.encode("utf-8")
|
||||
width, height = item.size
|
||||
raw = item.tobytes()
|
||||
ints = np.array([width, height, len(mode)], np.uint32)
|
||||
return ints.tobytes() + mode + raw, None
|
||||
|
||||
def deserialize(self, data: bytes) -> Any:
|
||||
idx = 3 * 4
|
||||
width, height, mode_size = np.frombuffer(data[:idx], np.uint32)
|
||||
idx2 = idx + mode_size
|
||||
mode = data[idx:idx2].decode("utf-8")
|
||||
size = width, height
|
||||
raw = data[idx2:]
|
||||
return Image.frombytes(mode, size, raw) # pyright: ignore
|
||||
|
||||
def can_serialize(self, item: Any) -> bool:
|
||||
return isinstance(item, Image.Image) and not isinstance(item, JpegImageFile)
|
||||
|
||||
|
||||
class IntSerializer(Serializer):
|
||||
"""The IntSerializer serialize and deserialize integer to and from bytes."""
|
||||
|
||||
def serialize(self, item: int) -> Tuple[bytes, Optional[str]]:
|
||||
return str(item).encode("utf-8"), None
|
||||
|
||||
def deserialize(self, data: bytes) -> int:
|
||||
return int(data.decode("utf-8"))
|
||||
|
||||
def can_serialize(self, item: Any) -> bool:
|
||||
return isinstance(item, int)
|
||||
|
||||
|
||||
class JPEGSerializer(Serializer):
|
||||
"""The JPEGSerializer serialize and deserialize JPEG image to and from bytes."""
|
||||
|
||||
def serialize(self, item: Image) -> Tuple[bytes, Optional[str]]:
|
||||
if isinstance(item, JpegImageFile):
|
||||
if not hasattr(item, "filename"):
|
||||
raise ValueError(
|
||||
"The JPEG Image's filename isn't defined. HINT: Open the image in your Dataset __getitem__ method."
|
||||
)
|
||||
with open(item.filename, "rb") as f:
|
||||
return f.read(), None
|
||||
raise TypeError(f"The provided itemect should be of type {JpegImageFile}. Found {item}.")
|
||||
|
||||
def deserialize(self, data: bytes) -> Union[JpegImageFile, torch.Tensor]:
|
||||
if _TORCH_VISION_AVAILABLE:
|
||||
array = torch.frombuffer(data, dtype=torch.uint8)
|
||||
return decode_jpeg(array)
|
||||
|
||||
inp = BytesIO(data)
|
||||
return Image.open(inp)
|
||||
|
||||
def can_serialize(self, item: Any) -> bool:
|
||||
return isinstance(item, JpegImageFile)
|
||||
|
||||
|
||||
class BytesSerializer(Serializer):
|
||||
"""The BytesSerializer serialize and deserialize integer to and from bytes."""
|
||||
|
||||
def serialize(self, item: bytes) -> Tuple[bytes, Optional[str]]:
|
||||
return item, None
|
||||
|
||||
def deserialize(self, item: bytes) -> bytes:
|
||||
return item
|
||||
|
||||
def can_serialize(self, item: bytes) -> bool:
|
||||
return isinstance(item, bytes)
|
||||
|
||||
|
||||
_TORCH_DTYPES_MAPPING = {
|
||||
0: torch.float32,
|
||||
1: torch.float,
|
||||
2: torch.float64,
|
||||
3: torch.double,
|
||||
4: torch.complex64,
|
||||
5: torch.cfloat,
|
||||
6: torch.complex128,
|
||||
7: torch.cdouble,
|
||||
8: torch.float16,
|
||||
9: torch.half,
|
||||
10: torch.bfloat16, # Not supported https://github.com/pytorch/pytorch/issues/110285
|
||||
11: torch.uint8,
|
||||
12: torch.int8,
|
||||
13: torch.int16,
|
||||
14: torch.short,
|
||||
15: torch.int32,
|
||||
16: torch.int,
|
||||
17: torch.int64,
|
||||
18: torch.long,
|
||||
19: torch.bool,
|
||||
}
|
||||
|
||||
|
||||
class TensorSerializer(Serializer):
|
||||
"""The TensorSerializer serialize and deserialize tensor to and from bytes."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self._dtype_to_indice = {v: k for k, v in _TORCH_DTYPES_MAPPING.items()}
|
||||
|
||||
def serialize(self, item: torch.Tensor) -> Tuple[bytes, Optional[str]]:
|
||||
dtype_indice = self._dtype_to_indice[item.dtype]
|
||||
data = [np.uint32(dtype_indice).tobytes()]
|
||||
data.append(np.uint32(len(item.shape)).tobytes())
|
||||
for dim in item.shape:
|
||||
data.append(np.uint32(dim).tobytes())
|
||||
data.append(item.numpy().tobytes())
|
||||
return b"".join(data), None
|
||||
|
||||
def deserialize(self, data: bytes) -> torch.Tensor:
|
||||
dtype_indice = np.frombuffer(data[0:4], np.uint32).item()
|
||||
dtype = _TORCH_DTYPES_MAPPING[dtype_indice]
|
||||
shape_size = np.frombuffer(data[4:8], np.uint32).item()
|
||||
shape = []
|
||||
for shape_idx in range(shape_size):
|
||||
shape.append(np.frombuffer(data[8 + 4 * shape_idx : 8 + 4 * (shape_idx + 1)], np.uint32).item())
|
||||
tensor = torch.frombuffer(data[8 + 4 * (shape_idx + 1) : len(data)], dtype=dtype)
|
||||
return torch.reshape(tensor, torch.Size(shape))
|
||||
|
||||
def can_serialize(self, item: torch.Tensor) -> bool:
|
||||
return isinstance(item, torch.Tensor) and type(item) == torch.Tensor
|
||||
|
||||
|
||||
class PickleSerializer(Serializer):
|
||||
"""The PickleSerializer serialize and deserialize python objects to and from bytes."""
|
||||
|
||||
def serialize(self, item: Any) -> Tuple[bytes, Optional[str]]:
|
||||
return pickle.dumps(item), None
|
||||
|
||||
def deserialize(self, data: bytes) -> Any:
|
||||
return pickle.loads(data)
|
||||
|
||||
def can_serialize(self, _: Any) -> bool:
|
||||
return True
|
||||
|
||||
|
||||
class FileSerializer(Serializer):
|
||||
def serialize(self, filepath: str) -> Tuple[bytes, Optional[str]]:
|
||||
_, file_extension = os.path.splitext(filepath)
|
||||
with open(filepath, "rb") as f:
|
||||
return f.read(), file_extension.replace(".", "").lower()
|
||||
|
||||
def deserialize(self, data: bytes) -> Any:
|
||||
pass
|
||||
|
||||
def can_serialize(self, data: Any) -> bool:
|
||||
return isinstance(data, str) and os.path.exists(data)
|
||||
|
||||
|
||||
_SERIALIZERS = OrderedDict(
|
||||
**{
|
||||
"file": FileSerializer(),
|
||||
"pil": PILSerializer(),
|
||||
"int": IntSerializer(),
|
||||
"jpeg": JPEGSerializer(),
|
||||
"bytes": BytesSerializer(),
|
||||
"tensor": TensorSerializer(),
|
||||
"pickle": PickleSerializer(),
|
||||
}
|
||||
)
|
|
@ -0,0 +1,305 @@
|
|||
# Copyright The Lightning AI team.
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import json
|
||||
import os
|
||||
from time import sleep
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
import numpy as np
|
||||
|
||||
from lightning.data.cache.compression import _COMPRESSORS, Compressor
|
||||
from lightning.data.cache.constants import _INDEX_FILENAME, _TORCH_2_1_0_AVAILABLE
|
||||
from lightning.data.cache.serializers import _SERIALIZERS, Serializer
|
||||
from lightning.data.datasets.env import _DistributedEnv, _WorkerEnv
|
||||
|
||||
if _TORCH_2_1_0_AVAILABLE:
|
||||
from torch.utils._pytree import PyTree, tree_flatten, treespec_dumps
|
||||
|
||||
|
||||
class BinaryWriter:
|
||||
def __init__(
|
||||
self,
|
||||
cache_dir: str,
|
||||
chunk_size: Optional[int] = None,
|
||||
chunk_bytes: Optional[int] = None,
|
||||
compression: Optional[str] = None,
|
||||
):
|
||||
"""The BinaryWriter enables to chunk dataset into an efficient streaming format for cloud training.
|
||||
|
||||
Arguments:
|
||||
cache_dir: The path to where the chunks will be saved.
|
||||
chunk_bytes: The maximum number of bytes within a chunk.
|
||||
chunk_size: The maximum number of items within a chunk.
|
||||
compression: The compression algorithm to use.
|
||||
|
||||
"""
|
||||
self._cache_dir = cache_dir
|
||||
|
||||
if not os.path.exists(self._cache_dir):
|
||||
raise FileNotFoundError(f"The provided cache directory `{self._cache_dir}` doesn't exist.")
|
||||
|
||||
if (chunk_size is None and chunk_bytes is None) or (chunk_size and chunk_bytes):
|
||||
raise ValueError("Either one of the `chunk_size` or the `chunk_bytes` need to be provided.")
|
||||
|
||||
self._serializers: Dict[str, Serializer] = _SERIALIZERS
|
||||
self._chunk_size = chunk_size
|
||||
self._chunk_bytes = chunk_bytes
|
||||
self._compression = compression
|
||||
|
||||
self._data_format: Optional[List[str]] = None
|
||||
self._data_spec: Optional[PyTree] = None
|
||||
|
||||
if self._compression:
|
||||
if len(_COMPRESSORS) == 0:
|
||||
raise ValueError("No compresion algorithms are installed.")
|
||||
if self._compression not in _COMPRESSORS:
|
||||
raise ValueError(
|
||||
f"The provided compression {self._compression} isn't available in {sorted(_COMPRESSORS)}"
|
||||
)
|
||||
self._compressor: Compressor = _COMPRESSORS[self._compression]
|
||||
|
||||
self._current_chunk_bytes = 0
|
||||
self._chunk_index = 0
|
||||
self._serialized_items: List[bytes] = []
|
||||
self._chunks_info: List[Dict[str, Any]] = []
|
||||
self._indexes: List[int] = []
|
||||
self._worker_env: Optional[_WorkerEnv] = None
|
||||
self._rank: Optional[int] = None
|
||||
self._is_done = False
|
||||
self._distributed_env = _DistributedEnv.detect()
|
||||
|
||||
@property
|
||||
def filled(self) -> bool:
|
||||
"""Returns whether the caching phase is done."""
|
||||
if self._is_done:
|
||||
return True
|
||||
files = os.listdir(self._cache_dir)
|
||||
index_files = [f for f in files if f.endswith(_INDEX_FILENAME)]
|
||||
worker_end = _WorkerEnv.detect()
|
||||
self._is_done = len(index_files) == self._distributed_env.world_size * worker_end.world_size
|
||||
return self._is_done
|
||||
|
||||
@property
|
||||
def rank(self) -> int:
|
||||
"""Returns the rank of the writer."""
|
||||
if self._rank is None:
|
||||
self._worker_env = _WorkerEnv.detect()
|
||||
self._rank = self._distributed_env.global_rank * self._worker_env.world_size + self._worker_env.rank
|
||||
return self._rank
|
||||
|
||||
def get_config(self) -> Dict[str, Any]:
|
||||
"""Returns the config of the writer."""
|
||||
out = {
|
||||
"compression": self._compression,
|
||||
"chunk_size": self._chunk_size,
|
||||
"chunk_bytes": self._chunk_bytes,
|
||||
"data_format": self._data_format,
|
||||
"data_spec": treespec_dumps(self._data_spec) if self._data_spec else None,
|
||||
}
|
||||
return out
|
||||
|
||||
def serialize(self, items: Any) -> bytes:
|
||||
"""Serialize a dictionary into its binary format."""
|
||||
|
||||
# Flatten the items provided by the users
|
||||
flattened, data_spec = tree_flatten(items)
|
||||
|
||||
# Collect the sizes and associated bytes for each item
|
||||
sizes: List[int] = []
|
||||
data: List[bytes] = []
|
||||
|
||||
data_format: List[str] = []
|
||||
for item in flattened:
|
||||
data_format.append(self._serialize(item, sizes, data))
|
||||
|
||||
if self._data_format is None:
|
||||
self._data_format = data_format
|
||||
elif self._data_format != data_format:
|
||||
raise Exception(
|
||||
f"The data format changed between items. Found {data_format} instead of {self._data_format}."
|
||||
)
|
||||
|
||||
if self._data_spec is None:
|
||||
self._data_spec = data_spec
|
||||
elif self._data_spec != data_spec:
|
||||
raise Exception(f"The data format changed between items. Found {data_spec} instead of {self._data_spec}.")
|
||||
|
||||
# Concatenante into a single byte array
|
||||
head = np.array(sizes, np.uint32).tobytes()
|
||||
body = b"".join(data)
|
||||
return head + body
|
||||
|
||||
def _serialize(self, item: Any, sizes: List[int], data: List[bytes]) -> str:
|
||||
"""Serialize a given item and append its size and bytes to the sizes and data array."""
|
||||
for serializer_name, serializer in self._serializers.items():
|
||||
if serializer.can_serialize(item):
|
||||
serialized_item, name = serializer.serialize(item)
|
||||
data.append(serialized_item)
|
||||
sizes.append(len(serialized_item))
|
||||
return name or serializer_name
|
||||
raise ValueError(f"The provided item isn't serializable. Found {item}")
|
||||
|
||||
def _create_chunk(self, filename: str) -> bytes:
|
||||
"""Create a binary chunk from all the binarized items."""
|
||||
num_items = np.uint32(len(self._serialized_items))
|
||||
sizes = list(map(len, self._serialized_items))
|
||||
offsets = np.array([0] + sizes).cumsum().astype(np.uint32)
|
||||
offsets += len(num_items.tobytes()) + len(offsets.tobytes())
|
||||
sample_data = b"".join(self._serialized_items)
|
||||
data = num_items.tobytes() + offsets.tobytes() + sample_data
|
||||
offsets = offsets.tolist()
|
||||
mapping = {}
|
||||
for i in range(len(self._indexes)):
|
||||
mapping[self._indexes[i]] = [offsets[i], offsets[i + 1]]
|
||||
|
||||
assert len(mapping) == len(self._indexes)
|
||||
assert (self._indexes[-1] - self._indexes[0] + 1) == len(self._serialized_items)
|
||||
|
||||
chunk_info = {
|
||||
"chunk_bytes": self._current_chunk_bytes,
|
||||
"chunk_size": len(self._serialized_items),
|
||||
"filename": filename,
|
||||
"interval": [self._indexes[0], self._indexes[-1] + 1],
|
||||
}
|
||||
|
||||
self._chunks_info.append(chunk_info)
|
||||
|
||||
return data
|
||||
|
||||
def write_chunk(self) -> None:
|
||||
"""Write a chunk to the filesystem."""
|
||||
if self._compression:
|
||||
filename = f"chunk-{self.rank}-{self._chunk_index}.{self._compression}.bin"
|
||||
else:
|
||||
filename = f"chunk-{self.rank}-{self._chunk_index}.bin"
|
||||
self.write_chunk_to_file(self._create_chunk(filename), filename)
|
||||
|
||||
def reset(self) -> None:
|
||||
"""Reset the writer to handle the next chunk."""
|
||||
self._serialized_items = []
|
||||
self._indexes = []
|
||||
self._current_chunk_bytes = 0
|
||||
|
||||
def __setitem__(self, index: int, items: Any) -> None:
|
||||
"""Store an item to a chunk.
|
||||
|
||||
The index needs to be provided in order.
|
||||
|
||||
This is handled by the samplers automatically. This ensures we can map an index to a shard from an interval.
|
||||
|
||||
"""
|
||||
# Serialize the items
|
||||
serialized_items = self.serialize(items)
|
||||
serialized_items_size = len(serialized_items)
|
||||
|
||||
# Check whether it is time to write a chunk
|
||||
should_write = (
|
||||
self._chunk_bytes and self._chunk_bytes < self._current_chunk_bytes + serialized_items_size
|
||||
) or (self._chunk_size and len(self._indexes) >= self._chunk_size)
|
||||
|
||||
if should_write:
|
||||
if self._current_chunk_bytes == 0:
|
||||
raise Exception(
|
||||
f"The provided chunk_size {self._chunk_bytes} is too small."
|
||||
f" You should use a multiple of {serialized_items_size} bytes."
|
||||
)
|
||||
self.write_chunk()
|
||||
self.reset()
|
||||
self._chunk_index += 1
|
||||
|
||||
# Store the serialized items into the chunk.
|
||||
self._serialized_items.append(serialized_items)
|
||||
self._current_chunk_bytes += serialized_items_size
|
||||
|
||||
# Validate the index are provided in an incremental order
|
||||
# This is required to ensure we can find efficiently a chunk index from an index using the chunk interval
|
||||
if self._indexes:
|
||||
assert self._indexes[-1] == index - 1, (self._indexes, index - 1)
|
||||
|
||||
# Store the index
|
||||
self._indexes.append(index)
|
||||
|
||||
def write_chunk_to_file(
|
||||
self,
|
||||
raw_data: bytes,
|
||||
filename: str,
|
||||
) -> None:
|
||||
"""Write chunk bytes to a file."""
|
||||
# Whether to compress the raw bytes
|
||||
if self._compression:
|
||||
raw_data = self._compressor.compress(raw_data)
|
||||
|
||||
# Write the binary chunk file
|
||||
with open(os.path.join(self._cache_dir, filename), "wb") as out:
|
||||
out.write(raw_data)
|
||||
|
||||
def write_chunks_index(self) -> None:
|
||||
"""Write the chunks index to a JSON file."""
|
||||
filepath = os.path.join(self._cache_dir, f"{self.rank}.{_INDEX_FILENAME}")
|
||||
config = self.get_config()
|
||||
with open(filepath, "w") as out:
|
||||
json.dump({"chunks": self._chunks_info, "config": config}, out, sort_keys=True)
|
||||
|
||||
def done(self) -> None:
|
||||
"""Called when StopIteration is triggered."""
|
||||
if self.filled:
|
||||
return
|
||||
if self._serialized_items:
|
||||
self.write_chunk()
|
||||
self.write_chunks_index()
|
||||
self.reset()
|
||||
self._is_done = True
|
||||
|
||||
def merge(self, num_workers: int = 1) -> None:
|
||||
"""Once all the workers have written their own index, the merge function is responsible to read and merge them
|
||||
into a single index."""
|
||||
num_workers = num_workers or 1
|
||||
|
||||
# Only for non rank 0
|
||||
if self.rank != 0:
|
||||
while not os.path.exists(os.path.join(self._cache_dir, _INDEX_FILENAME)):
|
||||
sleep(0.001)
|
||||
return
|
||||
|
||||
# Wait for all indexes to be available
|
||||
is_done = False
|
||||
while not is_done:
|
||||
files = os.listdir(self._cache_dir)
|
||||
if _INDEX_FILENAME in files:
|
||||
return
|
||||
index_files = [f for f in files if f.endswith(_INDEX_FILENAME)]
|
||||
is_done = len(index_files) == self._distributed_env.world_size * num_workers
|
||||
sleep(0.001)
|
||||
|
||||
# Read the index and append the chunks together
|
||||
chunks_info = []
|
||||
config = None
|
||||
for index_filename in sorted(index_files):
|
||||
chunk_path = os.path.join(self._cache_dir, index_filename)
|
||||
with open(chunk_path) as f:
|
||||
data = json.load(f)
|
||||
|
||||
if config is None:
|
||||
config = data["config"]
|
||||
|
||||
elif config != data["config"]:
|
||||
raise Exception("The config isn't consistent between chunks. This shouldn't have happened.")
|
||||
|
||||
chunks_info.extend(data["chunks"])
|
||||
|
||||
os.remove(chunk_path)
|
||||
|
||||
# Write down the collected index
|
||||
with open(os.path.join(self._cache_dir, _INDEX_FILENAME), "w") as f:
|
||||
json.dump({"chunks": chunks_info, "config": config}, f, sort_keys=True)
|
|
@ -1,7 +1,7 @@
|
|||
from typing import Optional
|
||||
from typing import Callable, Optional
|
||||
|
||||
import torch
|
||||
from torch.utils.data import get_worker_info
|
||||
from torch.utils.data import get_worker_info as torch_get_worker_info
|
||||
|
||||
|
||||
class _DistributedEnv:
|
||||
|
@ -60,7 +60,7 @@ class _WorkerEnv:
|
|||
self.rank = rank
|
||||
|
||||
@classmethod
|
||||
def detect(cls) -> "_WorkerEnv":
|
||||
def detect(cls, get_worker_info_fn: Optional[Callable] = None) -> "_WorkerEnv":
|
||||
"""Automatically detects the number of workers and the current rank.
|
||||
|
||||
Note:
|
||||
|
@ -68,6 +68,7 @@ class _WorkerEnv:
|
|||
In such a case it will default to 1 worker
|
||||
|
||||
"""
|
||||
get_worker_info = get_worker_info_fn or torch_get_worker_info
|
||||
worker_info = get_worker_info()
|
||||
num_workers = worker_info.num_workers if worker_info is not None else 1
|
||||
current_worker_rank = worker_info.id if worker_info is not None else 0
|
||||
|
|
|
@ -0,0 +1,205 @@
|
|||
# Copyright The Lightning AI team.
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import os
|
||||
import sys
|
||||
from functools import partial
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
import torch
|
||||
from lightning import seed_everything
|
||||
from lightning.data.cache import Cache
|
||||
from lightning.data.cache.dataloader import LightningDataLoader
|
||||
from lightning.data.datasets.env import _DistributedEnv
|
||||
from lightning.fabric import Fabric
|
||||
from lightning.pytorch.demos.boring_classes import RandomDataset
|
||||
from lightning_utilities.core.imports import RequirementCache
|
||||
from torch.utils.data import Dataset
|
||||
|
||||
_PIL_AVAILABLE = RequirementCache("PIL")
|
||||
_TORCH_VISION_AVAILABLE = RequirementCache("torchvision")
|
||||
|
||||
|
||||
class ImageDataset(Dataset):
|
||||
def __init__(self, tmpdir, cache, size, num_classes):
|
||||
from PIL import Image
|
||||
|
||||
self.data = []
|
||||
self.cache = cache
|
||||
|
||||
seed_everything(42)
|
||||
|
||||
for i in range(size):
|
||||
path = os.path.join(tmpdir, f"img{i}.jpeg")
|
||||
np_data = np.random.randint(255, size=(28, 28), dtype=np.uint8)
|
||||
img = Image.fromarray(np_data).convert("L")
|
||||
img.save(path, format="jpeg", quality=100)
|
||||
self.data.append({"image": path, "class": np.random.randint(num_classes)})
|
||||
|
||||
def __len__(self):
|
||||
return len(self.data)
|
||||
|
||||
def __getitem__(self, index):
|
||||
if self.cache.filled:
|
||||
return self.cache[index]
|
||||
self.cache[index] = {**self.data[index], "index": index}
|
||||
return None
|
||||
|
||||
|
||||
def _cache_for_image_dataset(num_workers, tmpdir, fabric=None):
|
||||
from PIL import Image
|
||||
from torchvision.transforms import PILToTensor
|
||||
|
||||
dataset_size = 85
|
||||
|
||||
cache_dir = os.path.join(tmpdir, "cache")
|
||||
distributed_env = _DistributedEnv.detect()
|
||||
|
||||
cache = Cache(cache_dir, chunk_size=10)
|
||||
dataset = ImageDataset(tmpdir, cache, dataset_size, 10)
|
||||
dataloader = LightningDataLoader(dataset, num_workers=num_workers, batch_size=4)
|
||||
|
||||
for _ in dataloader:
|
||||
pass
|
||||
|
||||
# Not strictly required but added to avoid race condition
|
||||
if distributed_env.world_size > 1:
|
||||
fabric.barrier()
|
||||
|
||||
assert cache.filled
|
||||
|
||||
for i in range(len(dataset)):
|
||||
cached_data = dataset[i]
|
||||
original_data = dataset.data[i]
|
||||
assert cached_data["class"] == original_data["class"]
|
||||
original_array = PILToTensor()(Image.open(original_data["image"]))
|
||||
assert torch.equal(original_array, cached_data["image"])
|
||||
|
||||
if distributed_env.world_size == 1:
|
||||
indexes = []
|
||||
dataloader = LightningDataLoader(dataset, num_workers=num_workers, batch_size=4)
|
||||
for batch in dataloader:
|
||||
if batch:
|
||||
indexes.extend(batch["index"].numpy().tolist())
|
||||
assert len(indexes) == dataset_size
|
||||
|
||||
seed_everything(42)
|
||||
|
||||
dataloader = LightningDataLoader(dataset, num_workers=num_workers, batch_size=4, shuffle=True)
|
||||
dataloader_iter = iter(dataloader)
|
||||
|
||||
indexes = []
|
||||
for batch in dataloader_iter:
|
||||
indexes.extend(batch["index"].numpy().tolist())
|
||||
|
||||
if distributed_env.world_size == 1:
|
||||
assert len(indexes) == dataset_size
|
||||
|
||||
indexes2 = []
|
||||
for batch in dataloader_iter:
|
||||
indexes2.extend(batch["index"].numpy().tolist())
|
||||
|
||||
assert indexes2 != indexes
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
condition=not _PIL_AVAILABLE or not _TORCH_VISION_AVAILABLE, reason="Requires: ['pil', 'torchvision']"
|
||||
)
|
||||
@pytest.mark.parametrize("num_workers", [0, 1, 2])
|
||||
def test_cache_for_image_dataset(num_workers, tmpdir):
|
||||
cache_dir = os.path.join(tmpdir, "cache")
|
||||
os.makedirs(cache_dir)
|
||||
|
||||
_cache_for_image_dataset(num_workers, tmpdir)
|
||||
|
||||
|
||||
def _fabric_cache_for_image_dataset(fabric, num_workers, tmpdir):
|
||||
_cache_for_image_dataset(num_workers, tmpdir, fabric=fabric)
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
condition=not _PIL_AVAILABLE or not _TORCH_VISION_AVAILABLE or sys.platform == "win32",
|
||||
reason="Requires: ['pil', 'torchvision']",
|
||||
)
|
||||
@pytest.mark.parametrize("num_workers", [2])
|
||||
def test_cache_for_image_dataset_distributed(num_workers, tmpdir):
|
||||
cache_dir = os.path.join(tmpdir, "cache")
|
||||
os.makedirs(cache_dir)
|
||||
|
||||
fabric = Fabric(accelerator="cpu", devices=2, strategy="ddp_spawn")
|
||||
fabric.launch(partial(_fabric_cache_for_image_dataset, num_workers=num_workers, tmpdir=tmpdir))
|
||||
|
||||
|
||||
def test_cache_with_simple_format(tmpdir):
|
||||
cache_dir = os.path.join(tmpdir, "cache1")
|
||||
os.makedirs(cache_dir)
|
||||
|
||||
cache = Cache(cache_dir, chunk_bytes=90)
|
||||
|
||||
for i in range(100):
|
||||
cache[i] = i
|
||||
|
||||
cache.done()
|
||||
cache.merge()
|
||||
|
||||
for i in range(100):
|
||||
assert i == cache[i]
|
||||
|
||||
cache_dir = os.path.join(tmpdir, "cache2")
|
||||
os.makedirs(cache_dir)
|
||||
|
||||
cache = Cache(cache_dir, chunk_bytes=90)
|
||||
|
||||
for i in range(100):
|
||||
cache[i] = [i, {0: [i + 1]}]
|
||||
|
||||
cache.done()
|
||||
cache.merge()
|
||||
|
||||
for i in range(100):
|
||||
assert [i, {0: [i + 1]}] == cache[i]
|
||||
|
||||
|
||||
def test_cache_with_auto_wrapping(tmpdir):
|
||||
os.makedirs(os.path.join(tmpdir, "cache_1"), exist_ok=True)
|
||||
|
||||
dataset = RandomDataset(64, 64)
|
||||
dataloader = LightningDataLoader(dataset, cache_dir=os.path.join(tmpdir, "cache_1"), chunk_bytes=2 << 12)
|
||||
for batch in dataloader:
|
||||
assert isinstance(batch, torch.Tensor)
|
||||
assert sorted(os.listdir(os.path.join(tmpdir, "cache_1"))) == [
|
||||
"chunk-0-0.bin",
|
||||
"chunk-0-1.bin",
|
||||
"chunk-0-2.bin",
|
||||
"index.json",
|
||||
]
|
||||
# Your dataset is optimised for the cloud
|
||||
|
||||
class RandomDatasetAtRuntime(Dataset):
|
||||
def __init__(self, size: int, length: int):
|
||||
self.len = length
|
||||
self.size = size
|
||||
|
||||
def __getitem__(self, index: int) -> torch.Tensor:
|
||||
return torch.randn(1, self.size)
|
||||
|
||||
def __len__(self) -> int:
|
||||
return self.len
|
||||
|
||||
os.makedirs(os.path.join(tmpdir, "cache_2"), exist_ok=True)
|
||||
dataset = RandomDatasetAtRuntime(64, 64)
|
||||
dataloader = LightningDataLoader(dataset, cache_dir=os.path.join(tmpdir, "cache_2"), chunk_bytes=2 << 12)
|
||||
with pytest.raises(ValueError, match="Your dataset items aren't deterministic"):
|
||||
for batch in dataloader:
|
||||
pass
|
|
@ -0,0 +1,112 @@
|
|||
from unittest import mock
|
||||
|
||||
import pytest
|
||||
from lightning import seed_everything
|
||||
from lightning.data.cache.sampler import CacheBatchSampler
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"params",
|
||||
[
|
||||
(
|
||||
21,
|
||||
1,
|
||||
[[0, 1, 2], [7, 8, 9], [14, 15, 16], [3, 4, 5], [10, 11, 12], [17, 18, 19], [6], [13], [20]],
|
||||
[[7, 0, 0], [1, 1, 1], [5, 5, 5], [0, 4, 4], [8, 3, 3], [2, 2, 2], [4], [3], [6]],
|
||||
),
|
||||
(
|
||||
11,
|
||||
1,
|
||||
[[0, 1, 2], [3, 4, 5], [6, 7, 8], [], [], [9, 10]],
|
||||
[[1, 1, 1], [3, 3], [0, 0, 0], [2, 2, 2]],
|
||||
),
|
||||
(8, 1, [[0, 1], [2, 3], [4, 5, 6], [], [], [7]], [[1, 1, 2], [3], [0, 0], [2, 2]]),
|
||||
(4, 1, [[0], [1], [2, 3]], [[0], [1], [2, 2]]),
|
||||
(
|
||||
9,
|
||||
1,
|
||||
[[0, 1, 2], [3, 4, 5], [6, 7, 8]],
|
||||
[[0, 0, 0], [1, 1, 1], [2, 2, 2]],
|
||||
),
|
||||
(
|
||||
19,
|
||||
1,
|
||||
[[0, 1, 2], [6, 7, 8], [12, 13, 14], [3, 4, 5], [9, 10, 11], [15, 16, 17], [], [], [18]],
|
||||
[[0, 0, 0], [1, 1, 1], [5, 5, 5], [2, 2, 2], [4, 4, 4], [3, 3, 3], [6]],
|
||||
),
|
||||
(19, 2, [[0, 1, 2], [3, 4, 5], [6, 7, 8]], [[0, 0, 0], [5, 5, 5], [4, 4, 4], [6]]),
|
||||
],
|
||||
)
|
||||
def test_cache_batch_sampler(params):
|
||||
seed_everything(42)
|
||||
|
||||
cache = mock.MagicMock()
|
||||
cache.filled = False
|
||||
if params[1] > 1:
|
||||
batch_sampler = CacheBatchSampler(params[0], params[1], 0, 3, 3, False, True, cache)
|
||||
batches = []
|
||||
for batch in batch_sampler:
|
||||
batches.append(batch)
|
||||
assert batches == params[2], batches
|
||||
|
||||
batch_sampler = CacheBatchSampler(params[0], 1, 0, 3, 3, False, True, cache)
|
||||
batches = []
|
||||
for batch in batch_sampler:
|
||||
batches.append(batch)
|
||||
|
||||
chunks_interval = [[batch[0], batch[-1] + 1] for batch in batches if len(batch)]
|
||||
else:
|
||||
batch_sampler = CacheBatchSampler(params[0], params[1], 0, 3, 3, False, True, cache)
|
||||
batches = []
|
||||
for batch in batch_sampler:
|
||||
batches.append(batch)
|
||||
assert batches == params[2], batches
|
||||
|
||||
chunks_interval = [[batch[0], batch[-1] + 1] for batch in batches if len(batch)]
|
||||
|
||||
cache.filled = True
|
||||
cache.get_chunk_interval.return_value = chunks_interval
|
||||
|
||||
seed_everything(42)
|
||||
|
||||
batch_sampler = CacheBatchSampler(params[0], params[1], 0, 3, 3, False, True, cache)
|
||||
|
||||
batches_1 = []
|
||||
for batch in batch_sampler:
|
||||
batches_1.append(batch)
|
||||
|
||||
def validate_batch(data, check_values):
|
||||
if params[1] == 1:
|
||||
assert all(b[0].chunk_indexes is not None for b in data[:3])
|
||||
assert all(b[1].chunk_indexes is None if len(b) > 1 else True for b in data[:3])
|
||||
assert all(b[0].chunk_indexes is None if len(b) else True for b in data[3:])
|
||||
if check_values:
|
||||
assert [[x.chunk_index for x in d] for d in data] == params[3]
|
||||
else:
|
||||
assert all(b[0].chunk_indexes is not None for b in data[:3])
|
||||
assert all(b[1].chunk_indexes is None if len(b) > 1 else True for b in data[:3])
|
||||
assert all(b[0].chunk_indexes is None if len(b) else True for b in data[3:])
|
||||
if check_values:
|
||||
assert [[x.chunk_index for x in d] for d in data] == params[3]
|
||||
|
||||
validate_batch(batches_1, True)
|
||||
|
||||
batches_2 = []
|
||||
for batch in batch_sampler:
|
||||
batches_2.append(batch)
|
||||
|
||||
validate_batch(batches_2, False)
|
||||
if params[1] == 1:
|
||||
assert batches_1 != batches_2
|
||||
|
||||
|
||||
def test_batch_sampler_imagenet():
|
||||
"""Validate the Imagenet dataset is valid."""
|
||||
dataset_size = 1281167
|
||||
world_size = 1
|
||||
rank = 0
|
||||
num_workers = 32
|
||||
batch_size = 8
|
||||
cache = mock.MagicMock()
|
||||
cache.filled = False
|
||||
CacheBatchSampler(dataset_size, world_size, rank, num_workers, batch_size, False, True, cache)
|
|
@ -0,0 +1,115 @@
|
|||
# Copyright The Lightning AI team.
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import sys
|
||||
from time import time
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
import torch
|
||||
from lightning import seed_everything
|
||||
from lightning.data.cache.serializers import (
|
||||
_SERIALIZERS,
|
||||
_TORCH_DTYPES_MAPPING,
|
||||
IntSerializer,
|
||||
PickleSerializer,
|
||||
PILSerializer,
|
||||
TensorSerializer,
|
||||
)
|
||||
from lightning_utilities.core.imports import RequirementCache
|
||||
|
||||
_PIL_AVAILABLE = RequirementCache("PIL")
|
||||
|
||||
|
||||
def test_serializers():
|
||||
assert list(_SERIALIZERS.keys()) == ["file", "pil", "int", "jpeg", "bytes", "tensor", "pickle"]
|
||||
|
||||
|
||||
def test_int_serializer():
|
||||
serializer = IntSerializer()
|
||||
|
||||
for i in range(100):
|
||||
data, _ = serializer.serialize(i)
|
||||
assert isinstance(data, bytes)
|
||||
assert i == serializer.deserialize(data)
|
||||
|
||||
|
||||
@pytest.mark.skipif(condition=not _PIL_AVAILABLE, reason="Requires: ['pil']")
|
||||
@pytest.mark.parametrize("mode", ["I", "L", "RGB"])
|
||||
def test_pil_serializer(mode):
|
||||
serializer = PILSerializer()
|
||||
|
||||
from PIL import Image
|
||||
|
||||
np_data = np.random.randint(255, size=(28, 28), dtype=np.uint32)
|
||||
img = Image.fromarray(np_data).convert(mode)
|
||||
|
||||
data, _ = serializer.serialize(img)
|
||||
assert isinstance(data, bytes)
|
||||
|
||||
deserialized_img = serializer.deserialize(data)
|
||||
deserialized_img = deserialized_img.convert("I")
|
||||
np_dec_data = np.asarray(deserialized_img, dtype=np.uint32)
|
||||
assert isinstance(deserialized_img, Image.Image)
|
||||
|
||||
# Validate data content
|
||||
assert np.array_equal(np_data, np_dec_data)
|
||||
|
||||
|
||||
@pytest.mark.skipif(sys.platform == "win32", reason="Not supported on windows")
|
||||
def test_tensor_serializer():
|
||||
seed_everything(42)
|
||||
|
||||
serializer_tensor = TensorSerializer()
|
||||
serializer_pickle = PickleSerializer()
|
||||
|
||||
ratio_times = []
|
||||
ratio_bytes = []
|
||||
shapes = [(10,), (10, 10), (10, 10, 10), (10, 10, 10, 5), (10, 10, 10, 5, 4)]
|
||||
for dtype in _TORCH_DTYPES_MAPPING.values():
|
||||
for shape in shapes:
|
||||
# Not serializable for some reasons
|
||||
if dtype in [torch.bfloat16]:
|
||||
continue
|
||||
tensor = torch.ones(shape, dtype=dtype)
|
||||
|
||||
t0 = time()
|
||||
data, _ = serializer_tensor.serialize(tensor)
|
||||
deserialized_tensor = serializer_tensor.deserialize(data)
|
||||
tensor_time = time() - t0
|
||||
tensor_bytes = len(data)
|
||||
|
||||
assert deserialized_tensor.dtype == dtype
|
||||
assert torch.equal(tensor, deserialized_tensor)
|
||||
|
||||
t1 = time()
|
||||
data, _ = serializer_pickle.serialize(tensor)
|
||||
deserialized_tensor = serializer_pickle.deserialize(data)
|
||||
pickle_time = time() - t1
|
||||
pickle_bytes = len(data)
|
||||
|
||||
assert deserialized_tensor.dtype == dtype
|
||||
assert torch.equal(tensor, deserialized_tensor)
|
||||
|
||||
ratio_times.append(pickle_time / tensor_time)
|
||||
ratio_bytes.append(pickle_bytes / tensor_bytes)
|
||||
|
||||
assert np.mean(ratio_times) > 3.5
|
||||
assert np.mean(ratio_bytes) > 2
|
||||
|
||||
|
||||
def test_assert_bfloat16_tensor_serializer():
|
||||
serializer = TensorSerializer()
|
||||
tensor = torch.ones((10,), dtype=torch.bfloat16)
|
||||
with pytest.raises(TypeError, match="Got unsupported ScalarType BFloat16"):
|
||||
serializer.serialize(tensor)
|
|
@ -0,0 +1,168 @@
|
|||
# Copyright The Lightning AI team.
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import json
|
||||
import os
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
from lightning.data.cache.reader import BinaryReader
|
||||
from lightning.data.cache.sampler import ChunkedIndex
|
||||
from lightning.data.cache.writer import BinaryWriter
|
||||
from lightning_utilities.core.imports import RequirementCache
|
||||
|
||||
_PIL_AVAILABLE = RequirementCache("PIL")
|
||||
|
||||
|
||||
def test_binary_writer_with_ints_and_chunk_bytes(tmpdir):
|
||||
with pytest.raises(FileNotFoundError, match="The provided cache directory `dontexists` doesn't exist."):
|
||||
BinaryWriter("dontexists", {})
|
||||
|
||||
with pytest.raises(ValueError, match="No compresion algorithms are installed."):
|
||||
BinaryWriter(tmpdir, {"i": "int"}, compression="something_else")
|
||||
|
||||
binary_writer = BinaryWriter(tmpdir, chunk_bytes=90)
|
||||
|
||||
for i in range(100):
|
||||
binary_writer[i] = {"i": i, "i+1": i + 1, "i+2": i + 2}
|
||||
|
||||
assert len(os.listdir(tmpdir)) == 19
|
||||
binary_writer.done()
|
||||
binary_writer.merge()
|
||||
assert len(os.listdir(tmpdir)) == 21
|
||||
|
||||
with open(os.path.join(tmpdir, "index.json")) as f:
|
||||
data = json.load(f)
|
||||
|
||||
assert data["chunks"][0]["chunk_size"] == 6
|
||||
assert data["chunks"][1]["chunk_size"] == 5
|
||||
assert data["chunks"][-1]["chunk_size"] == 4
|
||||
|
||||
chunk_sizes = np.cumsum([chunk["chunk_size"] for chunk in data["chunks"]])
|
||||
|
||||
reader = BinaryReader(tmpdir)
|
||||
for i in range(100):
|
||||
for chunk_index, chunk_start in enumerate(chunk_sizes):
|
||||
if i >= chunk_start:
|
||||
continue
|
||||
break
|
||||
data = reader.read(ChunkedIndex(i, chunk_index=chunk_index))
|
||||
assert data == {"i": i, "i+1": i + 1, "i+2": i + 2}
|
||||
|
||||
|
||||
def test_binary_writer_with_ints_and_chunk_size(tmpdir):
|
||||
with pytest.raises(FileNotFoundError, match="The provided cache directory `dontexists` doesn't exist."):
|
||||
BinaryWriter("dontexists", {})
|
||||
|
||||
with pytest.raises(ValueError, match="No compresion algorithms are installed."):
|
||||
BinaryWriter(tmpdir, {"i": "int"}, compression="something_else")
|
||||
|
||||
binary_writer = BinaryWriter(tmpdir, chunk_size=25)
|
||||
|
||||
for i in range(100):
|
||||
binary_writer[i] = {"i": i, "i+1": i + 1, "i+2": i + 2}
|
||||
|
||||
assert len(os.listdir(tmpdir)) == 3
|
||||
binary_writer.done()
|
||||
binary_writer.merge()
|
||||
assert len(os.listdir(tmpdir)) == 5
|
||||
|
||||
with open(os.path.join(tmpdir, "index.json")) as f:
|
||||
data = json.load(f)
|
||||
|
||||
assert data["chunks"][0]["chunk_size"] == 25
|
||||
assert data["chunks"][1]["chunk_size"] == 25
|
||||
assert data["chunks"][-1]["chunk_size"] == 25
|
||||
|
||||
reader = BinaryReader(tmpdir)
|
||||
for i in range(100):
|
||||
data = reader.read(ChunkedIndex(i, chunk_index=i // 25))
|
||||
assert data == {"i": i, "i+1": i + 1, "i+2": i + 2}
|
||||
|
||||
|
||||
@pytest.mark.skipif(condition=not _PIL_AVAILABLE, reason="Requires: ['pil']")
|
||||
def test_binary_writer_with_jpeg_and_int(tmpdir):
|
||||
"""Validate the writer and reader can serialize / deserialize a pair of image and label."""
|
||||
from PIL import Image
|
||||
|
||||
cache_dir = os.path.join(tmpdir, "chunks")
|
||||
os.makedirs(cache_dir, exist_ok=True)
|
||||
binary_writer = BinaryWriter(cache_dir, chunk_bytes=2 << 12)
|
||||
|
||||
imgs = []
|
||||
|
||||
for i in range(100):
|
||||
path = os.path.join(tmpdir, f"img{i}.jpeg")
|
||||
np_data = np.random.randint(255, size=(28, 28), dtype=np.uint8)
|
||||
img = Image.fromarray(np_data).convert("L")
|
||||
img.save(path, format="jpeg", quality=100)
|
||||
img = Image.open(path)
|
||||
imgs.append(img)
|
||||
binary_writer[i] = {"x": img, "y": i}
|
||||
|
||||
assert len(os.listdir(cache_dir)) == 24
|
||||
binary_writer.done()
|
||||
binary_writer.merge()
|
||||
assert len(os.listdir(cache_dir)) == 26
|
||||
|
||||
with open(os.path.join(cache_dir, "index.json")) as f:
|
||||
data = json.load(f)
|
||||
|
||||
assert data["chunks"][0]["chunk_size"] == 4
|
||||
assert data["chunks"][1]["chunk_size"] == 4
|
||||
assert data["chunks"][-1]["chunk_size"] == 4
|
||||
|
||||
reader = BinaryReader(cache_dir)
|
||||
for i in range(100):
|
||||
data = reader.read(ChunkedIndex(i, chunk_index=i // 4))
|
||||
np.testing.assert_array_equal(np.asarray(data["x"]).squeeze(0), imgs[i])
|
||||
assert data["y"] == i
|
||||
|
||||
|
||||
@pytest.mark.skipif(condition=not _PIL_AVAILABLE, reason="Requires: ['pil']")
|
||||
def test_binary_writer_with_jpeg_filepath_and_int(tmpdir):
|
||||
"""Validate the writer and reader can serialize / deserialize a pair of image and label."""
|
||||
from PIL import Image
|
||||
|
||||
cache_dir = os.path.join(tmpdir, "chunks")
|
||||
os.makedirs(cache_dir, exist_ok=True)
|
||||
binary_writer = BinaryWriter(cache_dir, chunk_bytes=2 << 12)
|
||||
|
||||
imgs = []
|
||||
|
||||
for i in range(100):
|
||||
path = os.path.join(tmpdir, f"img{i}.jpeg")
|
||||
np_data = np.random.randint(255, size=(28, 28), dtype=np.uint8)
|
||||
img = Image.fromarray(np_data).convert("L")
|
||||
img.save(path, format="jpeg", quality=100)
|
||||
img = Image.open(path)
|
||||
imgs.append(img)
|
||||
binary_writer[i] = {"x": path, "y": i}
|
||||
|
||||
assert len(os.listdir(cache_dir)) == 24
|
||||
binary_writer.done()
|
||||
binary_writer.merge()
|
||||
assert len(os.listdir(cache_dir)) == 26
|
||||
|
||||
with open(os.path.join(cache_dir, "index.json")) as f:
|
||||
data = json.load(f)
|
||||
|
||||
assert data["chunks"][0]["chunk_size"] == 4
|
||||
assert data["chunks"][1]["chunk_size"] == 4
|
||||
assert data["chunks"][-1]["chunk_size"] == 4
|
||||
|
||||
reader = BinaryReader(cache_dir)
|
||||
for i in range(100):
|
||||
data = reader.read(ChunkedIndex(i, chunk_index=i // 4))
|
||||
np.testing.assert_array_equal(np.asarray(data["x"]).squeeze(0), imgs[i])
|
||||
assert data["y"] == i
|
|
@ -17,8 +17,7 @@ def get_test_index_data(index_path):
|
|||
return list(dict.fromkeys([item.split("/")[-1] for item in data if "jpeg" in item]))
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def image_set(tmp_path_factory):
|
||||
def image_set(tmpdir):
|
||||
from PIL import Image
|
||||
|
||||
file_nums = [
|
||||
|
@ -45,11 +44,11 @@ def image_set(tmp_path_factory):
|
|||
img = img.astype(np.uint8)
|
||||
im = Image.fromarray(img)
|
||||
|
||||
for i in file_nums:
|
||||
fn = tmp_path_factory.mktemp("test_data") / f"img-{i}.jpeg"
|
||||
im.save(fn)
|
||||
folder_path = os.path.join(tmpdir, "test_data")
|
||||
os.makedirs(folder_path, exist_ok=True)
|
||||
|
||||
return tmp_path_factory.getbasetemp()._str
|
||||
for i in file_nums:
|
||||
im.save(os.path.join(folder_path, f"img-{i}.jpeg"))
|
||||
|
||||
|
||||
@pytest.mark.xfail(strict=False, reason="Need a valid AWS key and AWS secret key in CI for this to work")
|
||||
|
@ -70,7 +69,6 @@ def test_get_index_generate_for_s3_bucket(monkeypatch):
|
|||
|
||||
test_bucket = "s3://nohaspublictestbucket"
|
||||
index_path = os.path.join(os.getcwd(), "index_1.txt")
|
||||
print(index_path)
|
||||
got_index = get_index(s3_connection_path=test_bucket, index_file_path=index_path)
|
||||
|
||||
assert got_index
|
||||
|
@ -80,13 +78,16 @@ def test_get_index_generate_for_s3_bucket(monkeypatch):
|
|||
|
||||
assert len(test_index_data) == len(generated_index)
|
||||
assert test_index_data == generated_index
|
||||
os.remove(index_path)
|
||||
|
||||
|
||||
@pytest.mark.skipif(not package_available("lightning"), reason="Supported only with mono-package")
|
||||
@mock.patch("lightning.data.datasets.index.LightningClient", MagicMock())
|
||||
def test_get_index_generate_for_local_folder(image_set, monkeypatch):
|
||||
def test_get_index_generate_for_local_folder(monkeypatch, tmpdir):
|
||||
"""Can generate an index for an s3 bucket."""
|
||||
|
||||
image_set(tmpdir)
|
||||
|
||||
client = MagicMock()
|
||||
client.projects_service_list_project_cluster_bindings.return_value = None
|
||||
client.data_connection_service_list_data_connections.return_value = None
|
||||
|
@ -100,7 +101,7 @@ def test_get_index_generate_for_local_folder(image_set, monkeypatch):
|
|||
|
||||
# test_local_bucket = "data/test_dataset"
|
||||
index_path = os.path.join(THIS_DIR, "index_2.txt")
|
||||
got_index = get_index(s3_connection_path=image_set, index_file_path=index_path)
|
||||
got_index = get_index(s3_connection_path=str(tmpdir), index_file_path=index_path)
|
||||
|
||||
assert got_index
|
||||
|
||||
|
|
Loading…
Reference in New Issue