From e59dc41c8eec9fb237ff8ab977cb39d8d7511510 Mon Sep 17 00:00:00 2001 From: thomas chaton Date: Mon, 23 Oct 2023 18:06:48 +0100 Subject: [PATCH] Improve DatasetOptimizer API (#18827) Co-authored-by: Justus Schock <12886177+justusschock@users.noreply.github.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: thomas --- requirements/app/app.txt | 2 +- src/lightning/app/cli/commands/cp.py | 8 +- src/lightning/data/__init__.py | 9 +- src/lightning/data/streaming/dataset.py | 121 ++++++++++++++++-- .../data/streaming/dataset_optimizer.py | 102 +++++++-------- src/lightning/data/streaming/item_loader.py | 9 +- tests/tests_app/cli/test_cp.py | 2 +- tests/tests_data/streaming/test_cache.py | 42 +++++- .../streaming/test_data_optimizer.py | 39 ++++-- 9 files changed, 241 insertions(+), 93 deletions(-) diff --git a/requirements/app/app.txt b/requirements/app/app.txt index 0f3945efe6..254d71affb 100644 --- a/requirements/app/app.txt +++ b/requirements/app/app.txt @@ -1,4 +1,4 @@ -lightning-cloud ==0.5.42 # Must be pinned to ensure compatibility +lightning-cloud ==0.5.43 # Must be pinned to ensure compatibility packaging typing-extensions >=4.0.0, <4.8.0 deepdiff >=5.7.0, <6.6.0 diff --git a/src/lightning/app/cli/commands/cp.py b/src/lightning/app/cli/commands/cp.py index 069b09820e..8fd31b8d2c 100644 --- a/src/lightning/app/cli/commands/cp.py +++ b/src/lightning/app/cli/commands/cp.py @@ -113,7 +113,7 @@ def _upload_files(live, client: LightningClient, local_src: str, remote_dst: str else: upload_paths = [local_src] - upload_urls = [] + _upload_urls = [] clusters = client.projects_service_list_project_cluster_bindings(project_id) @@ -129,9 +129,11 @@ def _upload_files(live, client: LightningClient, local_src: str, remote_dst: str body=ProjectIdStorageBody(cluster_id=cluster.cluster_id, filename=filename), async_req=True, ) - upload_urls.append(response) + _upload_urls.append(response) - upload_urls = [upload_url.get().upload_url for upload_url in upload_urls] + upload_urls = [] + for upload_url in _upload_urls: + upload_urls.extend(upload_url.get().urls) live.stop() diff --git a/src/lightning/data/__init__.py b/src/lightning/data/__init__.py index 35677ed18f..851d1c0fe4 100644 --- a/src/lightning/data/__init__.py +++ b/src/lightning/data/__init__.py @@ -1,5 +1,12 @@ from lightning.data.datasets import LightningDataset, LightningIterableDataset from lightning.data.streaming.dataloader import StreamingDataLoader from lightning.data.streaming.dataset import StreamingDataset +from lightning.data.streaming.dataset_optimizer import DatasetOptimizer -__all__ = ["LightningDataset", "StreamingDataset", "StreamingDataLoader", "LightningIterableDataset"] +__all__ = [ + "LightningDataset", + "StreamingDataset", + "StreamingDataLoader", + "LightningIterableDataset", + "DatasetOptimizer", +] diff --git a/src/lightning/data/streaming/dataset.py b/src/lightning/data/streaming/dataset.py index 548d37747b..bb30be4f99 100644 --- a/src/lightning/data/streaming/dataset.py +++ b/src/lightning/data/streaming/dataset.py @@ -11,18 +11,28 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Literal, Optional, Union +from typing import Any, List, Literal, Optional, Union -from torch.utils.data import Dataset +import numpy as np +from torch.utils.data import IterableDataset +from lightning.data.datasets.env import _DistributedEnv, _WorkerEnv from lightning.data.streaming import Cache +from lightning.data.streaming.item_loader import BaseItemLoader +from lightning.data.streaming.sampler import ChunkedIndex -class StreamingDataset(Dataset): +class StreamingDataset(IterableDataset): """The streaming dataset can be used once your data have been optimised using the DatasetOptimiser class.""" def __init__( - self, name: str, version: Optional[Union[int, Literal["latest"]]] = "latest", cache_dir: Optional[str] = None + self, + name: str, + version: Optional[Union[int, Literal["latest"]]] = "latest", + cache_dir: Optional[str] = None, + item_loader: Optional[BaseItemLoader] = None, + shuffle: bool = True, + seed: int = 42, ) -> None: """The streaming dataset can be used once your data have been optimised using the DatasetOptimiser class. @@ -30,17 +40,106 @@ class StreamingDataset(Dataset): name: The name of the optimised dataset. version: The version of the dataset to use. cache_dir: The cache dir where the data would be stored. + item_loader: The logic to load an item from a chunk. + shuffle: Whether to shuffle the data. + seed: Random seed for shuffling. """ super().__init__() - self.cache = Cache(name=name, version=version, cache_dir=cache_dir) + self.cache = Cache(name=name, version=version, cache_dir=cache_dir, item_loader=item_loader, chunk_bytes=1) + + self.cache._reader._try_load_config() + + if not self.cache.filled: + raise ValueError(f"The provided dataset `{name}` isn't filled up.") + + self.shuffle = shuffle + self.distributed_env = _DistributedEnv.detect() + self.worker_env: Optional[_WorkerEnv] = None + + chunk_intervals = self.cache.get_chunk_interval() + self.L = sum([(interval[-1] - interval[0]) for interval in chunk_intervals]) + + self.worker_chunks: List[int] = [] + self.worker_intervals: List[List[int]] = [] + self.current_indexes: List[int] = [] + self.chunk_index = 0 + self.index = 0 + self.has_triggered_download = False + self.min_items_per_replica: Optional[int] = None + self.seed = seed + self.num_iter = 0 + self.random_state = None def __len__(self) -> int: - return len(self.cache) + return self.L - def __getitem__(self, idx: int) -> Any: - return self.cache[idx] + def __iter__(self) -> "StreamingDataset": + self.random_state = np.random.RandomState(seed=self.seed + self.num_iter) # type: ignore + chunk_intervals = self.cache.get_chunk_interval() + indexes = range(len(chunk_intervals)) + shuffled_indexes = self.random_state.permutation(indexes) if self.shuffle else list(indexes) + shuffled_chunk_intervals = np.asarray(chunk_intervals)[shuffled_indexes] - def getitem(self, obj: Any) -> Any: - """Override the getitem with your own logic to transform the cache object.""" - return obj + chunks_per_replica: List[List[int]] = [[] for _ in range(self.distributed_env.world_size)] + intervals_per_replica: List[List[List[int]]] = [[] for _ in range(self.distributed_env.world_size)] + for index, (chunk_index, chunk_interval) in enumerate(zip(shuffled_indexes, shuffled_chunk_intervals)): + replica_index = index % self.distributed_env.world_size + chunks_per_replica[replica_index].append(chunk_index) + intervals_per_replica[replica_index].append(chunk_interval) + + current_chunks = chunks_per_replica[self.distributed_env.global_rank % self.distributed_env.world_size] + current_intervals = intervals_per_replica[self.distributed_env.global_rank % self.distributed_env.world_size] + + if self.worker_env is None: + self.worker_env = _WorkerEnv.detect() + + self.worker_chunks = [] + self.worker_intervals = [] + + for i, (chunk_index, chunk_interval) in enumerate(zip(current_chunks, current_intervals)): + if i % self.worker_env.world_size != self.worker_env.rank: + continue + self.worker_chunks.append(chunk_index) + self.worker_intervals.append(chunk_interval) + + self.current_indexes = [] + self.chunk_index = 0 + self.num_iter += 1 + + return self + + def __getitem__(self, index: Union[ChunkedIndex, int]) -> Any: + if isinstance(index, int): + index = ChunkedIndex(index, self.cache._get_chunk_index_from_index(index)) + return self.cache[index] + + def __next__(self) -> Any: + # Lazily re-populate the interval to reduce memory usage. + if len(self.current_indexes) == 0: + if self.chunk_index == len(self.worker_intervals): + raise StopIteration + + interval = self.worker_intervals[self.chunk_index] + current_indexes = np.arange(0, interval[1] - interval[0]) + if self.shuffle: + current_indexes = self.random_state.permutation(current_indexes) + self.current_indexes = current_indexes.tolist() + self.chunk_index += 1 + + # Get the first index + index = self.current_indexes.pop(0) + + # Call the `__getitem__` method. + data = self.__getitem__( + ChunkedIndex( + index=index, + chunk_index=self.worker_chunks[self.chunk_index - 1], + chunk_indexes=None if self.has_triggered_download else self.worker_chunks, + ) + ) + + self.has_triggered_download = True + self.index += 1 + + return data diff --git a/src/lightning/data/streaming/dataset_optimizer.py b/src/lightning/data/streaming/dataset_optimizer.py index 1ed0230e7b..083e941eba 100644 --- a/src/lightning/data/streaming/dataset_optimizer.py +++ b/src/lightning/data/streaming/dataset_optimizer.py @@ -3,15 +3,15 @@ import os import signal import traceback import types -from abc import ABC, abstractmethod from enum import Enum from multiprocessing import Process, Queue from pathlib import Path from queue import Empty from shutil import copyfile +from textwrap import dedent from threading import Thread from time import sleep, time -from typing import Any, Callable, Dict, List, Literal, Optional, Tuple +from typing import Any, Callable, Dict, List, Literal, Optional, Protocol, Tuple, TypeVar, runtime_checkable from urllib import parse from tqdm.auto import tqdm @@ -167,7 +167,7 @@ class BaseWorker: start_index: int, dataset_name: str, node_rank: int, - dataset_optimizer: "DatasetOptimizer", + prepare_item: Callable, src_dir: str, remote_src_dir: str, remote_dst_dir: Optional[str], @@ -187,7 +187,7 @@ class BaseWorker: self.start_index = start_index self.dataset_name = dataset_name self.node_rank = node_rank - self.prepare_item = dataset_optimizer.prepare_item + self.prepare_item = prepare_item self.src_dir = src_dir self.remote_src_dir = remote_src_dir self.remote_dst_dir = remote_dst_dir @@ -432,57 +432,21 @@ class WorkerType(Enum): PROCESS = "process" -class DatasetOptimizer(ABC): - @abstractmethod - def prepare_dataset_structure(self, src_dir: str, filepaths: List[str]) -> List[Any]: - """This function is meant to return a list of item metadata. Each item metadata should be enough to prepare a - single item when called with the prepare_item. +T = TypeVar("T") - Example:: - # For a classification use case - - def prepare_dataset_structure(self, src_dir, filepaths) - import numpy as np - - filepaths = ['class_a/file_1.ext', ..., 'class_b/file_1.ext', ...] - classes = np.unique([filepath.split("/")[0] for filepath in filepaths]) - classes_to_idx_map = {c: idx for idx, c in enumerate(classes)} - - # Return pair with the filepath to the obj and its class - # [('class_a/file_1.ext', 0), ... ('class_b/file_1.ext', 1)] - return [(filepath, classes_to_idx_map[filepath.split("/")[0]]) for filepath in filepaths] - - Example:: - - # For a image segmentation use case - - def prepare_dataset_structure(self, src_dir, filepaths) - import numpy as np - - filepaths = ['file_1.JPEG', 'file_1.mask', .... 'file_N.JPEG', 'file_N.mask', ...] - - # [('file_1.JPEG', 'file_1.mask'), ... ('file_N.JPEG', 'file_N.mask')] - return [(x[i], x[i+1]) for i in range(len(filepaths) -1)] - - def prepare_item(self, obj): - image_filepath, mask_filepath = obj - - image = load_and_resize(image_filepath) - mask = load_and_resize(mask_filepath) - return (image, mask) - - """ +@runtime_checkable +class _OptimizableDataset(Protocol): + @staticmethod + def prepare_dataset_structure(root: str, filepaths: List[str]) -> List[T]: pass - def prepare_item(self, metadata_item: Any) -> Any: - """Using some metadata, prepare the associated item. + @staticmethod + def prepare_item(item_metadata: T) -> Any: + return item_metadata - The output of this function will be binarised - - """ - return metadata_item +class DatasetOptimizer: def __init__( self, name: str, @@ -547,9 +511,29 @@ class DatasetOptimizer(ABC): ) self.random_seed = random_seed - def run(self) -> None: + def run(self, optimizable_dataset: _OptimizableDataset) -> None: """The `DatasetChunker.run(...)` method is used to trigger the data processing from your dataset into chunks.""" + if not isinstance(optimizable_dataset, _OptimizableDataset): + raise ValueError( + dedent( + """The provided argument to the DatasetOptimizer.run(...) needs to have the following format: + + Example: + + class YourDataset: + + @staticmethod + def prepare_dataset_structure(root: str, filepaths: List[str]) -> List[T]: + return [...] + + @staticmethod + def prepare_item(item_metadata: T) -> Any: + return ... + """ + ) + ) + t0 = time() print(f"Setup started for `{self.name}` with fast_dev_run={self.fast_dev_run}.") @@ -564,7 +548,7 @@ class DatasetOptimizer(ABC): seed_everything(self.random_seed) # Call the setup method of the user - user_items = self.prepare_dataset_structure(self.src_dir, filepaths) + user_items: List[Any] = optimizable_dataset.prepare_dataset_structure(self.src_dir, filepaths) if not isinstance(user_items, list): raise ValueError("The setup_fn should return a list of item metadata.") @@ -588,9 +572,9 @@ class DatasetOptimizer(ABC): signal.signal(signal.SIGINT, self._signal_handler) if self.worker_type == WorkerType.THREAD.value: - self._create_thread_workers(begins, workers_user_items) + self._create_thread_workers(optimizable_dataset, begins, workers_user_items) else: - self._create_process_workers(begins, workers_user_items) + self._create_process_workers(optimizable_dataset, begins, workers_user_items) print("Workers are ready ! Starting data processing...") @@ -634,7 +618,9 @@ class DatasetOptimizer(ABC): w.join(0) raise RuntimeError(f"We found the following error {error}.") - def _create_thread_workers(self, begins: List[int], workers_user_items: List[List[Any]]) -> None: + def _create_thread_workers( + self, optimizable_dataset: _OptimizableDataset, begins: List[int], workers_user_items: List[List[Any]] + ) -> None: current_total = 0 total = sum([len(w) for w in workers_user_items]) with tqdm(total=total, smoothing=0) as pbar: @@ -649,7 +635,7 @@ class DatasetOptimizer(ABC): begins[worker_idx], self.name, _get_node_rank(), - self, + optimizable_dataset.prepare_item, self.src_dir, self.remote_src_dir, self.remote_dst_dir, @@ -676,7 +662,9 @@ class DatasetOptimizer(ABC): if current_total == total: break - def _create_process_workers(self, begins: List[int], workers_user_items: List[List[Any]]) -> None: + def _create_process_workers( + self, optimizable_dataset: _OptimizableDataset, begins: List[int], workers_user_items: List[List[Any]] + ) -> None: self.progress_queue = Queue() workers: List[DataWorkerProcess] = [] stop_queues: List[Queue] = [] @@ -688,7 +676,7 @@ class DatasetOptimizer(ABC): begins[worker_idx], self.name, _get_node_rank(), - self, + optimizable_dataset.prepare_item, self.src_dir, self.remote_src_dir, self.remote_dst_dir, diff --git a/src/lightning/data/streaming/item_loader.py b/src/lightning/data/streaming/item_loader.py index 4abe795eac..d22db33a6c 100644 --- a/src/lightning/data/streaming/item_loader.py +++ b/src/lightning/data/streaming/item_loader.py @@ -62,7 +62,7 @@ class PyTreeLoader(BaseItemLoader): return intervals def load_item_from_chunk(self, index: int, chunk_index: int, chunk_filepath: str, begin: int) -> bytes: - offset = (1 + (index - begin)) * 4 + offset = (1 + (index - begin) if index >= begin else index + 1) * 4 while not os.path.exists(chunk_filepath): sleep(0.0001) @@ -115,9 +115,10 @@ class TokensLoader(BaseItemLoader): end = 0 for chunk in self._chunks: dim = chunk["dim"] - end += dim // self._block_size + num_blocks = dim // self._block_size + end += num_blocks self._intervals.append((begin, end)) - begin += end + begin += num_blocks return self._intervals @@ -136,5 +137,5 @@ class TokensLoader(BaseItemLoader): assert self._dtype buffer: bytes = self._buffers[chunk_index] - offset = self._dtype.itemsize * index * self._block_size + offset = self._dtype.itemsize * index return torch.frombuffer(buffer, dtype=self._dtype, count=self._block_size, offset=offset) diff --git a/tests/tests_app/cli/test_cp.py b/tests/tests_app/cli/test_cp.py index 77c14ae7e8..311badede7 100644 --- a/tests/tests_app/cli/test_cp.py +++ b/tests/tests_app/cli/test_cp.py @@ -47,7 +47,7 @@ def test_cp_local_to_remote(tmpdir, monkeypatch): ) result = MagicMock() - result.get.return_value = V1UploadProjectArtifactResponse(upload_url="http://foo.bar") + result.get.return_value = V1UploadProjectArtifactResponse(urls=["http://foo.bar"]) client.lightningapp_instance_service_upload_project_artifact.return_value = result monkeypatch.setattr(cp, "LightningClient", MagicMock(return_value=client)) diff --git a/tests/tests_data/streaming/test_cache.py b/tests/tests_data/streaming/test_cache.py index 74e26c0b95..5c9bb1c6fe 100644 --- a/tests/tests_data/streaming/test_cache.py +++ b/tests/tests_data/streaming/test_cache.py @@ -23,10 +23,12 @@ from lightning.data.datasets.env import _DistributedEnv from lightning.data.streaming import Cache from lightning.data.streaming import cache as cache_module from lightning.data.streaming.dataloader import StreamingDataLoader +from lightning.data.streaming.dataset import StreamingDataset +from lightning.data.streaming.item_loader import TokensLoader from lightning.fabric import Fabric from lightning.pytorch.demos.boring_classes import RandomDataset from lightning_utilities.core.imports import RequirementCache -from torch.utils.data import Dataset +from torch.utils.data import DataLoader, Dataset _PIL_AVAILABLE = RequirementCache("PIL") _TORCH_VISION_AVAILABLE = RequirementCache("torchvision") @@ -113,11 +115,23 @@ def _cache_for_image_dataset(num_workers, tmpdir, fabric=None): assert indexes2 != indexes + streaming_dataset = StreamingDataset(name="dummy", cache_dir=cache_dir) + for i in range(len(streaming_dataset)): + cached_data = streaming_dataset[i] + original_data = dataset.data[i] + assert cached_data["class"] == original_data["class"] + original_array = PILToTensor()(Image.open(original_data["image"])) + assert torch.equal(original_array, cached_data["image"]) + + streaming_dataset_iter = iter(streaming_dataset) + for _ in streaming_dataset_iter: + pass + @pytest.mark.skipif( condition=not _PIL_AVAILABLE or not _TORCH_VISION_AVAILABLE, reason="Requires: ['pil', 'torchvision']" ) -@pytest.mark.parametrize("num_workers", [1]) +@pytest.mark.parametrize("num_workers", [0, 1, 2]) def test_cache_for_image_dataset(num_workers, tmpdir): cache_dir = os.path.join(tmpdir, "cache") os.makedirs(cache_dir) @@ -218,3 +232,27 @@ def test_cache_with_name(tmpdir, monkeypatch): assert cache._writer._chunk_size == 2 assert cache._writer._cache_dir == os.path.join(tmpdir, "something") assert cache._reader._remote_dir == os.path.join(tmpdir, "remote_dir") + + +def test_streaming_dataset(tmpdir, monkeypatch): + seed_everything(42) + + os.makedirs(os.path.join(tmpdir, "remote_dir"), exist_ok=True) + monkeypatch.setattr(cache_module, "_try_create_cache_dir", lambda name: tmpdir) + + with pytest.raises(ValueError, match="The provided dataset `choco` isn't filled up."): + dataset = StreamingDataset(name="choco", cache_dir=tmpdir) + + dataset = RandomDataset(128, 64) + dataloader = StreamingDataLoader(dataset, cache_dir=tmpdir, chunk_bytes=2 << 12) + for batch in dataloader: + assert isinstance(batch, torch.Tensor) + + dataset = StreamingDataset(name="choco", cache_dir=tmpdir, item_loader=TokensLoader(block_size=10)) + + assert len(dataset) == 816 + dataset_iter = iter(dataset) + assert len(dataset_iter) == 816 + + dataloader = DataLoader(dataset, num_workers=2, batch_size=2) + assert len(dataloader) == 408 diff --git a/tests/tests_data/streaming/test_data_optimizer.py b/tests/tests_data/streaming/test_data_optimizer.py index ba0c8732f1..652e5f43c2 100644 --- a/tests/tests_data/streaming/test_data_optimizer.py +++ b/tests/tests_data/streaming/test_data_optimizer.py @@ -6,7 +6,7 @@ from unittest import mock import numpy as np import pytest import torch -from lightning import seed_everything +from lightning import LightningDataModule, seed_everything from lightning.data.streaming.dataset_optimizer import ( DatasetOptimizer, _download_data_target, @@ -131,11 +131,14 @@ def test_download_data_target(tmpdir): assert os.listdir(cache_dir) == ["a.txt"] -class TestDatasetOptimizer(DatasetOptimizer): +class DataModuleImage(LightningDataModule): def prepare_dataset_structure(self, src_dir: str, filepaths: List[str]) -> List[Any]: assert len(filepaths) == 30 return filepaths + def prepare_item(self, item): + return item + @pytest.mark.parametrize("delete_cached_files", [False, True]) @pytest.mark.parametrize("fast_dev_run", [False, True]) @@ -154,7 +157,7 @@ def test_data_optimizer(fast_dev_run, delete_cached_files, tmpdir, monkeypatch): cache_dir = os.path.join(tmpdir, "cache") monkeypatch.setenv("DATA_OPTIMIZER_HOME_FOLDER", home_dir) monkeypatch.setenv("DATA_OPTIMIZER_CACHE_FOLDER", cache_dir) - datasetOptimizer = TestDatasetOptimizer( + dataset_optimizer = DatasetOptimizer( name="dummy_dataset", src_dir=tmpdir, chunk_size=2, @@ -165,7 +168,7 @@ def test_data_optimizer(fast_dev_run, delete_cached_files, tmpdir, monkeypatch): delete_cached_files=delete_cached_files, fast_dev_run=fast_dev_run, ) - datasetOptimizer.run() + dataset_optimizer.run(DataModuleImage()) assert sorted(os.listdir(cache_dir)) == ["data", "dummy_dataset"] @@ -242,7 +245,7 @@ def test_data_optimizer_distributed(fast_dev_run, delete_cached_files, tmpdir, m monkeypatch.setenv("DATA_OPTIMIZER_CACHE_FOLDER", cache_dir) monkeypatch.setenv("DATA_OPTIMIZER_NUM_NODES", "2") monkeypatch.setenv("DATA_OPTIMIZER_NODE_RANK", "0") - datasetOptimizer = TestDatasetOptimizer( + dataset_optimizer = DatasetOptimizer( name="dummy_dataset", src_dir=tmpdir, chunk_size=2, @@ -254,7 +257,7 @@ def test_data_optimizer_distributed(fast_dev_run, delete_cached_files, tmpdir, m fast_dev_run=fast_dev_run, remote_dst_dir=remote_dst_dir, ) - datasetOptimizer.run() + dataset_optimizer.run(DataModuleImage()) assert sorted(os.listdir(cache_dir)) == ["data", "dummy_dataset"] @@ -276,7 +279,7 @@ def test_data_optimizer_distributed(fast_dev_run, delete_cached_files, tmpdir, m monkeypatch.setenv("DATA_OPTIMIZER_CACHE_FOLDER", cache_dir) monkeypatch.setenv("DATA_OPTIMIZER_NUM_NODES", "2") monkeypatch.setenv("DATA_OPTIMIZER_NODE_RANK", "1") - datasetOptimizer = TestDatasetOptimizer( + dataset_optimizer = DatasetOptimizer( name="dummy_dataset", src_dir=tmpdir, chunk_size=2, @@ -288,7 +291,7 @@ def test_data_optimizer_distributed(fast_dev_run, delete_cached_files, tmpdir, m fast_dev_run=fast_dev_run, remote_dst_dir=remote_dst_dir, ) - datasetOptimizer.run() + dataset_optimizer.run(DataModuleImage()) assert sorted(os.listdir(cache_dir)) == ["data", "dummy_dataset"] @@ -309,11 +312,13 @@ def test_data_optimizer_distributed(fast_dev_run, delete_cached_files, tmpdir, m assert sorted(os.listdir(remote_dst_dir)) == expected -class NlpDatasetOptimizer(DatasetOptimizer): - def prepare_dataset_structure(self, src_dir: str, filepaths: List[str]) -> List[Any]: +class DataModule(LightningDataModule): + @staticmethod + def prepare_dataset_structure(src_dir: str, filepaths: List[str]) -> List[Any]: return [os.path.join(src_dir, "dummy2")] - def prepare_item(self, filepath): + @staticmethod + def prepare_item(filepath): for _ in range(100): yield torch.randint(0, 1000, (np.random.randint(0, 1000),)).to(torch.int) @@ -327,7 +332,15 @@ def test_data_optimizer_nlp(tmpdir, monkeypatch): with open(os.path.join(tmpdir, "dummy.txt"), "w") as f: f.write("Hello World !") - dataset_optimizer = NlpDatasetOptimizer( + dataset_optimizer = DatasetOptimizer( name="dummy2", src_dir=tmpdir, num_workers=1, num_downloaders=1, chunk_size=1024 * 11 ) - dataset_optimizer.run() + dataset_optimizer.run(DataModule()) + + +def test_data_optimizer_api(tmpdir): + dataset_optimizer = DatasetOptimizer( + name="dummy2", src_dir=tmpdir, num_workers=1, num_downloaders=1, chunk_size=1024 * 11 + ) + with pytest.raises(ValueError, match="prepare_dataset_structure"): + dataset_optimizer.run(None)