Improve DatasetOptimizer API (#18827)
Co-authored-by: Justus Schock <12886177+justusschock@users.noreply.github.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: thomas <thomas@thomass-MacBook-Pro.local>
This commit is contained in:
parent
1a5718aa38
commit
e59dc41c8e
|
@ -1,4 +1,4 @@
|
|||
lightning-cloud ==0.5.42 # Must be pinned to ensure compatibility
|
||||
lightning-cloud ==0.5.43 # Must be pinned to ensure compatibility
|
||||
packaging
|
||||
typing-extensions >=4.0.0, <4.8.0
|
||||
deepdiff >=5.7.0, <6.6.0
|
||||
|
|
|
@ -113,7 +113,7 @@ def _upload_files(live, client: LightningClient, local_src: str, remote_dst: str
|
|||
else:
|
||||
upload_paths = [local_src]
|
||||
|
||||
upload_urls = []
|
||||
_upload_urls = []
|
||||
|
||||
clusters = client.projects_service_list_project_cluster_bindings(project_id)
|
||||
|
||||
|
@ -129,9 +129,11 @@ def _upload_files(live, client: LightningClient, local_src: str, remote_dst: str
|
|||
body=ProjectIdStorageBody(cluster_id=cluster.cluster_id, filename=filename),
|
||||
async_req=True,
|
||||
)
|
||||
upload_urls.append(response)
|
||||
_upload_urls.append(response)
|
||||
|
||||
upload_urls = [upload_url.get().upload_url for upload_url in upload_urls]
|
||||
upload_urls = []
|
||||
for upload_url in _upload_urls:
|
||||
upload_urls.extend(upload_url.get().urls)
|
||||
|
||||
live.stop()
|
||||
|
||||
|
|
|
@ -1,5 +1,12 @@
|
|||
from lightning.data.datasets import LightningDataset, LightningIterableDataset
|
||||
from lightning.data.streaming.dataloader import StreamingDataLoader
|
||||
from lightning.data.streaming.dataset import StreamingDataset
|
||||
from lightning.data.streaming.dataset_optimizer import DatasetOptimizer
|
||||
|
||||
__all__ = ["LightningDataset", "StreamingDataset", "StreamingDataLoader", "LightningIterableDataset"]
|
||||
__all__ = [
|
||||
"LightningDataset",
|
||||
"StreamingDataset",
|
||||
"StreamingDataLoader",
|
||||
"LightningIterableDataset",
|
||||
"DatasetOptimizer",
|
||||
]
|
||||
|
|
|
@ -11,18 +11,28 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import Any, Literal, Optional, Union
|
||||
from typing import Any, List, Literal, Optional, Union
|
||||
|
||||
from torch.utils.data import Dataset
|
||||
import numpy as np
|
||||
from torch.utils.data import IterableDataset
|
||||
|
||||
from lightning.data.datasets.env import _DistributedEnv, _WorkerEnv
|
||||
from lightning.data.streaming import Cache
|
||||
from lightning.data.streaming.item_loader import BaseItemLoader
|
||||
from lightning.data.streaming.sampler import ChunkedIndex
|
||||
|
||||
|
||||
class StreamingDataset(Dataset):
|
||||
class StreamingDataset(IterableDataset):
|
||||
"""The streaming dataset can be used once your data have been optimised using the DatasetOptimiser class."""
|
||||
|
||||
def __init__(
|
||||
self, name: str, version: Optional[Union[int, Literal["latest"]]] = "latest", cache_dir: Optional[str] = None
|
||||
self,
|
||||
name: str,
|
||||
version: Optional[Union[int, Literal["latest"]]] = "latest",
|
||||
cache_dir: Optional[str] = None,
|
||||
item_loader: Optional[BaseItemLoader] = None,
|
||||
shuffle: bool = True,
|
||||
seed: int = 42,
|
||||
) -> None:
|
||||
"""The streaming dataset can be used once your data have been optimised using the DatasetOptimiser class.
|
||||
|
||||
|
@ -30,17 +40,106 @@ class StreamingDataset(Dataset):
|
|||
name: The name of the optimised dataset.
|
||||
version: The version of the dataset to use.
|
||||
cache_dir: The cache dir where the data would be stored.
|
||||
item_loader: The logic to load an item from a chunk.
|
||||
shuffle: Whether to shuffle the data.
|
||||
seed: Random seed for shuffling.
|
||||
|
||||
"""
|
||||
super().__init__()
|
||||
self.cache = Cache(name=name, version=version, cache_dir=cache_dir)
|
||||
self.cache = Cache(name=name, version=version, cache_dir=cache_dir, item_loader=item_loader, chunk_bytes=1)
|
||||
|
||||
self.cache._reader._try_load_config()
|
||||
|
||||
if not self.cache.filled:
|
||||
raise ValueError(f"The provided dataset `{name}` isn't filled up.")
|
||||
|
||||
self.shuffle = shuffle
|
||||
self.distributed_env = _DistributedEnv.detect()
|
||||
self.worker_env: Optional[_WorkerEnv] = None
|
||||
|
||||
chunk_intervals = self.cache.get_chunk_interval()
|
||||
self.L = sum([(interval[-1] - interval[0]) for interval in chunk_intervals])
|
||||
|
||||
self.worker_chunks: List[int] = []
|
||||
self.worker_intervals: List[List[int]] = []
|
||||
self.current_indexes: List[int] = []
|
||||
self.chunk_index = 0
|
||||
self.index = 0
|
||||
self.has_triggered_download = False
|
||||
self.min_items_per_replica: Optional[int] = None
|
||||
self.seed = seed
|
||||
self.num_iter = 0
|
||||
self.random_state = None
|
||||
|
||||
def __len__(self) -> int:
|
||||
return len(self.cache)
|
||||
return self.L
|
||||
|
||||
def __getitem__(self, idx: int) -> Any:
|
||||
return self.cache[idx]
|
||||
def __iter__(self) -> "StreamingDataset":
|
||||
self.random_state = np.random.RandomState(seed=self.seed + self.num_iter) # type: ignore
|
||||
chunk_intervals = self.cache.get_chunk_interval()
|
||||
indexes = range(len(chunk_intervals))
|
||||
shuffled_indexes = self.random_state.permutation(indexes) if self.shuffle else list(indexes)
|
||||
shuffled_chunk_intervals = np.asarray(chunk_intervals)[shuffled_indexes]
|
||||
|
||||
def getitem(self, obj: Any) -> Any:
|
||||
"""Override the getitem with your own logic to transform the cache object."""
|
||||
return obj
|
||||
chunks_per_replica: List[List[int]] = [[] for _ in range(self.distributed_env.world_size)]
|
||||
intervals_per_replica: List[List[List[int]]] = [[] for _ in range(self.distributed_env.world_size)]
|
||||
for index, (chunk_index, chunk_interval) in enumerate(zip(shuffled_indexes, shuffled_chunk_intervals)):
|
||||
replica_index = index % self.distributed_env.world_size
|
||||
chunks_per_replica[replica_index].append(chunk_index)
|
||||
intervals_per_replica[replica_index].append(chunk_interval)
|
||||
|
||||
current_chunks = chunks_per_replica[self.distributed_env.global_rank % self.distributed_env.world_size]
|
||||
current_intervals = intervals_per_replica[self.distributed_env.global_rank % self.distributed_env.world_size]
|
||||
|
||||
if self.worker_env is None:
|
||||
self.worker_env = _WorkerEnv.detect()
|
||||
|
||||
self.worker_chunks = []
|
||||
self.worker_intervals = []
|
||||
|
||||
for i, (chunk_index, chunk_interval) in enumerate(zip(current_chunks, current_intervals)):
|
||||
if i % self.worker_env.world_size != self.worker_env.rank:
|
||||
continue
|
||||
self.worker_chunks.append(chunk_index)
|
||||
self.worker_intervals.append(chunk_interval)
|
||||
|
||||
self.current_indexes = []
|
||||
self.chunk_index = 0
|
||||
self.num_iter += 1
|
||||
|
||||
return self
|
||||
|
||||
def __getitem__(self, index: Union[ChunkedIndex, int]) -> Any:
|
||||
if isinstance(index, int):
|
||||
index = ChunkedIndex(index, self.cache._get_chunk_index_from_index(index))
|
||||
return self.cache[index]
|
||||
|
||||
def __next__(self) -> Any:
|
||||
# Lazily re-populate the interval to reduce memory usage.
|
||||
if len(self.current_indexes) == 0:
|
||||
if self.chunk_index == len(self.worker_intervals):
|
||||
raise StopIteration
|
||||
|
||||
interval = self.worker_intervals[self.chunk_index]
|
||||
current_indexes = np.arange(0, interval[1] - interval[0])
|
||||
if self.shuffle:
|
||||
current_indexes = self.random_state.permutation(current_indexes)
|
||||
self.current_indexes = current_indexes.tolist()
|
||||
self.chunk_index += 1
|
||||
|
||||
# Get the first index
|
||||
index = self.current_indexes.pop(0)
|
||||
|
||||
# Call the `__getitem__` method.
|
||||
data = self.__getitem__(
|
||||
ChunkedIndex(
|
||||
index=index,
|
||||
chunk_index=self.worker_chunks[self.chunk_index - 1],
|
||||
chunk_indexes=None if self.has_triggered_download else self.worker_chunks,
|
||||
)
|
||||
)
|
||||
|
||||
self.has_triggered_download = True
|
||||
self.index += 1
|
||||
|
||||
return data
|
||||
|
|
|
@ -3,15 +3,15 @@ import os
|
|||
import signal
|
||||
import traceback
|
||||
import types
|
||||
from abc import ABC, abstractmethod
|
||||
from enum import Enum
|
||||
from multiprocessing import Process, Queue
|
||||
from pathlib import Path
|
||||
from queue import Empty
|
||||
from shutil import copyfile
|
||||
from textwrap import dedent
|
||||
from threading import Thread
|
||||
from time import sleep, time
|
||||
from typing import Any, Callable, Dict, List, Literal, Optional, Tuple
|
||||
from typing import Any, Callable, Dict, List, Literal, Optional, Protocol, Tuple, TypeVar, runtime_checkable
|
||||
from urllib import parse
|
||||
|
||||
from tqdm.auto import tqdm
|
||||
|
@ -167,7 +167,7 @@ class BaseWorker:
|
|||
start_index: int,
|
||||
dataset_name: str,
|
||||
node_rank: int,
|
||||
dataset_optimizer: "DatasetOptimizer",
|
||||
prepare_item: Callable,
|
||||
src_dir: str,
|
||||
remote_src_dir: str,
|
||||
remote_dst_dir: Optional[str],
|
||||
|
@ -187,7 +187,7 @@ class BaseWorker:
|
|||
self.start_index = start_index
|
||||
self.dataset_name = dataset_name
|
||||
self.node_rank = node_rank
|
||||
self.prepare_item = dataset_optimizer.prepare_item
|
||||
self.prepare_item = prepare_item
|
||||
self.src_dir = src_dir
|
||||
self.remote_src_dir = remote_src_dir
|
||||
self.remote_dst_dir = remote_dst_dir
|
||||
|
@ -432,57 +432,21 @@ class WorkerType(Enum):
|
|||
PROCESS = "process"
|
||||
|
||||
|
||||
class DatasetOptimizer(ABC):
|
||||
@abstractmethod
|
||||
def prepare_dataset_structure(self, src_dir: str, filepaths: List[str]) -> List[Any]:
|
||||
"""This function is meant to return a list of item metadata. Each item metadata should be enough to prepare a
|
||||
single item when called with the prepare_item.
|
||||
T = TypeVar("T")
|
||||
|
||||
Example::
|
||||
|
||||
# For a classification use case
|
||||
|
||||
def prepare_dataset_structure(self, src_dir, filepaths)
|
||||
import numpy as np
|
||||
|
||||
filepaths = ['class_a/file_1.ext', ..., 'class_b/file_1.ext', ...]
|
||||
classes = np.unique([filepath.split("/")[0] for filepath in filepaths])
|
||||
classes_to_idx_map = {c: idx for idx, c in enumerate(classes)}
|
||||
|
||||
# Return pair with the filepath to the obj and its class
|
||||
# [('class_a/file_1.ext', 0), ... ('class_b/file_1.ext', 1)]
|
||||
return [(filepath, classes_to_idx_map[filepath.split("/")[0]]) for filepath in filepaths]
|
||||
|
||||
Example::
|
||||
|
||||
# For a image segmentation use case
|
||||
|
||||
def prepare_dataset_structure(self, src_dir, filepaths)
|
||||
import numpy as np
|
||||
|
||||
filepaths = ['file_1.JPEG', 'file_1.mask', .... 'file_N.JPEG', 'file_N.mask', ...]
|
||||
|
||||
# [('file_1.JPEG', 'file_1.mask'), ... ('file_N.JPEG', 'file_N.mask')]
|
||||
return [(x[i], x[i+1]) for i in range(len(filepaths) -1)]
|
||||
|
||||
def prepare_item(self, obj):
|
||||
image_filepath, mask_filepath = obj
|
||||
|
||||
image = load_and_resize(image_filepath)
|
||||
mask = load_and_resize(mask_filepath)
|
||||
return (image, mask)
|
||||
|
||||
"""
|
||||
@runtime_checkable
|
||||
class _OptimizableDataset(Protocol):
|
||||
@staticmethod
|
||||
def prepare_dataset_structure(root: str, filepaths: List[str]) -> List[T]:
|
||||
pass
|
||||
|
||||
def prepare_item(self, metadata_item: Any) -> Any:
|
||||
"""Using some metadata, prepare the associated item.
|
||||
@staticmethod
|
||||
def prepare_item(item_metadata: T) -> Any:
|
||||
return item_metadata
|
||||
|
||||
The output of this function will be binarised
|
||||
|
||||
"""
|
||||
return metadata_item
|
||||
|
||||
class DatasetOptimizer:
|
||||
def __init__(
|
||||
self,
|
||||
name: str,
|
||||
|
@ -547,9 +511,29 @@ class DatasetOptimizer(ABC):
|
|||
)
|
||||
self.random_seed = random_seed
|
||||
|
||||
def run(self) -> None:
|
||||
def run(self, optimizable_dataset: _OptimizableDataset) -> None:
|
||||
"""The `DatasetChunker.run(...)` method is used to trigger the data processing from your dataset into
|
||||
chunks."""
|
||||
if not isinstance(optimizable_dataset, _OptimizableDataset):
|
||||
raise ValueError(
|
||||
dedent(
|
||||
"""The provided argument to the DatasetOptimizer.run(...) needs to have the following format:
|
||||
|
||||
Example:
|
||||
|
||||
class YourDataset:
|
||||
|
||||
@staticmethod
|
||||
def prepare_dataset_structure(root: str, filepaths: List[str]) -> List[T]:
|
||||
return [...]
|
||||
|
||||
@staticmethod
|
||||
def prepare_item(item_metadata: T) -> Any:
|
||||
return ...
|
||||
"""
|
||||
)
|
||||
)
|
||||
|
||||
t0 = time()
|
||||
print(f"Setup started for `{self.name}` with fast_dev_run={self.fast_dev_run}.")
|
||||
|
||||
|
@ -564,7 +548,7 @@ class DatasetOptimizer(ABC):
|
|||
seed_everything(self.random_seed)
|
||||
|
||||
# Call the setup method of the user
|
||||
user_items = self.prepare_dataset_structure(self.src_dir, filepaths)
|
||||
user_items: List[Any] = optimizable_dataset.prepare_dataset_structure(self.src_dir, filepaths)
|
||||
|
||||
if not isinstance(user_items, list):
|
||||
raise ValueError("The setup_fn should return a list of item metadata.")
|
||||
|
@ -588,9 +572,9 @@ class DatasetOptimizer(ABC):
|
|||
signal.signal(signal.SIGINT, self._signal_handler)
|
||||
|
||||
if self.worker_type == WorkerType.THREAD.value:
|
||||
self._create_thread_workers(begins, workers_user_items)
|
||||
self._create_thread_workers(optimizable_dataset, begins, workers_user_items)
|
||||
else:
|
||||
self._create_process_workers(begins, workers_user_items)
|
||||
self._create_process_workers(optimizable_dataset, begins, workers_user_items)
|
||||
|
||||
print("Workers are ready ! Starting data processing...")
|
||||
|
||||
|
@ -634,7 +618,9 @@ class DatasetOptimizer(ABC):
|
|||
w.join(0)
|
||||
raise RuntimeError(f"We found the following error {error}.")
|
||||
|
||||
def _create_thread_workers(self, begins: List[int], workers_user_items: List[List[Any]]) -> None:
|
||||
def _create_thread_workers(
|
||||
self, optimizable_dataset: _OptimizableDataset, begins: List[int], workers_user_items: List[List[Any]]
|
||||
) -> None:
|
||||
current_total = 0
|
||||
total = sum([len(w) for w in workers_user_items])
|
||||
with tqdm(total=total, smoothing=0) as pbar:
|
||||
|
@ -649,7 +635,7 @@ class DatasetOptimizer(ABC):
|
|||
begins[worker_idx],
|
||||
self.name,
|
||||
_get_node_rank(),
|
||||
self,
|
||||
optimizable_dataset.prepare_item,
|
||||
self.src_dir,
|
||||
self.remote_src_dir,
|
||||
self.remote_dst_dir,
|
||||
|
@ -676,7 +662,9 @@ class DatasetOptimizer(ABC):
|
|||
if current_total == total:
|
||||
break
|
||||
|
||||
def _create_process_workers(self, begins: List[int], workers_user_items: List[List[Any]]) -> None:
|
||||
def _create_process_workers(
|
||||
self, optimizable_dataset: _OptimizableDataset, begins: List[int], workers_user_items: List[List[Any]]
|
||||
) -> None:
|
||||
self.progress_queue = Queue()
|
||||
workers: List[DataWorkerProcess] = []
|
||||
stop_queues: List[Queue] = []
|
||||
|
@ -688,7 +676,7 @@ class DatasetOptimizer(ABC):
|
|||
begins[worker_idx],
|
||||
self.name,
|
||||
_get_node_rank(),
|
||||
self,
|
||||
optimizable_dataset.prepare_item,
|
||||
self.src_dir,
|
||||
self.remote_src_dir,
|
||||
self.remote_dst_dir,
|
||||
|
|
|
@ -62,7 +62,7 @@ class PyTreeLoader(BaseItemLoader):
|
|||
return intervals
|
||||
|
||||
def load_item_from_chunk(self, index: int, chunk_index: int, chunk_filepath: str, begin: int) -> bytes:
|
||||
offset = (1 + (index - begin)) * 4
|
||||
offset = (1 + (index - begin) if index >= begin else index + 1) * 4
|
||||
|
||||
while not os.path.exists(chunk_filepath):
|
||||
sleep(0.0001)
|
||||
|
@ -115,9 +115,10 @@ class TokensLoader(BaseItemLoader):
|
|||
end = 0
|
||||
for chunk in self._chunks:
|
||||
dim = chunk["dim"]
|
||||
end += dim // self._block_size
|
||||
num_blocks = dim // self._block_size
|
||||
end += num_blocks
|
||||
self._intervals.append((begin, end))
|
||||
begin += end
|
||||
begin += num_blocks
|
||||
|
||||
return self._intervals
|
||||
|
||||
|
@ -136,5 +137,5 @@ class TokensLoader(BaseItemLoader):
|
|||
assert self._dtype
|
||||
|
||||
buffer: bytes = self._buffers[chunk_index]
|
||||
offset = self._dtype.itemsize * index * self._block_size
|
||||
offset = self._dtype.itemsize * index
|
||||
return torch.frombuffer(buffer, dtype=self._dtype, count=self._block_size, offset=offset)
|
||||
|
|
|
@ -47,7 +47,7 @@ def test_cp_local_to_remote(tmpdir, monkeypatch):
|
|||
)
|
||||
|
||||
result = MagicMock()
|
||||
result.get.return_value = V1UploadProjectArtifactResponse(upload_url="http://foo.bar")
|
||||
result.get.return_value = V1UploadProjectArtifactResponse(urls=["http://foo.bar"])
|
||||
client.lightningapp_instance_service_upload_project_artifact.return_value = result
|
||||
|
||||
monkeypatch.setattr(cp, "LightningClient", MagicMock(return_value=client))
|
||||
|
|
|
@ -23,10 +23,12 @@ from lightning.data.datasets.env import _DistributedEnv
|
|||
from lightning.data.streaming import Cache
|
||||
from lightning.data.streaming import cache as cache_module
|
||||
from lightning.data.streaming.dataloader import StreamingDataLoader
|
||||
from lightning.data.streaming.dataset import StreamingDataset
|
||||
from lightning.data.streaming.item_loader import TokensLoader
|
||||
from lightning.fabric import Fabric
|
||||
from lightning.pytorch.demos.boring_classes import RandomDataset
|
||||
from lightning_utilities.core.imports import RequirementCache
|
||||
from torch.utils.data import Dataset
|
||||
from torch.utils.data import DataLoader, Dataset
|
||||
|
||||
_PIL_AVAILABLE = RequirementCache("PIL")
|
||||
_TORCH_VISION_AVAILABLE = RequirementCache("torchvision")
|
||||
|
@ -113,11 +115,23 @@ def _cache_for_image_dataset(num_workers, tmpdir, fabric=None):
|
|||
|
||||
assert indexes2 != indexes
|
||||
|
||||
streaming_dataset = StreamingDataset(name="dummy", cache_dir=cache_dir)
|
||||
for i in range(len(streaming_dataset)):
|
||||
cached_data = streaming_dataset[i]
|
||||
original_data = dataset.data[i]
|
||||
assert cached_data["class"] == original_data["class"]
|
||||
original_array = PILToTensor()(Image.open(original_data["image"]))
|
||||
assert torch.equal(original_array, cached_data["image"])
|
||||
|
||||
streaming_dataset_iter = iter(streaming_dataset)
|
||||
for _ in streaming_dataset_iter:
|
||||
pass
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
condition=not _PIL_AVAILABLE or not _TORCH_VISION_AVAILABLE, reason="Requires: ['pil', 'torchvision']"
|
||||
)
|
||||
@pytest.mark.parametrize("num_workers", [1])
|
||||
@pytest.mark.parametrize("num_workers", [0, 1, 2])
|
||||
def test_cache_for_image_dataset(num_workers, tmpdir):
|
||||
cache_dir = os.path.join(tmpdir, "cache")
|
||||
os.makedirs(cache_dir)
|
||||
|
@ -218,3 +232,27 @@ def test_cache_with_name(tmpdir, monkeypatch):
|
|||
assert cache._writer._chunk_size == 2
|
||||
assert cache._writer._cache_dir == os.path.join(tmpdir, "something")
|
||||
assert cache._reader._remote_dir == os.path.join(tmpdir, "remote_dir")
|
||||
|
||||
|
||||
def test_streaming_dataset(tmpdir, monkeypatch):
|
||||
seed_everything(42)
|
||||
|
||||
os.makedirs(os.path.join(tmpdir, "remote_dir"), exist_ok=True)
|
||||
monkeypatch.setattr(cache_module, "_try_create_cache_dir", lambda name: tmpdir)
|
||||
|
||||
with pytest.raises(ValueError, match="The provided dataset `choco` isn't filled up."):
|
||||
dataset = StreamingDataset(name="choco", cache_dir=tmpdir)
|
||||
|
||||
dataset = RandomDataset(128, 64)
|
||||
dataloader = StreamingDataLoader(dataset, cache_dir=tmpdir, chunk_bytes=2 << 12)
|
||||
for batch in dataloader:
|
||||
assert isinstance(batch, torch.Tensor)
|
||||
|
||||
dataset = StreamingDataset(name="choco", cache_dir=tmpdir, item_loader=TokensLoader(block_size=10))
|
||||
|
||||
assert len(dataset) == 816
|
||||
dataset_iter = iter(dataset)
|
||||
assert len(dataset_iter) == 816
|
||||
|
||||
dataloader = DataLoader(dataset, num_workers=2, batch_size=2)
|
||||
assert len(dataloader) == 408
|
||||
|
|
|
@ -6,7 +6,7 @@ from unittest import mock
|
|||
import numpy as np
|
||||
import pytest
|
||||
import torch
|
||||
from lightning import seed_everything
|
||||
from lightning import LightningDataModule, seed_everything
|
||||
from lightning.data.streaming.dataset_optimizer import (
|
||||
DatasetOptimizer,
|
||||
_download_data_target,
|
||||
|
@ -131,11 +131,14 @@ def test_download_data_target(tmpdir):
|
|||
assert os.listdir(cache_dir) == ["a.txt"]
|
||||
|
||||
|
||||
class TestDatasetOptimizer(DatasetOptimizer):
|
||||
class DataModuleImage(LightningDataModule):
|
||||
def prepare_dataset_structure(self, src_dir: str, filepaths: List[str]) -> List[Any]:
|
||||
assert len(filepaths) == 30
|
||||
return filepaths
|
||||
|
||||
def prepare_item(self, item):
|
||||
return item
|
||||
|
||||
|
||||
@pytest.mark.parametrize("delete_cached_files", [False, True])
|
||||
@pytest.mark.parametrize("fast_dev_run", [False, True])
|
||||
|
@ -154,7 +157,7 @@ def test_data_optimizer(fast_dev_run, delete_cached_files, tmpdir, monkeypatch):
|
|||
cache_dir = os.path.join(tmpdir, "cache")
|
||||
monkeypatch.setenv("DATA_OPTIMIZER_HOME_FOLDER", home_dir)
|
||||
monkeypatch.setenv("DATA_OPTIMIZER_CACHE_FOLDER", cache_dir)
|
||||
datasetOptimizer = TestDatasetOptimizer(
|
||||
dataset_optimizer = DatasetOptimizer(
|
||||
name="dummy_dataset",
|
||||
src_dir=tmpdir,
|
||||
chunk_size=2,
|
||||
|
@ -165,7 +168,7 @@ def test_data_optimizer(fast_dev_run, delete_cached_files, tmpdir, monkeypatch):
|
|||
delete_cached_files=delete_cached_files,
|
||||
fast_dev_run=fast_dev_run,
|
||||
)
|
||||
datasetOptimizer.run()
|
||||
dataset_optimizer.run(DataModuleImage())
|
||||
|
||||
assert sorted(os.listdir(cache_dir)) == ["data", "dummy_dataset"]
|
||||
|
||||
|
@ -242,7 +245,7 @@ def test_data_optimizer_distributed(fast_dev_run, delete_cached_files, tmpdir, m
|
|||
monkeypatch.setenv("DATA_OPTIMIZER_CACHE_FOLDER", cache_dir)
|
||||
monkeypatch.setenv("DATA_OPTIMIZER_NUM_NODES", "2")
|
||||
monkeypatch.setenv("DATA_OPTIMIZER_NODE_RANK", "0")
|
||||
datasetOptimizer = TestDatasetOptimizer(
|
||||
dataset_optimizer = DatasetOptimizer(
|
||||
name="dummy_dataset",
|
||||
src_dir=tmpdir,
|
||||
chunk_size=2,
|
||||
|
@ -254,7 +257,7 @@ def test_data_optimizer_distributed(fast_dev_run, delete_cached_files, tmpdir, m
|
|||
fast_dev_run=fast_dev_run,
|
||||
remote_dst_dir=remote_dst_dir,
|
||||
)
|
||||
datasetOptimizer.run()
|
||||
dataset_optimizer.run(DataModuleImage())
|
||||
|
||||
assert sorted(os.listdir(cache_dir)) == ["data", "dummy_dataset"]
|
||||
|
||||
|
@ -276,7 +279,7 @@ def test_data_optimizer_distributed(fast_dev_run, delete_cached_files, tmpdir, m
|
|||
monkeypatch.setenv("DATA_OPTIMIZER_CACHE_FOLDER", cache_dir)
|
||||
monkeypatch.setenv("DATA_OPTIMIZER_NUM_NODES", "2")
|
||||
monkeypatch.setenv("DATA_OPTIMIZER_NODE_RANK", "1")
|
||||
datasetOptimizer = TestDatasetOptimizer(
|
||||
dataset_optimizer = DatasetOptimizer(
|
||||
name="dummy_dataset",
|
||||
src_dir=tmpdir,
|
||||
chunk_size=2,
|
||||
|
@ -288,7 +291,7 @@ def test_data_optimizer_distributed(fast_dev_run, delete_cached_files, tmpdir, m
|
|||
fast_dev_run=fast_dev_run,
|
||||
remote_dst_dir=remote_dst_dir,
|
||||
)
|
||||
datasetOptimizer.run()
|
||||
dataset_optimizer.run(DataModuleImage())
|
||||
|
||||
assert sorted(os.listdir(cache_dir)) == ["data", "dummy_dataset"]
|
||||
|
||||
|
@ -309,11 +312,13 @@ def test_data_optimizer_distributed(fast_dev_run, delete_cached_files, tmpdir, m
|
|||
assert sorted(os.listdir(remote_dst_dir)) == expected
|
||||
|
||||
|
||||
class NlpDatasetOptimizer(DatasetOptimizer):
|
||||
def prepare_dataset_structure(self, src_dir: str, filepaths: List[str]) -> List[Any]:
|
||||
class DataModule(LightningDataModule):
|
||||
@staticmethod
|
||||
def prepare_dataset_structure(src_dir: str, filepaths: List[str]) -> List[Any]:
|
||||
return [os.path.join(src_dir, "dummy2")]
|
||||
|
||||
def prepare_item(self, filepath):
|
||||
@staticmethod
|
||||
def prepare_item(filepath):
|
||||
for _ in range(100):
|
||||
yield torch.randint(0, 1000, (np.random.randint(0, 1000),)).to(torch.int)
|
||||
|
||||
|
@ -327,7 +332,15 @@ def test_data_optimizer_nlp(tmpdir, monkeypatch):
|
|||
with open(os.path.join(tmpdir, "dummy.txt"), "w") as f:
|
||||
f.write("Hello World !")
|
||||
|
||||
dataset_optimizer = NlpDatasetOptimizer(
|
||||
dataset_optimizer = DatasetOptimizer(
|
||||
name="dummy2", src_dir=tmpdir, num_workers=1, num_downloaders=1, chunk_size=1024 * 11
|
||||
)
|
||||
dataset_optimizer.run()
|
||||
dataset_optimizer.run(DataModule())
|
||||
|
||||
|
||||
def test_data_optimizer_api(tmpdir):
|
||||
dataset_optimizer = DatasetOptimizer(
|
||||
name="dummy2", src_dir=tmpdir, num_workers=1, num_downloaders=1, chunk_size=1024 * 11
|
||||
)
|
||||
with pytest.raises(ValueError, match="prepare_dataset_structure"):
|
||||
dataset_optimizer.run(None)
|
||||
|
|
Loading…
Reference in New Issue