diff --git a/.github/checkgroup.yml b/.github/checkgroup.yml index a8bbe1d341..0aef571eaa 100644 --- a/.github/checkgroup.yml +++ b/.github/checkgroup.yml @@ -173,9 +173,9 @@ subprojects: - "!*.md" - "!**/*.md" checks: - - "data-cpu (macOS-11, lightning, 3.10, 2.0)" - - "data-cpu (ubuntu-20.04, lightning, 3.10, 2.0)" - - "data-cpu (windows-2022, lightning, 3.10, 2.0)" + - "data-cpu (macOS-11, lightning, 3.10, 2.1)" + - "data-cpu (ubuntu-20.04, lightning, 3.10, 2.1)" + - "data-cpu (windows-2022, lightning, 3.10, 2.1)" # SECTION: lightning_fabric diff --git a/.github/workflows/ci-tests-data.yml b/.github/workflows/ci-tests-data.yml index 6d303c2609..4de87f501f 100644 --- a/.github/workflows/ci-tests-data.yml +++ b/.github/workflows/ci-tests-data.yml @@ -34,9 +34,9 @@ jobs: fail-fast: false matrix: include: - - { os: "macOS-11", pkg-name: "lightning", python-version: "3.10", pytorch-version: "2.0" } - - { os: "ubuntu-20.04", pkg-name: "lightning", python-version: "3.10", pytorch-version: "2.0" } - - { os: "windows-2022", pkg-name: "lightning", python-version: "3.10", pytorch-version: "2.0" } + - { os: "macOS-11", pkg-name: "lightning", python-version: "3.10", pytorch-version: "2.1" } + - { os: "ubuntu-20.04", pkg-name: "lightning", python-version: "3.10", pytorch-version: "2.1" } + - { os: "windows-2022", pkg-name: "lightning", python-version: "3.10", pytorch-version: "2.1" } # "oldest" versions tests, only on minimum Python # - {os: "macOS-11", pkg-name: "lightning", python-version: "3.8", pytorch-version: "2.0", requires: "oldest"} # - {os: "ubuntu-20.04", pkg-name: "lightning", python-version: "3.8", pytorch-version: "2.0", requires: "oldest"} diff --git a/requirements/data/data.txt b/requirements/data/data.txt index 4fa81bd1a5..4813af9523 100644 --- a/requirements/data/data.txt +++ b/requirements/data/data.txt @@ -3,6 +3,6 @@ lightning-utilities >=0.8.0, <0.10.0 # to be able to include also 0.6 and preserve `>` needed for CI min version bypass -torchdata >0.5.9, <0.7.0 +torchdata >0.5.9, <=0.7.0 # to be able to include also PL 2.0 and preserve `>` needed for CI min version bypass -torch >0.14.0, <2.1.0 +torch >0.14.0, <=2.1.0 diff --git a/src/lightning/data/cache/__init__.py b/src/lightning/data/cache/__init__.py new file mode 100644 index 0000000000..1f9601debc --- /dev/null +++ b/src/lightning/data/cache/__init__.py @@ -0,0 +1,17 @@ +# Copyright The Lightning AI team. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from lightning.data.cache.cache import Cache +from lightning.data.cache.dataloader import LightningDataLoader + +__all__ = ["Cache", "LightningDataLoader"] diff --git a/src/lightning/data/cache/cache.py b/src/lightning/data/cache/cache.py new file mode 100644 index 0000000000..1d9c2d6e69 --- /dev/null +++ b/src/lightning/data/cache/cache.py @@ -0,0 +1,90 @@ +# Copyright The Lightning AI team. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +import os +from typing import Any, Dict, List, Optional, Tuple, Union + +from lightning.data.cache.constants import _INDEX_FILENAME, _TORCH_2_1_0_AVAILABLE +from lightning.data.cache.reader import BinaryReader +from lightning.data.cache.sampler import ChunkedIndex +from lightning.data.cache.writer import BinaryWriter +from lightning.data.datasets.env import _DistributedEnv + +logger = logging.Logger(__name__) + + +class Cache: + def __init__( + self, + cache_dir: str, + remote_dir: Optional[str] = None, + compression: Optional[str] = None, + chunk_size: Optional[int] = None, + chunk_bytes: Optional[int] = None, + ): + """The Cache enables to optimise dataset format for cloud training. This is done by grouping several elements + together in order to accelerate fetching. + + Arguments: + cache_dir: The path to where the chunks will be stored. + remote_dir: The path to a remote folder where the data are located. + The scheme needs to be added to the path. + compression: The name of the algorithm to reduce the size of the chunks. + chunk_bytes: The maximum number of bytes within a chunk. + chunk_size: The maximum number of items within a chunk. + + """ + super().__init__() + if not _TORCH_2_1_0_AVAILABLE: + raise ModuleNotFoundError("PyTorch version 2.1 or higher is required to use the cache.") + self._writer = BinaryWriter(cache_dir, chunk_size=chunk_size, chunk_bytes=chunk_bytes, compression=compression) + self._reader = BinaryReader(cache_dir, remote_dir=remote_dir, compression=compression) + self._cache_dir = cache_dir + self._is_done = False + self._distributed_env = _DistributedEnv.detect() + + @property + def filled(self) -> bool: + """Returns whether the caching phase is done.""" + if self._is_done: + return True + self._is_done = os.path.exists(os.path.join(self._cache_dir, _INDEX_FILENAME)) + return self._is_done + + def __setitem__(self, index: int, data: Any) -> None: + """Store an item in the writer.""" + self._writer[index] = data + + def __getitem__(self, index: Union[int, ChunkedIndex]) -> Dict[str, Any]: + """Read an item in the reader.""" + if isinstance(index, int): + index = ChunkedIndex(index, self._get_chunk_index_from_index(index)) + return self._reader.read(index) + + def done(self) -> None: + """Inform the writer the chunking phase is finished.""" + self._writer.done() + + def merge(self, num_workers: int = 1) -> None: + """Inform the writer the chunking phase is finished.""" + self._writer.merge(num_workers) + + def __len__(self) -> int: + return self._reader.get_length() + + def get_chunk_interval(self) -> List[Tuple[int, int]]: + return self._reader.get_chunk_interval() + + def _get_chunk_index_from_index(self, index: int) -> int: + return self._reader._get_chunk_index_from_index(index) diff --git a/src/lightning/data/cache/compression.py b/src/lightning/data/cache/compression.py new file mode 100644 index 0000000000..68fbc2eaf3 --- /dev/null +++ b/src/lightning/data/cache/compression.py @@ -0,0 +1,76 @@ +# Copyright The Lightning AI team. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from abc import ABC, abstractclassmethod, abstractmethod +from typing import Dict, TypeVar + +from lightning_utilities.core.imports import RequirementCache, requires + +_ZSTD_AVAILABLE = RequirementCache("zstd") + +if _ZSTD_AVAILABLE: + import zstd + +TCompressor = TypeVar("TCompressor", bound="Compressor") + + +class Compressor(ABC): + """Base class for compression algorithm.""" + + @abstractmethod + def compress(self, data: bytes) -> bytes: + pass + + @abstractmethod + def decompress(self, data: bytes) -> bytes: + pass + + @abstractclassmethod + def register(cls, compressors: Dict[str, "Compressor"]) -> None: + pass + + +class ZSTDCompressor(Compressor): + """Compressor for the zstd package.""" + + @requires("zstd") + def __init__(self, level: int) -> None: + super().__init__() + self.level = level + self.extension = "zstd" + + @property + def name(self) -> str: + return f"{self.extension}:{self.level}" + + def compress(self, data: bytes) -> bytes: + return zstd.compress(data, self.level) + + def decompress(self, data: bytes) -> bytes: + return zstd.decompress(data) + + @classmethod + def register(cls, compressors: Dict[str, "Compressor"]) -> None: # type: ignore + if not _ZSTD_AVAILABLE: + return + + # default + compressors["zstd"] = ZSTDCompressor(4) + + for level in list(range(1, 23)): + compressors[f"zstd:{level}"] = ZSTDCompressor(level) + + +_COMPRESSORS: Dict[str, Compressor] = {} + +ZSTDCompressor.register(_COMPRESSORS) diff --git a/src/lightning/data/cache/config.py b/src/lightning/data/cache/config.py new file mode 100644 index 0000000000..2e23a15dc2 --- /dev/null +++ b/src/lightning/data/cache/config.py @@ -0,0 +1,125 @@ +# Copyright The Lightning AI team. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json +import os +from typing import Any, Dict, List, Optional, Tuple + +from lightning.data.cache.constants import _INDEX_FILENAME, _TORCH_2_1_0_AVAILABLE +from lightning.data.cache.downloader import get_downloader_cls +from lightning.data.cache.sampler import ChunkedIndex + +if _TORCH_2_1_0_AVAILABLE: + from torch.utils._pytree import treespec_loads + + +class ChunksConfig: + def __init__(self, cache_dir: str, remote_dir: Optional[str]): + """The ChunksConfig reads the index files associated a chunked dataset and enables to map an index to its + chunk. + + Arguments: + cache_dir: The path to cache folder. + remote_dir: The path to a remote folder where the data are located. + The scheme needs to be added to the path. + + """ + self._cache_dir = cache_dir + self._intervals: List[Tuple[int, int]] = [] + self._config = None + self._chunks = [] + self._remote_dir = remote_dir + + with open(os.path.join(self._cache_dir, _INDEX_FILENAME)) as f: + data = json.load(f) + + self._config = data["config"] + + self._chunks.extend(data["chunks"]) + + self._config["data_spec"] = treespec_loads(self._config["data_spec"]) + + for chunk in self._chunks: + start, end = chunk["interval"] + if (end - start) != chunk["chunk_size"]: + raise Exception( + "The config intervals doesn't match the number of samples. This shouldn't have happened." + ) + self._intervals.append((chunk["interval"][0], chunk["interval"][1])) + + self._length = sum([chunk["chunk_size"] for chunk in self._chunks]) + + self._downloader = None + + if remote_dir: + self._downloader = get_downloader_cls(remote_dir)(remote_dir, cache_dir, self._chunks) + + def download_chunk_from_index(self, chunk_index: int) -> None: + chunk_filename = self._chunks[chunk_index]["filename"] + + local_chunkpath = os.path.join(self._cache_dir, chunk_filename) + + if os.path.exists(local_chunkpath): + return + + if self._downloader is None: + raise RuntimeError("The downloader should be defined.") + + self._downloader.download_chunk_from_index(chunk_index) + + @property + def intervals(self) -> List[Tuple[int, int]]: + if self._intervals is None: + raise RuntimeError("The intervals should be defined.") + return self._intervals + + @property + def data_format(self) -> Any: + if self._config is None: + raise RuntimeError("The config should be defined.") + return self._config["data_format"] + + @property + def config(self) -> Dict[str, Any]: + if self._config is None: + raise RuntimeError("The config should be defined.") + return self._config + + def _get_chunk_index_from_index(self, index: int) -> int: + for chunk_index, internal in enumerate(self._intervals): + if internal[0] <= index < internal[1]: + return chunk_index + raise ValueError( + f"The provided index {index} didn't find a match within the chunk intervals {self._intervals}." + ) + + def __getitem__(self, index: ChunkedIndex) -> Tuple[str, int, int]: + """Find the associated chunk metadata.""" + chunk = self._chunks[index.chunk_index] + return os.path.join(self._cache_dir, chunk["filename"]), *self._intervals[index.chunk_index] + + @classmethod + def load(cls, cache_dir: str, remote_dir: Optional[str] = None) -> Optional["ChunksConfig"]: + cache_index_filepath = os.path.join(cache_dir, _INDEX_FILENAME) + + if isinstance(remote_dir, str): + downloader = get_downloader_cls(remote_dir)(remote_dir, cache_dir, []) + downloader.download_file(os.path.join(remote_dir, _INDEX_FILENAME), cache_index_filepath) + + if not os.path.exists(cache_index_filepath): + return None + + return ChunksConfig(cache_dir, remote_dir) + + def __len__(self) -> int: + return self._length diff --git a/src/lightning/data/cache/constants.py b/src/lightning/data/cache/constants.py new file mode 100644 index 0000000000..d9dfa136c4 --- /dev/null +++ b/src/lightning/data/cache/constants.py @@ -0,0 +1,21 @@ +# Copyright The Lightning AI team. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from lightning_utilities.core.imports import RequirementCache + +_INDEX_FILENAME = "index.json" +_DEFAULT_CHUNK_BYTES = 1 << 26 # 64M B + +# This is required for full pytree serialization / deserialization support +_TORCH_2_1_0_AVAILABLE = RequirementCache("torch>=2.1.0") +_VIZ_TRACKER_AVAILABLE = RequirementCache("viztracer") diff --git a/src/lightning/data/cache/dataloader.py b/src/lightning/data/cache/dataloader.py new file mode 100644 index 0000000000..8f58b54c97 --- /dev/null +++ b/src/lightning/data/cache/dataloader.py @@ -0,0 +1,311 @@ +# Copyright The Lightning AI team. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import asyncio +import inspect +import logging +import os +from importlib import reload +from typing import Any, Callable, List, Optional + +import torch +from torch.utils.data import Dataset, IterableDataset +from torch.utils.data._utils.collate import default_collate +from torch.utils.data._utils.fetch import _BaseDatasetFetcher +from torch.utils.data.dataloader import ( + DataLoader, + _BaseDataLoaderIter, + _DatasetKind, + _MultiProcessingDataLoaderIter, + _SingleProcessDataLoaderIter, +) +from torch.utils.data.sampler import BatchSampler, Sampler + +from lightning.data.cache import Cache +from lightning.data.cache.constants import _DEFAULT_CHUNK_BYTES, _TORCH_2_1_0_AVAILABLE, _VIZ_TRACKER_AVAILABLE +from lightning.data.cache.sampler import CacheBatchSampler +from lightning.data.datasets.env import _DistributedEnv, _WorkerEnv + +if _TORCH_2_1_0_AVAILABLE: + from torch.utils._pytree import tree_flatten + +logger = logging.Logger(__name__) + + +def _equal_items(data_1: Any, data_2: Any) -> bool: + data_1_flattened, _ = tree_flatten(data_1) + data_2_flattened, _ = tree_flatten(data_2) + + if len(data_1_flattened) != len(data_2_flattened): + return False + + return all(_equal_item(d1, d2) for d1, d2 in zip(data_1_flattened, data_2_flattened)) + + +def _equal_item(d1: Any, d2: Any) -> bool: + if not isinstance(d1, type(d2)): + return False + equality = d1 == d2 + if isinstance(equality, torch.Tensor): + return bool(equality.all().item()) + if equality is True: + return True + return False + + +class CacheDataset(Dataset): + def __init__( + self, + dataset: Any, + cache_dir: str, + chunk_bytes: Optional[int], + chunk_size: Optional[int], + compression: Optional[str], + ): + """The `CacheDataset` is a dataset wraper to provide a beginner experience with the Cache. + + Arguments: + dataset: The dataset of the user + cache_dir: The folder where the chunks are written to. + chunk_bytes: The maximal number of bytes to write within a chunk. + chunk_sie: The maximal number of items to write to a chunk. + compression: The compression algorithm to use to reduce the size of the chunk. + + """ + self._dataset = dataset + self._cache = Cache(cache_dir, chunk_bytes=chunk_bytes, chunk_size=chunk_size, compression=compression) + self._is_deterministic = False + + def __len__(self) -> int: + return len(self._cache) if self._cache.filled else len(self._dataset) + + def __getitem__(self, index: int) -> Any: + data_1 = self._cache[index] if self._cache.filled else self._dataset[index] + if not self._cache.filled: + if not self._is_deterministic: + data2 = self._dataset[index] + if not _equal_items(data_1, data2): + raise ValueError( + f"Your dataset items aren't deterministic. Found {data_1} and {data2} for index {index}." + " HINT: Use the `lightning.data.cache.Cache` directly within your dataset." + ) + self._is_deterministic = True + self._cache[index] = data_1 + return data_1 + + +class CacheCollateFn: + """This CacheCollateFn is used to accelerate the processing of the data generated using the Cache. + + During the chunking phase, there is no need to return any data from the DataLoader reducing some time. + + Additionally, if the user makes their __getitem__ asynchronous, the collate executes them in parallel. + + """ + + def __init__(self, collate_fn: Optional[Callable] = None) -> None: + self.collate_fn = collate_fn or default_collate + + def __call__(self, items: List[Any]) -> Any: + if all(item is None for item in items): + return None + + # If the __getitem__ method is asynchornous, collect all the items. + if all(inspect.iscoroutine(item) for item in items): + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + items = loop.run_until_complete(asyncio.gather(*items)) + + return self.collate_fn([item for item in items if item is not None]) + + +class _SingleProcessDataLoaderIterPatch(_SingleProcessDataLoaderIter): + """This is overriden to inform the cache is done chunking.""" + + def _next_data(self) -> Any: + try: + data = None + while data is None: + data = super()._next_data() + return data + except StopIteration: + for v in self._dataset_fetcher.dataset.__dict__.values(): + if isinstance(v, Cache): + v.done() + if not v.filled: + v.merge(1) + raise StopIteration() + + +class WorkerLoop: + """Wrap the PyTorch DataLoader WorkerLoop to perform caching and profiling.""" + + def __init__(self, global_rank: int, profile: bool = False) -> None: + self._global_rank = global_rank + self._profile = profile + + def __call__(self, dataset_kind: _DatasetKind, *args: Any, **kwargs: Any) -> None: + from torch.utils.data._utils import worker + + from lightning.data.cache.cache import Cache + + rank = _WorkerEnv.detect().rank + enable_profiling = self._global_rank == 0 and rank == 0 and _VIZ_TRACKER_AVAILABLE and self._profile + + if enable_profiling: + from viztracer import VizTracer + + tracer = VizTracer(output_file=os.path.join(os.getcwd(), "trace.json")) + tracer.start() + + # Reload to remove the patching + reloaded_worker = reload(worker) + create_fetcher = _DatasetKind.create_fetcher + fetcher = None + + def create_fetcher_fn(*args: Any, **kwargs: Any) -> "_BaseDatasetFetcher": + nonlocal fetcher + fetcher = create_fetcher(*args, **kwargs) + return fetcher + + _DatasetKind.create_fetcher = create_fetcher_fn # type: ignore + + reloaded_worker._worker_loop(dataset_kind, *args, **kwargs) + + if dataset_kind == _DatasetKind.Map: + assert fetcher + for v in fetcher.dataset.__dict__.values(): + if isinstance(v, Cache): + v.done() + + if enable_profiling: + tracer.stop() + tracer.save() + + +class _MultiProcessingDataLoaderIterPatch(_MultiProcessingDataLoaderIter): + def __init__(self, loader: DataLoader) -> None: + self._cache = loader._cache + self._num_workers = loader.num_workers + # Patch PyTorch worker loop to call the `cache.done()` method. + from torch.utils.data._utils import worker + + worker._worker_loop = WorkerLoop(loader._global_rank, loader._profile) + super().__init__(loader) + + def _shutdown_workers(self) -> None: + super()._shutdown_workers() + + # If the data isn't filled, we trigger an indedm merge + if not self._cache.filled: + self._cache.merge(self._num_workers) + + def _next_data(self) -> Any: + try: + data = None + while data is None: + data = super()._next_data() + return data + except StopIteration as e: + raise e + + +class LightningDataLoader(DataLoader): + __doc__ = DataLoader.__doc__ + + def __init__( + self, + dataset: Any, + *args: Any, + sampler: Optional[Sampler] = None, + batch_sampler: Optional[BatchSampler] = None, + num_workers: int = 0, + shuffle: bool = False, + generator: Optional[torch.Generator] = None, + batch_size: Optional[int] = None, + drop_last: bool = False, + cache_dir: Optional[str] = None, + chunk_bytes: Optional[int] = _DEFAULT_CHUNK_BYTES, + compression: Optional[str] = None, + profile: bool = False, + collate_fn: Optional[Callable] = None, + **kwargs: Any, + ) -> None: + if sampler: + raise ValueError( + "The LightningDataLoader relies on its own internal sampler. Passing a sampler isn't supported." + ) + + if batch_sampler: + raise ValueError( + "The LightningDataLoader relies on its own internal sampler. Passing a batch_sampler isn't supported." + ) + + if isinstance(dataset, IterableDataset): + raise ValueError("Only map-based dataset are supported by the LightningDataLoader for now.") + + if profile and not _VIZ_TRACKER_AVAILABLE: + raise ModuleNotFoundError("To enable DataLoader profiling, run `pip install viztracer`.") + + cache_list = [v for v in dataset.__dict__.values() if isinstance(v, Cache)] + + if len(cache_list) > 1: + raise ValueError( + "We found several Cache used as attributes from your dataset. Only one is support for now." + ) + + if len(cache_list) == 0: + if cache_dir is None: + raise ValueError("You should provide a `cache_dir` filepath to the LightningDataLoader.") + + dataset = CacheDataset(dataset, cache_dir, chunk_bytes, batch_size, compression) + cache = dataset._cache + else: + cache = cache_list[0] + + if not cache.filled and shuffle: + logger.info("Shuffle is ignored during the caching phase phase.") + + self._cache = cache + + distributed_env = _DistributedEnv.detect() + self._global_rank = distributed_env.global_rank + + batch_sampler = CacheBatchSampler( + len(dataset), + distributed_env.world_size, + self._global_rank, + num_workers, + batch_size or 1, + drop_last, + shuffle, + cache, + ) + + self._profile = profile + + super().__init__( + dataset, + *args, + batch_sampler=batch_sampler, # type: ignore + collate_fn=CacheCollateFn(collate_fn), + num_workers=num_workers, + **kwargs, + ) + + def _get_iterator(self) -> "_BaseDataLoaderIter": + """Overriden to ensure the `Cache.done()` method is triggered on iteration done.""" + if self.num_workers == 0: + return _SingleProcessDataLoaderIterPatch(self) + self.check_worker_number_rationality() + return _MultiProcessingDataLoaderIterPatch(self) diff --git a/src/lightning/data/cache/downloader.py b/src/lightning/data/cache/downloader.py new file mode 100644 index 0000000000..460d0e576d --- /dev/null +++ b/src/lightning/data/cache/downloader.py @@ -0,0 +1,74 @@ +# Copyright The Lightning AI team. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os +from abc import ABC, abstractmethod +from typing import Any, Dict, List, Type +from urllib import parse + + +class Downloader(ABC): + def __init__(self, remote_dir: str, cache_dir: str, chunks: List[Dict[str, Any]]): + self._remote_dir = remote_dir + self._cache_dir = cache_dir + self._chunks = chunks + + def download_chunk_from_index(self, chunk_index: int) -> None: + chunk_filename = self._chunks[chunk_index]["filename"] + local_chunkpath = os.path.join(self._cache_dir, chunk_filename) + remote_chunkpath = os.path.join(self._remote_dir, chunk_filename) + self.download_file(remote_chunkpath, local_chunkpath) + + @abstractmethod + def download_file(self, remote_chunkpath: str, local_chunkpath: str) -> None: + pass + + +class S3Downloader(Downloader): + @classmethod + def download_file(cls, remote_filepath: str, local_filepath: str) -> None: + import boto3 + from boto3.s3.transfer import TransferConfig + from botocore.config import Config + + obj = parse.urlparse(remote_filepath) + + if obj.scheme != "s3": + raise ValueError(f"Expected obj.scheme to be `s3`, instead, got {obj.scheme} for remote={remote_filepath}") + + extra_args: Dict[str, Any] = {} + + # Create a new session per thread + session = boto3.session.Session() + # Create a resource client using a thread's session object + s3 = session.client("s3", config=Config(read_timeout=None)) + # Threads calling S3 operations return RuntimeError (cannot schedule new futures after + # interpreter shutdown). Temporary solution is to have `use_threads` as `False`. + # Issue: https://github.com/boto/boto3/issues/3113 + s3.download_file( + obj.netloc, + obj.path.lstrip("/"), + local_filepath, + ExtraArgs=extra_args, + Config=TransferConfig(use_threads=False), + ) + + +# TODO: Add fsspec support +_DOWNLOADERS = {"s3://": S3Downloader} + + +def get_downloader_cls(remote_dir: str) -> Type[Downloader]: + for k, cls in _DOWNLOADERS.items(): + if remote_dir.startswith(k): + return cls + raise ValueError(f"The provided `remote_dir` {remote_dir} doesn't have a downloader associated.") diff --git a/src/lightning/data/cache/reader.py b/src/lightning/data/cache/reader.py new file mode 100644 index 0000000000..667360190b --- /dev/null +++ b/src/lightning/data/cache/reader.py @@ -0,0 +1,177 @@ +# Copyright The Lightning AI team. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +from threading import Lock, Thread +from time import sleep +from typing import Any, Dict, List, Optional, Tuple + +import numpy as np + +from lightning.data.cache.config import ChunksConfig +from lightning.data.cache.constants import _TORCH_2_1_0_AVAILABLE +from lightning.data.cache.sampler import ChunkedIndex +from lightning.data.cache.serializers import _SERIALIZERS, Serializer +from lightning.data.datasets.env import _DistributedEnv, _WorkerEnv + +if _TORCH_2_1_0_AVAILABLE: + from torch.utils._pytree import PyTree, tree_unflatten + + +class PrepareChunksThread(Thread): + """This thread is responsible to download the chunks associated to a given worker.""" + + def __init__(self, config: ChunksConfig) -> None: + super().__init__(daemon=True) + self._config = config + self._chunks_index_to_be_processed: List[int] = [] + self._chunks_index_to_ready: List[int] = [] + self._lock = Lock() + + def add(self, chunk_indices: List[int]) -> None: + """Receive the list of the chunk indices to download for the current epoch.""" + with self._lock: + self._chunks_index_to_be_processed.extend(chunk_indices) + + def run(self) -> None: + while True: + with self._lock: + if len(self._chunks_index_to_be_processed) == 0: + sleep(0.007) + continue + + chunk_index = self._chunks_index_to_be_processed.pop(0) + + # TODO: Implement eviction + self._config.download_chunk_from_index(chunk_index) + self._chunks_index_to_ready.append(chunk_index) + + +class BinaryReader: + def __init__(self, cache_dir: str, remote_dir: Optional[str] = None, compression: Optional[str] = None) -> None: + """The BinaryReader enables to read chunked dataset in an efficient way. + + Arguments: + cache_dir: The path to cache folder. + remote_dir: The path to a remote folder where the data are located. + The scheme needs to be added to the path. + compression: The algorithm to decompress the chunks. + + """ + super().__init__() + self._cache_dir = cache_dir + self._remote_dir = remote_dir + + if not os.path.exists(self._cache_dir): + raise FileNotFoundError(f"The provided cache_dir `{self._cache_dir}` doesn't exist.") + + self._compression = compression + self._intervals: Optional[List[str]] = None + + self._serializers: Dict[str, Serializer] = _SERIALIZERS + self._distributed_env = _DistributedEnv.detect() + self._rank: Optional[int] = None + self._config: Optional[ChunksConfig] = None + self._prepare_thread: Optional[PrepareChunksThread] = None + + def _get_chunk_index_from_index(self, index: int) -> int: + # Load the config containing the index + if self._config is None and self._try_load_config() is None: + raise Exception("The reader index isn't defined.") + + return self._config._get_chunk_index_from_index(index) # type: ignore + + def _try_load_config(self) -> Optional[ChunksConfig]: + """Try to load the chunks config if the index files are available.""" + self._config = ChunksConfig.load(self._cache_dir, self._remote_dir) + return self._config + + @property + def config(self) -> ChunksConfig: + if self._config is None: + raise RuntimeError("The config should be defined.") + return self._config + + @property + def rank(self) -> int: + """Returns the rank of the writer.""" + if self._rank is None: + self._worker_env = _WorkerEnv.detect() + self._rank = self._distributed_env.global_rank * self._worker_env.world_size + self._worker_env.rank + return self._rank + + def read(self, index: ChunkedIndex) -> Any: + """Read an item for the given from a chunk. + + If the chunk isn't available locally or in memory, it will be downloaded. + + Prefetching should reduce the wait time to be the batch available. + + """ + if not isinstance(index, ChunkedIndex): + raise ValueError("The Reader.read(...) method expects a chunked Index.") + + # Load the config containing the index + if self._config is None and self._try_load_config() is None: + raise Exception("The reader index isn't defined.") + + # Create and start the prepare chunks thread + if index.chunk_indexes is not None and self._prepare_thread is None and self._config: + self._prepare_thread = PrepareChunksThread(self._config) + self._prepare_thread.start() + self._prepare_thread.add(index.chunk_indexes) + + # Fetch the element + chunk_filepath, begin, _ = self.config[index] + raw_item_data = self.load_item_from_chunk(index.index, chunk_filepath, begin) + return self.deserialize(raw_item_data) + + def deserialize(self, raw_item_data: bytes) -> "PyTree": + """Deserialize the raw bytes into their python equivalent.""" + idx = len(self.config.data_format) * 4 + sizes = np.frombuffer(raw_item_data[:idx], np.uint32) + data = [] + for size, data_format in zip(sizes, self.config.data_format): + serializer = self._serializers[data_format] + data_bytes = raw_item_data[idx : idx + size] + data.append(serializer.deserialize(data_bytes)) + idx += size + return tree_unflatten(data, self.config.config["data_spec"]) + + def load_item_from_chunk(self, index: int, chunk_filepath: str, begin: int) -> bytes: + offset = (1 + (index - begin)) * 4 + + while not os.path.exists(chunk_filepath): + sleep(0.0001) + + with open(chunk_filepath, "rb", 0) as fp: + fp.seek(offset) + pair = fp.read(8) + begin, end = np.frombuffer(pair, np.uint32) + fp.seek(begin) + data = fp.read(end - begin) + return data + + def get_length(self) -> int: + """Get the number of samples across all chunks.""" + if self._config is None and self._try_load_config() is None: + raise Exception("The reader index isn't defined.") + + return len(self.config) + + def get_chunk_interval(self) -> List[Tuple[int, int]]: + """Get the index interval of each chunk.""" + if self._config is None and self._try_load_config() is None: + raise Exception("The reader index isn't defined.") + + return self.config.intervals diff --git a/src/lightning/data/cache/sampler.py b/src/lightning/data/cache/sampler.py new file mode 100644 index 0000000000..fe88a2cf2c --- /dev/null +++ b/src/lightning/data/cache/sampler.py @@ -0,0 +1,226 @@ +# Copyright The Lightning AI team. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +from dataclasses import dataclass +from typing import Any, Dict, Iterator, List, Optional, Union + +import numpy as np + +logger = logging.Logger(__name__) + + +@dataclass +class ChunkedIndex: + index: int + chunk_index: int + chunk_indexes: Optional[List[int]] = None + + +class CacheBatchSampler: + def __init__( + self, + dataset_size: int, + num_replicas: int, + global_rank: int, + num_workers: int, + batch_size: int, + drop_last: bool, + shuffle: bool, + cache: Any, + ): + """The CacheBatchSampler handles the generation of batch indices. + + If the cache isn't filled, the batch sampler alternates with ordered indices for the writer to chunk the dataset + If the cache is filled, it acts as normal BatchSampler. + + Arguments: + dataset_size: The size of the dataset. + num_replicas: The number of processes involves in the distributed training. + global_rank: The global_rank of the given process + num_workers: The number of workers provided to the DataLoader. + batch_size: The number of items in a batch. + drop_last: Whether to drop the last batch of data. + shuffle: Whether the data should be shuffled. + cache: The cache associated to the dataset. + + """ + self._dataset_size = dataset_size + self._num_replicas = num_replicas + self._global_rank = global_rank + self._cache = cache + self._shuffle = shuffle + self._num_workers = num_workers or 1 + self._shuffled_chunk_intervals = None + self._batch_size = batch_size + + self._drop_last = drop_last + self._length = 0 + + # Before starting, ensures the chunk indices are properly defined. + self._validate() + + def _validate(self) -> None: + """Checks each worker is getting sucessive indices.""" + if self._num_workers > 1 and not self._cache.filled: + batches: Dict[int, Any] = {} + for batch_index, batch_indices in enumerate(self): + self._length += 1 + worker_index = batch_index % self._num_workers + if worker_index not in batches: + batches[worker_index] = [] + batches[worker_index].extend(batch_indices) + elif len(batch_indices) > 0: + batches[worker_index].extend(batch_indices) + + for indices in batches.values(): + indices = np.asarray(indices) + diff = indices[1:] - (indices[:-1] + 1) + if diff.sum() != 0: + raise RuntimeError("This shouldn't have happened. There is a bug in the CacheSampler.") + + def __iter__(self) -> Iterator[List[Union[int, ChunkedIndex]]]: + # When the cache is filled, we need to iterate though the chunks + if self._cache.filled: + if self._num_replicas == 1: + return self.__iter_from_chunks_non_distributed__() + return self.__iter_from_chunks_distributed__() + + # shuffle is ignored while building the binarized version of the dataset + if self._num_replicas == 1: + return self.__iter_non_distributed__() + return self.__iter_distributed__() + + def __iter_non_distributed__(self) -> Iterator[List[Union[int, ChunkedIndex]]]: + worker_size = self._dataset_size // self._num_workers + indices = list(range(self._dataset_size)) + worker_indices = [] + for worker_idx in range(self._num_workers): + is_last = worker_idx == self._num_workers - 1 + start = worker_idx * worker_size + end = self._dataset_size if is_last else (worker_idx + 1) * worker_size + worker_indices.append(indices[start:end]) + + assert sum([len(s) for s in worker_indices]) == self._dataset_size + + worker_indices_batches = [self._chunk_list(indices, self._batch_size) for indices in worker_indices] + + yield from self.__iter_indices_per_workers__(worker_indices_batches) + + def __iter_distributed__(self) -> Iterator[List[Union[int, ChunkedIndex]]]: + self.indices = list(range(self._dataset_size)) + replica_size = self._dataset_size // self._num_replicas + worker_size = self._dataset_size // (self._num_replicas * self._num_workers) + for rank in range(self._num_replicas): + if rank != self._global_rank: + continue + + is_last_replica = rank == self._num_replicas - 1 + start_replica = rank * replica_size + end_replica = self._dataset_size if is_last_replica else (rank + 1) * replica_size + replica_indices = self.indices[start_replica:end_replica] + + replica_size = len(replica_indices) + + worker_indices = [] + for worker_idx in range(self._num_workers): + is_last_worker = worker_idx == self._num_workers - 1 + start_worker = worker_idx * worker_size + end_worker = replica_size if is_last_worker else (worker_idx + 1) * worker_size + worker_indices.append(replica_indices[start_worker:end_worker]) + + assert sum([len(s) for s in worker_indices]) == len(replica_indices) + + worker_indices_batches = [self._chunk_list(indices, self._batch_size) for indices in worker_indices] + + yield from self.__iter_indices_per_workers__(worker_indices_batches) + + def __iter_from_chunks_non_distributed__(self) -> Iterator[List[Union[int, ChunkedIndex]]]: + chunk_intervals = self._cache.get_chunk_interval() + shuffled_indexes = np.random.permutation(range(len(chunk_intervals))) + shuffled_chunk_intervals = np.asarray(chunk_intervals)[shuffled_indexes] + yield from self.__iter_from_shuffled_chunks(shuffled_indexes.tolist(), shuffled_chunk_intervals) + + def __iter_from_chunks_distributed__(self) -> Iterator[List[Union[int, ChunkedIndex]]]: + chunk_intervals = self._cache.get_chunk_interval() + shuffled_indexes = np.random.permutation(range(len(chunk_intervals))) + shuffled_chunk_intervals = np.asarray(chunk_intervals)[shuffled_indexes] + + replica_chunks = [] + replica_intervals = [] + for index, (chunk_index, chunk_interval) in enumerate(zip(shuffled_indexes, shuffled_chunk_intervals)): + if index % self._num_replicas == self._global_rank: + replica_chunks.append(chunk_index) + replica_intervals.append(chunk_interval) + + yield from self.__iter_from_shuffled_chunks(replica_chunks, replica_intervals) + + def __iter_from_shuffled_chunks( + self, shuffled_indexes: List[int], shuffled_chunk_intervals: List[List[int]] + ) -> Iterator[List[Union[int, ChunkedIndex]]]: + chunks_per_workers: List[List[int]] = [[] for _ in range(self._num_workers)] + for i, chunk_index in enumerate(shuffled_indexes): + chunks_per_workers[i % self._num_workers].append(chunk_index) + + indices_per_workers: List[List[ChunkedIndex]] = [[] for _ in range(self._num_workers)] + + for i, (chunk_index, chunk_interval) in enumerate(zip(shuffled_indexes, shuffled_chunk_intervals)): + worker_id = i % self._num_workers + interval_indices = np.arange(chunk_interval[0], chunk_interval[1]) + shuffled_interval_indices = np.random.permutation(interval_indices).tolist() + is_empty = len(indices_per_workers[worker_id]) == 0 + indices_per_workers[worker_id].extend( + [ + ChunkedIndex( + index, + chunk_index, + chunk_indexes=chunks_per_workers[worker_id] if j == 0 and is_empty else None, + ) + for j, index in enumerate(shuffled_interval_indices) + ] + ) + + indices_per_workers_splitted = [self._chunk_list(indices, self._batch_size) for indices in indices_per_workers] + + yield from self.__iter_indices_per_workers__(indices_per_workers_splitted) + + def __len__(self) -> int: + return self._length + + def __iter_indices_per_workers__( + self, indices_per_workers: List[List[List[Union[int, ChunkedIndex]]]] + ) -> Iterator[List[Union[int, ChunkedIndex]]]: + batches: List[List[Union[int, ChunkedIndex]]] = [] + counter = 0 + while sum([len(v) for v in indices_per_workers]) != 0: + worker_indices = indices_per_workers[counter % self._num_workers] + if len(worker_indices) == 0: + batches.append([]) + else: + batches.append(worker_indices.pop(0)) + counter += 1 + + while True: + if len(batches[-1]) == 0: + batches.pop(-1) + else: + break + + yield from batches + + def _chunk_list(self, arr: List[Any], chunk_size: int) -> List[List[Any]]: + out = [] + for i in range(0, len(arr), chunk_size): + slice_item = slice(i, i + chunk_size, 1) + out.append(arr[slice_item]) + return out diff --git a/src/lightning/data/cache/serializers.py b/src/lightning/data/cache/serializers.py new file mode 100644 index 0000000000..6ba5d5a4db --- /dev/null +++ b/src/lightning/data/cache/serializers.py @@ -0,0 +1,223 @@ +# Copyright The Lightning AI team. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import pickle +from abc import ABC, abstractmethod +from collections import OrderedDict +from io import BytesIO +from typing import Any, Optional, Tuple, Union + +import numpy as np +import torch +from lightning_utilities.core.imports import RequirementCache + +_PIL_AVAILABLE = RequirementCache("PIL") +_TORCH_VISION_AVAILABLE = RequirementCache("torchvision") + +if _PIL_AVAILABLE: + from PIL import Image + from PIL.JpegImagePlugin import JpegImageFile +else: + Image = None + JpegImageFile = None + +if _TORCH_VISION_AVAILABLE: + from torchvision.io import decode_jpeg + + +class Serializer(ABC): + """The base interface for any serializers. + + A Serializer serialize and deserialize to and from bytes. + + """ + + @abstractmethod + def serialize(self, data: Any) -> Tuple[bytes, Optional[str]]: + pass + + @abstractmethod + def deserialize(self, data: bytes) -> Any: + pass + + @abstractmethod + def can_serialize(self, data: Any) -> bool: + pass + + +class PILSerializer(Serializer): + """The PILSerializer serialize and deserialize PIL Image to and from bytes.""" + + def serialize(self, item: Image) -> Tuple[bytes, Optional[str]]: + mode = item.mode.encode("utf-8") + width, height = item.size + raw = item.tobytes() + ints = np.array([width, height, len(mode)], np.uint32) + return ints.tobytes() + mode + raw, None + + def deserialize(self, data: bytes) -> Any: + idx = 3 * 4 + width, height, mode_size = np.frombuffer(data[:idx], np.uint32) + idx2 = idx + mode_size + mode = data[idx:idx2].decode("utf-8") + size = width, height + raw = data[idx2:] + return Image.frombytes(mode, size, raw) # pyright: ignore + + def can_serialize(self, item: Any) -> bool: + return isinstance(item, Image.Image) and not isinstance(item, JpegImageFile) + + +class IntSerializer(Serializer): + """The IntSerializer serialize and deserialize integer to and from bytes.""" + + def serialize(self, item: int) -> Tuple[bytes, Optional[str]]: + return str(item).encode("utf-8"), None + + def deserialize(self, data: bytes) -> int: + return int(data.decode("utf-8")) + + def can_serialize(self, item: Any) -> bool: + return isinstance(item, int) + + +class JPEGSerializer(Serializer): + """The JPEGSerializer serialize and deserialize JPEG image to and from bytes.""" + + def serialize(self, item: Image) -> Tuple[bytes, Optional[str]]: + if isinstance(item, JpegImageFile): + if not hasattr(item, "filename"): + raise ValueError( + "The JPEG Image's filename isn't defined. HINT: Open the image in your Dataset __getitem__ method." + ) + with open(item.filename, "rb") as f: + return f.read(), None + raise TypeError(f"The provided itemect should be of type {JpegImageFile}. Found {item}.") + + def deserialize(self, data: bytes) -> Union[JpegImageFile, torch.Tensor]: + if _TORCH_VISION_AVAILABLE: + array = torch.frombuffer(data, dtype=torch.uint8) + return decode_jpeg(array) + + inp = BytesIO(data) + return Image.open(inp) + + def can_serialize(self, item: Any) -> bool: + return isinstance(item, JpegImageFile) + + +class BytesSerializer(Serializer): + """The BytesSerializer serialize and deserialize integer to and from bytes.""" + + def serialize(self, item: bytes) -> Tuple[bytes, Optional[str]]: + return item, None + + def deserialize(self, item: bytes) -> bytes: + return item + + def can_serialize(self, item: bytes) -> bool: + return isinstance(item, bytes) + + +_TORCH_DTYPES_MAPPING = { + 0: torch.float32, + 1: torch.float, + 2: torch.float64, + 3: torch.double, + 4: torch.complex64, + 5: torch.cfloat, + 6: torch.complex128, + 7: torch.cdouble, + 8: torch.float16, + 9: torch.half, + 10: torch.bfloat16, # Not supported https://github.com/pytorch/pytorch/issues/110285 + 11: torch.uint8, + 12: torch.int8, + 13: torch.int16, + 14: torch.short, + 15: torch.int32, + 16: torch.int, + 17: torch.int64, + 18: torch.long, + 19: torch.bool, +} + + +class TensorSerializer(Serializer): + """The TensorSerializer serialize and deserialize tensor to and from bytes.""" + + def __init__(self) -> None: + super().__init__() + self._dtype_to_indice = {v: k for k, v in _TORCH_DTYPES_MAPPING.items()} + + def serialize(self, item: torch.Tensor) -> Tuple[bytes, Optional[str]]: + dtype_indice = self._dtype_to_indice[item.dtype] + data = [np.uint32(dtype_indice).tobytes()] + data.append(np.uint32(len(item.shape)).tobytes()) + for dim in item.shape: + data.append(np.uint32(dim).tobytes()) + data.append(item.numpy().tobytes()) + return b"".join(data), None + + def deserialize(self, data: bytes) -> torch.Tensor: + dtype_indice = np.frombuffer(data[0:4], np.uint32).item() + dtype = _TORCH_DTYPES_MAPPING[dtype_indice] + shape_size = np.frombuffer(data[4:8], np.uint32).item() + shape = [] + for shape_idx in range(shape_size): + shape.append(np.frombuffer(data[8 + 4 * shape_idx : 8 + 4 * (shape_idx + 1)], np.uint32).item()) + tensor = torch.frombuffer(data[8 + 4 * (shape_idx + 1) : len(data)], dtype=dtype) + return torch.reshape(tensor, torch.Size(shape)) + + def can_serialize(self, item: torch.Tensor) -> bool: + return isinstance(item, torch.Tensor) and type(item) == torch.Tensor + + +class PickleSerializer(Serializer): + """The PickleSerializer serialize and deserialize python objects to and from bytes.""" + + def serialize(self, item: Any) -> Tuple[bytes, Optional[str]]: + return pickle.dumps(item), None + + def deserialize(self, data: bytes) -> Any: + return pickle.loads(data) + + def can_serialize(self, _: Any) -> bool: + return True + + +class FileSerializer(Serializer): + def serialize(self, filepath: str) -> Tuple[bytes, Optional[str]]: + _, file_extension = os.path.splitext(filepath) + with open(filepath, "rb") as f: + return f.read(), file_extension.replace(".", "").lower() + + def deserialize(self, data: bytes) -> Any: + pass + + def can_serialize(self, data: Any) -> bool: + return isinstance(data, str) and os.path.exists(data) + + +_SERIALIZERS = OrderedDict( + **{ + "file": FileSerializer(), + "pil": PILSerializer(), + "int": IntSerializer(), + "jpeg": JPEGSerializer(), + "bytes": BytesSerializer(), + "tensor": TensorSerializer(), + "pickle": PickleSerializer(), + } +) diff --git a/src/lightning/data/cache/writer.py b/src/lightning/data/cache/writer.py new file mode 100644 index 0000000000..29981a89f4 --- /dev/null +++ b/src/lightning/data/cache/writer.py @@ -0,0 +1,305 @@ +# Copyright The Lightning AI team. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json +import os +from time import sleep +from typing import Any, Dict, List, Optional + +import numpy as np + +from lightning.data.cache.compression import _COMPRESSORS, Compressor +from lightning.data.cache.constants import _INDEX_FILENAME, _TORCH_2_1_0_AVAILABLE +from lightning.data.cache.serializers import _SERIALIZERS, Serializer +from lightning.data.datasets.env import _DistributedEnv, _WorkerEnv + +if _TORCH_2_1_0_AVAILABLE: + from torch.utils._pytree import PyTree, tree_flatten, treespec_dumps + + +class BinaryWriter: + def __init__( + self, + cache_dir: str, + chunk_size: Optional[int] = None, + chunk_bytes: Optional[int] = None, + compression: Optional[str] = None, + ): + """The BinaryWriter enables to chunk dataset into an efficient streaming format for cloud training. + + Arguments: + cache_dir: The path to where the chunks will be saved. + chunk_bytes: The maximum number of bytes within a chunk. + chunk_size: The maximum number of items within a chunk. + compression: The compression algorithm to use. + + """ + self._cache_dir = cache_dir + + if not os.path.exists(self._cache_dir): + raise FileNotFoundError(f"The provided cache directory `{self._cache_dir}` doesn't exist.") + + if (chunk_size is None and chunk_bytes is None) or (chunk_size and chunk_bytes): + raise ValueError("Either one of the `chunk_size` or the `chunk_bytes` need to be provided.") + + self._serializers: Dict[str, Serializer] = _SERIALIZERS + self._chunk_size = chunk_size + self._chunk_bytes = chunk_bytes + self._compression = compression + + self._data_format: Optional[List[str]] = None + self._data_spec: Optional[PyTree] = None + + if self._compression: + if len(_COMPRESSORS) == 0: + raise ValueError("No compresion algorithms are installed.") + if self._compression not in _COMPRESSORS: + raise ValueError( + f"The provided compression {self._compression} isn't available in {sorted(_COMPRESSORS)}" + ) + self._compressor: Compressor = _COMPRESSORS[self._compression] + + self._current_chunk_bytes = 0 + self._chunk_index = 0 + self._serialized_items: List[bytes] = [] + self._chunks_info: List[Dict[str, Any]] = [] + self._indexes: List[int] = [] + self._worker_env: Optional[_WorkerEnv] = None + self._rank: Optional[int] = None + self._is_done = False + self._distributed_env = _DistributedEnv.detect() + + @property + def filled(self) -> bool: + """Returns whether the caching phase is done.""" + if self._is_done: + return True + files = os.listdir(self._cache_dir) + index_files = [f for f in files if f.endswith(_INDEX_FILENAME)] + worker_end = _WorkerEnv.detect() + self._is_done = len(index_files) == self._distributed_env.world_size * worker_end.world_size + return self._is_done + + @property + def rank(self) -> int: + """Returns the rank of the writer.""" + if self._rank is None: + self._worker_env = _WorkerEnv.detect() + self._rank = self._distributed_env.global_rank * self._worker_env.world_size + self._worker_env.rank + return self._rank + + def get_config(self) -> Dict[str, Any]: + """Returns the config of the writer.""" + out = { + "compression": self._compression, + "chunk_size": self._chunk_size, + "chunk_bytes": self._chunk_bytes, + "data_format": self._data_format, + "data_spec": treespec_dumps(self._data_spec) if self._data_spec else None, + } + return out + + def serialize(self, items: Any) -> bytes: + """Serialize a dictionary into its binary format.""" + + # Flatten the items provided by the users + flattened, data_spec = tree_flatten(items) + + # Collect the sizes and associated bytes for each item + sizes: List[int] = [] + data: List[bytes] = [] + + data_format: List[str] = [] + for item in flattened: + data_format.append(self._serialize(item, sizes, data)) + + if self._data_format is None: + self._data_format = data_format + elif self._data_format != data_format: + raise Exception( + f"The data format changed between items. Found {data_format} instead of {self._data_format}." + ) + + if self._data_spec is None: + self._data_spec = data_spec + elif self._data_spec != data_spec: + raise Exception(f"The data format changed between items. Found {data_spec} instead of {self._data_spec}.") + + # Concatenante into a single byte array + head = np.array(sizes, np.uint32).tobytes() + body = b"".join(data) + return head + body + + def _serialize(self, item: Any, sizes: List[int], data: List[bytes]) -> str: + """Serialize a given item and append its size and bytes to the sizes and data array.""" + for serializer_name, serializer in self._serializers.items(): + if serializer.can_serialize(item): + serialized_item, name = serializer.serialize(item) + data.append(serialized_item) + sizes.append(len(serialized_item)) + return name or serializer_name + raise ValueError(f"The provided item isn't serializable. Found {item}") + + def _create_chunk(self, filename: str) -> bytes: + """Create a binary chunk from all the binarized items.""" + num_items = np.uint32(len(self._serialized_items)) + sizes = list(map(len, self._serialized_items)) + offsets = np.array([0] + sizes).cumsum().astype(np.uint32) + offsets += len(num_items.tobytes()) + len(offsets.tobytes()) + sample_data = b"".join(self._serialized_items) + data = num_items.tobytes() + offsets.tobytes() + sample_data + offsets = offsets.tolist() + mapping = {} + for i in range(len(self._indexes)): + mapping[self._indexes[i]] = [offsets[i], offsets[i + 1]] + + assert len(mapping) == len(self._indexes) + assert (self._indexes[-1] - self._indexes[0] + 1) == len(self._serialized_items) + + chunk_info = { + "chunk_bytes": self._current_chunk_bytes, + "chunk_size": len(self._serialized_items), + "filename": filename, + "interval": [self._indexes[0], self._indexes[-1] + 1], + } + + self._chunks_info.append(chunk_info) + + return data + + def write_chunk(self) -> None: + """Write a chunk to the filesystem.""" + if self._compression: + filename = f"chunk-{self.rank}-{self._chunk_index}.{self._compression}.bin" + else: + filename = f"chunk-{self.rank}-{self._chunk_index}.bin" + self.write_chunk_to_file(self._create_chunk(filename), filename) + + def reset(self) -> None: + """Reset the writer to handle the next chunk.""" + self._serialized_items = [] + self._indexes = [] + self._current_chunk_bytes = 0 + + def __setitem__(self, index: int, items: Any) -> None: + """Store an item to a chunk. + + The index needs to be provided in order. + + This is handled by the samplers automatically. This ensures we can map an index to a shard from an interval. + + """ + # Serialize the items + serialized_items = self.serialize(items) + serialized_items_size = len(serialized_items) + + # Check whether it is time to write a chunk + should_write = ( + self._chunk_bytes and self._chunk_bytes < self._current_chunk_bytes + serialized_items_size + ) or (self._chunk_size and len(self._indexes) >= self._chunk_size) + + if should_write: + if self._current_chunk_bytes == 0: + raise Exception( + f"The provided chunk_size {self._chunk_bytes} is too small." + f" You should use a multiple of {serialized_items_size} bytes." + ) + self.write_chunk() + self.reset() + self._chunk_index += 1 + + # Store the serialized items into the chunk. + self._serialized_items.append(serialized_items) + self._current_chunk_bytes += serialized_items_size + + # Validate the index are provided in an incremental order + # This is required to ensure we can find efficiently a chunk index from an index using the chunk interval + if self._indexes: + assert self._indexes[-1] == index - 1, (self._indexes, index - 1) + + # Store the index + self._indexes.append(index) + + def write_chunk_to_file( + self, + raw_data: bytes, + filename: str, + ) -> None: + """Write chunk bytes to a file.""" + # Whether to compress the raw bytes + if self._compression: + raw_data = self._compressor.compress(raw_data) + + # Write the binary chunk file + with open(os.path.join(self._cache_dir, filename), "wb") as out: + out.write(raw_data) + + def write_chunks_index(self) -> None: + """Write the chunks index to a JSON file.""" + filepath = os.path.join(self._cache_dir, f"{self.rank}.{_INDEX_FILENAME}") + config = self.get_config() + with open(filepath, "w") as out: + json.dump({"chunks": self._chunks_info, "config": config}, out, sort_keys=True) + + def done(self) -> None: + """Called when StopIteration is triggered.""" + if self.filled: + return + if self._serialized_items: + self.write_chunk() + self.write_chunks_index() + self.reset() + self._is_done = True + + def merge(self, num_workers: int = 1) -> None: + """Once all the workers have written their own index, the merge function is responsible to read and merge them + into a single index.""" + num_workers = num_workers or 1 + + # Only for non rank 0 + if self.rank != 0: + while not os.path.exists(os.path.join(self._cache_dir, _INDEX_FILENAME)): + sleep(0.001) + return + + # Wait for all indexes to be available + is_done = False + while not is_done: + files = os.listdir(self._cache_dir) + if _INDEX_FILENAME in files: + return + index_files = [f for f in files if f.endswith(_INDEX_FILENAME)] + is_done = len(index_files) == self._distributed_env.world_size * num_workers + sleep(0.001) + + # Read the index and append the chunks together + chunks_info = [] + config = None + for index_filename in sorted(index_files): + chunk_path = os.path.join(self._cache_dir, index_filename) + with open(chunk_path) as f: + data = json.load(f) + + if config is None: + config = data["config"] + + elif config != data["config"]: + raise Exception("The config isn't consistent between chunks. This shouldn't have happened.") + + chunks_info.extend(data["chunks"]) + + os.remove(chunk_path) + + # Write down the collected index + with open(os.path.join(self._cache_dir, _INDEX_FILENAME), "w") as f: + json.dump({"chunks": chunks_info, "config": config}, f, sort_keys=True) diff --git a/src/lightning/data/datasets/env.py b/src/lightning/data/datasets/env.py index 51a9f21271..e369448ff8 100644 --- a/src/lightning/data/datasets/env.py +++ b/src/lightning/data/datasets/env.py @@ -1,7 +1,7 @@ -from typing import Optional +from typing import Callable, Optional import torch -from torch.utils.data import get_worker_info +from torch.utils.data import get_worker_info as torch_get_worker_info class _DistributedEnv: @@ -60,7 +60,7 @@ class _WorkerEnv: self.rank = rank @classmethod - def detect(cls) -> "_WorkerEnv": + def detect(cls, get_worker_info_fn: Optional[Callable] = None) -> "_WorkerEnv": """Automatically detects the number of workers and the current rank. Note: @@ -68,6 +68,7 @@ class _WorkerEnv: In such a case it will default to 1 worker """ + get_worker_info = get_worker_info_fn or torch_get_worker_info worker_info = get_worker_info() num_workers = worker_info.num_workers if worker_info is not None else 1 current_worker_rank = worker_info.id if worker_info is not None else 0 diff --git a/tests/tests_data/cache/__init__.py b/tests/tests_data/cache/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/tests_data/cache/test_cache.py b/tests/tests_data/cache/test_cache.py new file mode 100644 index 0000000000..d1605123a0 --- /dev/null +++ b/tests/tests_data/cache/test_cache.py @@ -0,0 +1,205 @@ +# Copyright The Lightning AI team. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import sys +from functools import partial + +import numpy as np +import pytest +import torch +from lightning import seed_everything +from lightning.data.cache import Cache +from lightning.data.cache.dataloader import LightningDataLoader +from lightning.data.datasets.env import _DistributedEnv +from lightning.fabric import Fabric +from lightning.pytorch.demos.boring_classes import RandomDataset +from lightning_utilities.core.imports import RequirementCache +from torch.utils.data import Dataset + +_PIL_AVAILABLE = RequirementCache("PIL") +_TORCH_VISION_AVAILABLE = RequirementCache("torchvision") + + +class ImageDataset(Dataset): + def __init__(self, tmpdir, cache, size, num_classes): + from PIL import Image + + self.data = [] + self.cache = cache + + seed_everything(42) + + for i in range(size): + path = os.path.join(tmpdir, f"img{i}.jpeg") + np_data = np.random.randint(255, size=(28, 28), dtype=np.uint8) + img = Image.fromarray(np_data).convert("L") + img.save(path, format="jpeg", quality=100) + self.data.append({"image": path, "class": np.random.randint(num_classes)}) + + def __len__(self): + return len(self.data) + + def __getitem__(self, index): + if self.cache.filled: + return self.cache[index] + self.cache[index] = {**self.data[index], "index": index} + return None + + +def _cache_for_image_dataset(num_workers, tmpdir, fabric=None): + from PIL import Image + from torchvision.transforms import PILToTensor + + dataset_size = 85 + + cache_dir = os.path.join(tmpdir, "cache") + distributed_env = _DistributedEnv.detect() + + cache = Cache(cache_dir, chunk_size=10) + dataset = ImageDataset(tmpdir, cache, dataset_size, 10) + dataloader = LightningDataLoader(dataset, num_workers=num_workers, batch_size=4) + + for _ in dataloader: + pass + + # Not strictly required but added to avoid race condition + if distributed_env.world_size > 1: + fabric.barrier() + + assert cache.filled + + for i in range(len(dataset)): + cached_data = dataset[i] + original_data = dataset.data[i] + assert cached_data["class"] == original_data["class"] + original_array = PILToTensor()(Image.open(original_data["image"])) + assert torch.equal(original_array, cached_data["image"]) + + if distributed_env.world_size == 1: + indexes = [] + dataloader = LightningDataLoader(dataset, num_workers=num_workers, batch_size=4) + for batch in dataloader: + if batch: + indexes.extend(batch["index"].numpy().tolist()) + assert len(indexes) == dataset_size + + seed_everything(42) + + dataloader = LightningDataLoader(dataset, num_workers=num_workers, batch_size=4, shuffle=True) + dataloader_iter = iter(dataloader) + + indexes = [] + for batch in dataloader_iter: + indexes.extend(batch["index"].numpy().tolist()) + + if distributed_env.world_size == 1: + assert len(indexes) == dataset_size + + indexes2 = [] + for batch in dataloader_iter: + indexes2.extend(batch["index"].numpy().tolist()) + + assert indexes2 != indexes + + +@pytest.mark.skipif( + condition=not _PIL_AVAILABLE or not _TORCH_VISION_AVAILABLE, reason="Requires: ['pil', 'torchvision']" +) +@pytest.mark.parametrize("num_workers", [0, 1, 2]) +def test_cache_for_image_dataset(num_workers, tmpdir): + cache_dir = os.path.join(tmpdir, "cache") + os.makedirs(cache_dir) + + _cache_for_image_dataset(num_workers, tmpdir) + + +def _fabric_cache_for_image_dataset(fabric, num_workers, tmpdir): + _cache_for_image_dataset(num_workers, tmpdir, fabric=fabric) + + +@pytest.mark.skipif( + condition=not _PIL_AVAILABLE or not _TORCH_VISION_AVAILABLE or sys.platform == "win32", + reason="Requires: ['pil', 'torchvision']", +) +@pytest.mark.parametrize("num_workers", [2]) +def test_cache_for_image_dataset_distributed(num_workers, tmpdir): + cache_dir = os.path.join(tmpdir, "cache") + os.makedirs(cache_dir) + + fabric = Fabric(accelerator="cpu", devices=2, strategy="ddp_spawn") + fabric.launch(partial(_fabric_cache_for_image_dataset, num_workers=num_workers, tmpdir=tmpdir)) + + +def test_cache_with_simple_format(tmpdir): + cache_dir = os.path.join(tmpdir, "cache1") + os.makedirs(cache_dir) + + cache = Cache(cache_dir, chunk_bytes=90) + + for i in range(100): + cache[i] = i + + cache.done() + cache.merge() + + for i in range(100): + assert i == cache[i] + + cache_dir = os.path.join(tmpdir, "cache2") + os.makedirs(cache_dir) + + cache = Cache(cache_dir, chunk_bytes=90) + + for i in range(100): + cache[i] = [i, {0: [i + 1]}] + + cache.done() + cache.merge() + + for i in range(100): + assert [i, {0: [i + 1]}] == cache[i] + + +def test_cache_with_auto_wrapping(tmpdir): + os.makedirs(os.path.join(tmpdir, "cache_1"), exist_ok=True) + + dataset = RandomDataset(64, 64) + dataloader = LightningDataLoader(dataset, cache_dir=os.path.join(tmpdir, "cache_1"), chunk_bytes=2 << 12) + for batch in dataloader: + assert isinstance(batch, torch.Tensor) + assert sorted(os.listdir(os.path.join(tmpdir, "cache_1"))) == [ + "chunk-0-0.bin", + "chunk-0-1.bin", + "chunk-0-2.bin", + "index.json", + ] + # Your dataset is optimised for the cloud + + class RandomDatasetAtRuntime(Dataset): + def __init__(self, size: int, length: int): + self.len = length + self.size = size + + def __getitem__(self, index: int) -> torch.Tensor: + return torch.randn(1, self.size) + + def __len__(self) -> int: + return self.len + + os.makedirs(os.path.join(tmpdir, "cache_2"), exist_ok=True) + dataset = RandomDatasetAtRuntime(64, 64) + dataloader = LightningDataLoader(dataset, cache_dir=os.path.join(tmpdir, "cache_2"), chunk_bytes=2 << 12) + with pytest.raises(ValueError, match="Your dataset items aren't deterministic"): + for batch in dataloader: + pass diff --git a/tests/tests_data/cache/test_sampler.py b/tests/tests_data/cache/test_sampler.py new file mode 100644 index 0000000000..d6528794a3 --- /dev/null +++ b/tests/tests_data/cache/test_sampler.py @@ -0,0 +1,112 @@ +from unittest import mock + +import pytest +from lightning import seed_everything +from lightning.data.cache.sampler import CacheBatchSampler + + +@pytest.mark.parametrize( + "params", + [ + ( + 21, + 1, + [[0, 1, 2], [7, 8, 9], [14, 15, 16], [3, 4, 5], [10, 11, 12], [17, 18, 19], [6], [13], [20]], + [[7, 0, 0], [1, 1, 1], [5, 5, 5], [0, 4, 4], [8, 3, 3], [2, 2, 2], [4], [3], [6]], + ), + ( + 11, + 1, + [[0, 1, 2], [3, 4, 5], [6, 7, 8], [], [], [9, 10]], + [[1, 1, 1], [3, 3], [0, 0, 0], [2, 2, 2]], + ), + (8, 1, [[0, 1], [2, 3], [4, 5, 6], [], [], [7]], [[1, 1, 2], [3], [0, 0], [2, 2]]), + (4, 1, [[0], [1], [2, 3]], [[0], [1], [2, 2]]), + ( + 9, + 1, + [[0, 1, 2], [3, 4, 5], [6, 7, 8]], + [[0, 0, 0], [1, 1, 1], [2, 2, 2]], + ), + ( + 19, + 1, + [[0, 1, 2], [6, 7, 8], [12, 13, 14], [3, 4, 5], [9, 10, 11], [15, 16, 17], [], [], [18]], + [[0, 0, 0], [1, 1, 1], [5, 5, 5], [2, 2, 2], [4, 4, 4], [3, 3, 3], [6]], + ), + (19, 2, [[0, 1, 2], [3, 4, 5], [6, 7, 8]], [[0, 0, 0], [5, 5, 5], [4, 4, 4], [6]]), + ], +) +def test_cache_batch_sampler(params): + seed_everything(42) + + cache = mock.MagicMock() + cache.filled = False + if params[1] > 1: + batch_sampler = CacheBatchSampler(params[0], params[1], 0, 3, 3, False, True, cache) + batches = [] + for batch in batch_sampler: + batches.append(batch) + assert batches == params[2], batches + + batch_sampler = CacheBatchSampler(params[0], 1, 0, 3, 3, False, True, cache) + batches = [] + for batch in batch_sampler: + batches.append(batch) + + chunks_interval = [[batch[0], batch[-1] + 1] for batch in batches if len(batch)] + else: + batch_sampler = CacheBatchSampler(params[0], params[1], 0, 3, 3, False, True, cache) + batches = [] + for batch in batch_sampler: + batches.append(batch) + assert batches == params[2], batches + + chunks_interval = [[batch[0], batch[-1] + 1] for batch in batches if len(batch)] + + cache.filled = True + cache.get_chunk_interval.return_value = chunks_interval + + seed_everything(42) + + batch_sampler = CacheBatchSampler(params[0], params[1], 0, 3, 3, False, True, cache) + + batches_1 = [] + for batch in batch_sampler: + batches_1.append(batch) + + def validate_batch(data, check_values): + if params[1] == 1: + assert all(b[0].chunk_indexes is not None for b in data[:3]) + assert all(b[1].chunk_indexes is None if len(b) > 1 else True for b in data[:3]) + assert all(b[0].chunk_indexes is None if len(b) else True for b in data[3:]) + if check_values: + assert [[x.chunk_index for x in d] for d in data] == params[3] + else: + assert all(b[0].chunk_indexes is not None for b in data[:3]) + assert all(b[1].chunk_indexes is None if len(b) > 1 else True for b in data[:3]) + assert all(b[0].chunk_indexes is None if len(b) else True for b in data[3:]) + if check_values: + assert [[x.chunk_index for x in d] for d in data] == params[3] + + validate_batch(batches_1, True) + + batches_2 = [] + for batch in batch_sampler: + batches_2.append(batch) + + validate_batch(batches_2, False) + if params[1] == 1: + assert batches_1 != batches_2 + + +def test_batch_sampler_imagenet(): + """Validate the Imagenet dataset is valid.""" + dataset_size = 1281167 + world_size = 1 + rank = 0 + num_workers = 32 + batch_size = 8 + cache = mock.MagicMock() + cache.filled = False + CacheBatchSampler(dataset_size, world_size, rank, num_workers, batch_size, False, True, cache) diff --git a/tests/tests_data/cache/test_serializer.py b/tests/tests_data/cache/test_serializer.py new file mode 100644 index 0000000000..cc4cefbead --- /dev/null +++ b/tests/tests_data/cache/test_serializer.py @@ -0,0 +1,115 @@ +# Copyright The Lightning AI team. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import sys +from time import time + +import numpy as np +import pytest +import torch +from lightning import seed_everything +from lightning.data.cache.serializers import ( + _SERIALIZERS, + _TORCH_DTYPES_MAPPING, + IntSerializer, + PickleSerializer, + PILSerializer, + TensorSerializer, +) +from lightning_utilities.core.imports import RequirementCache + +_PIL_AVAILABLE = RequirementCache("PIL") + + +def test_serializers(): + assert list(_SERIALIZERS.keys()) == ["file", "pil", "int", "jpeg", "bytes", "tensor", "pickle"] + + +def test_int_serializer(): + serializer = IntSerializer() + + for i in range(100): + data, _ = serializer.serialize(i) + assert isinstance(data, bytes) + assert i == serializer.deserialize(data) + + +@pytest.mark.skipif(condition=not _PIL_AVAILABLE, reason="Requires: ['pil']") +@pytest.mark.parametrize("mode", ["I", "L", "RGB"]) +def test_pil_serializer(mode): + serializer = PILSerializer() + + from PIL import Image + + np_data = np.random.randint(255, size=(28, 28), dtype=np.uint32) + img = Image.fromarray(np_data).convert(mode) + + data, _ = serializer.serialize(img) + assert isinstance(data, bytes) + + deserialized_img = serializer.deserialize(data) + deserialized_img = deserialized_img.convert("I") + np_dec_data = np.asarray(deserialized_img, dtype=np.uint32) + assert isinstance(deserialized_img, Image.Image) + + # Validate data content + assert np.array_equal(np_data, np_dec_data) + + +@pytest.mark.skipif(sys.platform == "win32", reason="Not supported on windows") +def test_tensor_serializer(): + seed_everything(42) + + serializer_tensor = TensorSerializer() + serializer_pickle = PickleSerializer() + + ratio_times = [] + ratio_bytes = [] + shapes = [(10,), (10, 10), (10, 10, 10), (10, 10, 10, 5), (10, 10, 10, 5, 4)] + for dtype in _TORCH_DTYPES_MAPPING.values(): + for shape in shapes: + # Not serializable for some reasons + if dtype in [torch.bfloat16]: + continue + tensor = torch.ones(shape, dtype=dtype) + + t0 = time() + data, _ = serializer_tensor.serialize(tensor) + deserialized_tensor = serializer_tensor.deserialize(data) + tensor_time = time() - t0 + tensor_bytes = len(data) + + assert deserialized_tensor.dtype == dtype + assert torch.equal(tensor, deserialized_tensor) + + t1 = time() + data, _ = serializer_pickle.serialize(tensor) + deserialized_tensor = serializer_pickle.deserialize(data) + pickle_time = time() - t1 + pickle_bytes = len(data) + + assert deserialized_tensor.dtype == dtype + assert torch.equal(tensor, deserialized_tensor) + + ratio_times.append(pickle_time / tensor_time) + ratio_bytes.append(pickle_bytes / tensor_bytes) + + assert np.mean(ratio_times) > 3.5 + assert np.mean(ratio_bytes) > 2 + + +def test_assert_bfloat16_tensor_serializer(): + serializer = TensorSerializer() + tensor = torch.ones((10,), dtype=torch.bfloat16) + with pytest.raises(TypeError, match="Got unsupported ScalarType BFloat16"): + serializer.serialize(tensor) diff --git a/tests/tests_data/cache/test_writer.py b/tests/tests_data/cache/test_writer.py new file mode 100644 index 0000000000..331be0b5fa --- /dev/null +++ b/tests/tests_data/cache/test_writer.py @@ -0,0 +1,168 @@ +# Copyright The Lightning AI team. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json +import os + +import numpy as np +import pytest +from lightning.data.cache.reader import BinaryReader +from lightning.data.cache.sampler import ChunkedIndex +from lightning.data.cache.writer import BinaryWriter +from lightning_utilities.core.imports import RequirementCache + +_PIL_AVAILABLE = RequirementCache("PIL") + + +def test_binary_writer_with_ints_and_chunk_bytes(tmpdir): + with pytest.raises(FileNotFoundError, match="The provided cache directory `dontexists` doesn't exist."): + BinaryWriter("dontexists", {}) + + with pytest.raises(ValueError, match="No compresion algorithms are installed."): + BinaryWriter(tmpdir, {"i": "int"}, compression="something_else") + + binary_writer = BinaryWriter(tmpdir, chunk_bytes=90) + + for i in range(100): + binary_writer[i] = {"i": i, "i+1": i + 1, "i+2": i + 2} + + assert len(os.listdir(tmpdir)) == 19 + binary_writer.done() + binary_writer.merge() + assert len(os.listdir(tmpdir)) == 21 + + with open(os.path.join(tmpdir, "index.json")) as f: + data = json.load(f) + + assert data["chunks"][0]["chunk_size"] == 6 + assert data["chunks"][1]["chunk_size"] == 5 + assert data["chunks"][-1]["chunk_size"] == 4 + + chunk_sizes = np.cumsum([chunk["chunk_size"] for chunk in data["chunks"]]) + + reader = BinaryReader(tmpdir) + for i in range(100): + for chunk_index, chunk_start in enumerate(chunk_sizes): + if i >= chunk_start: + continue + break + data = reader.read(ChunkedIndex(i, chunk_index=chunk_index)) + assert data == {"i": i, "i+1": i + 1, "i+2": i + 2} + + +def test_binary_writer_with_ints_and_chunk_size(tmpdir): + with pytest.raises(FileNotFoundError, match="The provided cache directory `dontexists` doesn't exist."): + BinaryWriter("dontexists", {}) + + with pytest.raises(ValueError, match="No compresion algorithms are installed."): + BinaryWriter(tmpdir, {"i": "int"}, compression="something_else") + + binary_writer = BinaryWriter(tmpdir, chunk_size=25) + + for i in range(100): + binary_writer[i] = {"i": i, "i+1": i + 1, "i+2": i + 2} + + assert len(os.listdir(tmpdir)) == 3 + binary_writer.done() + binary_writer.merge() + assert len(os.listdir(tmpdir)) == 5 + + with open(os.path.join(tmpdir, "index.json")) as f: + data = json.load(f) + + assert data["chunks"][0]["chunk_size"] == 25 + assert data["chunks"][1]["chunk_size"] == 25 + assert data["chunks"][-1]["chunk_size"] == 25 + + reader = BinaryReader(tmpdir) + for i in range(100): + data = reader.read(ChunkedIndex(i, chunk_index=i // 25)) + assert data == {"i": i, "i+1": i + 1, "i+2": i + 2} + + +@pytest.mark.skipif(condition=not _PIL_AVAILABLE, reason="Requires: ['pil']") +def test_binary_writer_with_jpeg_and_int(tmpdir): + """Validate the writer and reader can serialize / deserialize a pair of image and label.""" + from PIL import Image + + cache_dir = os.path.join(tmpdir, "chunks") + os.makedirs(cache_dir, exist_ok=True) + binary_writer = BinaryWriter(cache_dir, chunk_bytes=2 << 12) + + imgs = [] + + for i in range(100): + path = os.path.join(tmpdir, f"img{i}.jpeg") + np_data = np.random.randint(255, size=(28, 28), dtype=np.uint8) + img = Image.fromarray(np_data).convert("L") + img.save(path, format="jpeg", quality=100) + img = Image.open(path) + imgs.append(img) + binary_writer[i] = {"x": img, "y": i} + + assert len(os.listdir(cache_dir)) == 24 + binary_writer.done() + binary_writer.merge() + assert len(os.listdir(cache_dir)) == 26 + + with open(os.path.join(cache_dir, "index.json")) as f: + data = json.load(f) + + assert data["chunks"][0]["chunk_size"] == 4 + assert data["chunks"][1]["chunk_size"] == 4 + assert data["chunks"][-1]["chunk_size"] == 4 + + reader = BinaryReader(cache_dir) + for i in range(100): + data = reader.read(ChunkedIndex(i, chunk_index=i // 4)) + np.testing.assert_array_equal(np.asarray(data["x"]).squeeze(0), imgs[i]) + assert data["y"] == i + + +@pytest.mark.skipif(condition=not _PIL_AVAILABLE, reason="Requires: ['pil']") +def test_binary_writer_with_jpeg_filepath_and_int(tmpdir): + """Validate the writer and reader can serialize / deserialize a pair of image and label.""" + from PIL import Image + + cache_dir = os.path.join(tmpdir, "chunks") + os.makedirs(cache_dir, exist_ok=True) + binary_writer = BinaryWriter(cache_dir, chunk_bytes=2 << 12) + + imgs = [] + + for i in range(100): + path = os.path.join(tmpdir, f"img{i}.jpeg") + np_data = np.random.randint(255, size=(28, 28), dtype=np.uint8) + img = Image.fromarray(np_data).convert("L") + img.save(path, format="jpeg", quality=100) + img = Image.open(path) + imgs.append(img) + binary_writer[i] = {"x": path, "y": i} + + assert len(os.listdir(cache_dir)) == 24 + binary_writer.done() + binary_writer.merge() + assert len(os.listdir(cache_dir)) == 26 + + with open(os.path.join(cache_dir, "index.json")) as f: + data = json.load(f) + + assert data["chunks"][0]["chunk_size"] == 4 + assert data["chunks"][1]["chunk_size"] == 4 + assert data["chunks"][-1]["chunk_size"] == 4 + + reader = BinaryReader(cache_dir) + for i in range(100): + data = reader.read(ChunkedIndex(i, chunk_index=i // 4)) + np.testing.assert_array_equal(np.asarray(data["x"]).squeeze(0), imgs[i]) + assert data["y"] == i diff --git a/tests/tests_data/datasets/test_get_index.py b/tests/tests_data/datasets/test_get_index.py index 3dad62bf62..b51f10966c 100644 --- a/tests/tests_data/datasets/test_get_index.py +++ b/tests/tests_data/datasets/test_get_index.py @@ -17,8 +17,7 @@ def get_test_index_data(index_path): return list(dict.fromkeys([item.split("/")[-1] for item in data if "jpeg" in item])) -@pytest.fixture(scope="session") -def image_set(tmp_path_factory): +def image_set(tmpdir): from PIL import Image file_nums = [ @@ -45,11 +44,11 @@ def image_set(tmp_path_factory): img = img.astype(np.uint8) im = Image.fromarray(img) - for i in file_nums: - fn = tmp_path_factory.mktemp("test_data") / f"img-{i}.jpeg" - im.save(fn) + folder_path = os.path.join(tmpdir, "test_data") + os.makedirs(folder_path, exist_ok=True) - return tmp_path_factory.getbasetemp()._str + for i in file_nums: + im.save(os.path.join(folder_path, f"img-{i}.jpeg")) @pytest.mark.xfail(strict=False, reason="Need a valid AWS key and AWS secret key in CI for this to work") @@ -70,7 +69,6 @@ def test_get_index_generate_for_s3_bucket(monkeypatch): test_bucket = "s3://nohaspublictestbucket" index_path = os.path.join(os.getcwd(), "index_1.txt") - print(index_path) got_index = get_index(s3_connection_path=test_bucket, index_file_path=index_path) assert got_index @@ -80,13 +78,16 @@ def test_get_index_generate_for_s3_bucket(monkeypatch): assert len(test_index_data) == len(generated_index) assert test_index_data == generated_index + os.remove(index_path) @pytest.mark.skipif(not package_available("lightning"), reason="Supported only with mono-package") @mock.patch("lightning.data.datasets.index.LightningClient", MagicMock()) -def test_get_index_generate_for_local_folder(image_set, monkeypatch): +def test_get_index_generate_for_local_folder(monkeypatch, tmpdir): """Can generate an index for an s3 bucket.""" + image_set(tmpdir) + client = MagicMock() client.projects_service_list_project_cluster_bindings.return_value = None client.data_connection_service_list_data_connections.return_value = None @@ -100,7 +101,7 @@ def test_get_index_generate_for_local_folder(image_set, monkeypatch): # test_local_bucket = "data/test_dataset" index_path = os.path.join(THIS_DIR, "index_2.txt") - got_index = get_index(s3_connection_path=image_set, index_file_path=index_path) + got_index = get_index(s3_connection_path=str(tmpdir), index_file_path=index_path) assert got_index