Improve DatasetOptimizer API (#18827)

Co-authored-by: Justus Schock <12886177+justusschock@users.noreply.github.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: thomas <thomas@thomass-MacBook-Pro.local>
This commit is contained in:
thomas chaton 2023-10-23 18:06:48 +01:00 committed by GitHub
parent 1a5718aa38
commit e59dc41c8e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 241 additions and 93 deletions

View File

@ -1,4 +1,4 @@
lightning-cloud ==0.5.42 # Must be pinned to ensure compatibility
lightning-cloud ==0.5.43 # Must be pinned to ensure compatibility
packaging
typing-extensions >=4.0.0, <4.8.0
deepdiff >=5.7.0, <6.6.0

View File

@ -113,7 +113,7 @@ def _upload_files(live, client: LightningClient, local_src: str, remote_dst: str
else:
upload_paths = [local_src]
upload_urls = []
_upload_urls = []
clusters = client.projects_service_list_project_cluster_bindings(project_id)
@ -129,9 +129,11 @@ def _upload_files(live, client: LightningClient, local_src: str, remote_dst: str
body=ProjectIdStorageBody(cluster_id=cluster.cluster_id, filename=filename),
async_req=True,
)
upload_urls.append(response)
_upload_urls.append(response)
upload_urls = [upload_url.get().upload_url for upload_url in upload_urls]
upload_urls = []
for upload_url in _upload_urls:
upload_urls.extend(upload_url.get().urls)
live.stop()

View File

@ -1,5 +1,12 @@
from lightning.data.datasets import LightningDataset, LightningIterableDataset
from lightning.data.streaming.dataloader import StreamingDataLoader
from lightning.data.streaming.dataset import StreamingDataset
from lightning.data.streaming.dataset_optimizer import DatasetOptimizer
__all__ = ["LightningDataset", "StreamingDataset", "StreamingDataLoader", "LightningIterableDataset"]
__all__ = [
"LightningDataset",
"StreamingDataset",
"StreamingDataLoader",
"LightningIterableDataset",
"DatasetOptimizer",
]

View File

@ -11,18 +11,28 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Literal, Optional, Union
from typing import Any, List, Literal, Optional, Union
from torch.utils.data import Dataset
import numpy as np
from torch.utils.data import IterableDataset
from lightning.data.datasets.env import _DistributedEnv, _WorkerEnv
from lightning.data.streaming import Cache
from lightning.data.streaming.item_loader import BaseItemLoader
from lightning.data.streaming.sampler import ChunkedIndex
class StreamingDataset(Dataset):
class StreamingDataset(IterableDataset):
"""The streaming dataset can be used once your data have been optimised using the DatasetOptimiser class."""
def __init__(
self, name: str, version: Optional[Union[int, Literal["latest"]]] = "latest", cache_dir: Optional[str] = None
self,
name: str,
version: Optional[Union[int, Literal["latest"]]] = "latest",
cache_dir: Optional[str] = None,
item_loader: Optional[BaseItemLoader] = None,
shuffle: bool = True,
seed: int = 42,
) -> None:
"""The streaming dataset can be used once your data have been optimised using the DatasetOptimiser class.
@ -30,17 +40,106 @@ class StreamingDataset(Dataset):
name: The name of the optimised dataset.
version: The version of the dataset to use.
cache_dir: The cache dir where the data would be stored.
item_loader: The logic to load an item from a chunk.
shuffle: Whether to shuffle the data.
seed: Random seed for shuffling.
"""
super().__init__()
self.cache = Cache(name=name, version=version, cache_dir=cache_dir)
self.cache = Cache(name=name, version=version, cache_dir=cache_dir, item_loader=item_loader, chunk_bytes=1)
self.cache._reader._try_load_config()
if not self.cache.filled:
raise ValueError(f"The provided dataset `{name}` isn't filled up.")
self.shuffle = shuffle
self.distributed_env = _DistributedEnv.detect()
self.worker_env: Optional[_WorkerEnv] = None
chunk_intervals = self.cache.get_chunk_interval()
self.L = sum([(interval[-1] - interval[0]) for interval in chunk_intervals])
self.worker_chunks: List[int] = []
self.worker_intervals: List[List[int]] = []
self.current_indexes: List[int] = []
self.chunk_index = 0
self.index = 0
self.has_triggered_download = False
self.min_items_per_replica: Optional[int] = None
self.seed = seed
self.num_iter = 0
self.random_state = None
def __len__(self) -> int:
return len(self.cache)
return self.L
def __getitem__(self, idx: int) -> Any:
return self.cache[idx]
def __iter__(self) -> "StreamingDataset":
self.random_state = np.random.RandomState(seed=self.seed + self.num_iter) # type: ignore
chunk_intervals = self.cache.get_chunk_interval()
indexes = range(len(chunk_intervals))
shuffled_indexes = self.random_state.permutation(indexes) if self.shuffle else list(indexes)
shuffled_chunk_intervals = np.asarray(chunk_intervals)[shuffled_indexes]
def getitem(self, obj: Any) -> Any:
"""Override the getitem with your own logic to transform the cache object."""
return obj
chunks_per_replica: List[List[int]] = [[] for _ in range(self.distributed_env.world_size)]
intervals_per_replica: List[List[List[int]]] = [[] for _ in range(self.distributed_env.world_size)]
for index, (chunk_index, chunk_interval) in enumerate(zip(shuffled_indexes, shuffled_chunk_intervals)):
replica_index = index % self.distributed_env.world_size
chunks_per_replica[replica_index].append(chunk_index)
intervals_per_replica[replica_index].append(chunk_interval)
current_chunks = chunks_per_replica[self.distributed_env.global_rank % self.distributed_env.world_size]
current_intervals = intervals_per_replica[self.distributed_env.global_rank % self.distributed_env.world_size]
if self.worker_env is None:
self.worker_env = _WorkerEnv.detect()
self.worker_chunks = []
self.worker_intervals = []
for i, (chunk_index, chunk_interval) in enumerate(zip(current_chunks, current_intervals)):
if i % self.worker_env.world_size != self.worker_env.rank:
continue
self.worker_chunks.append(chunk_index)
self.worker_intervals.append(chunk_interval)
self.current_indexes = []
self.chunk_index = 0
self.num_iter += 1
return self
def __getitem__(self, index: Union[ChunkedIndex, int]) -> Any:
if isinstance(index, int):
index = ChunkedIndex(index, self.cache._get_chunk_index_from_index(index))
return self.cache[index]
def __next__(self) -> Any:
# Lazily re-populate the interval to reduce memory usage.
if len(self.current_indexes) == 0:
if self.chunk_index == len(self.worker_intervals):
raise StopIteration
interval = self.worker_intervals[self.chunk_index]
current_indexes = np.arange(0, interval[1] - interval[0])
if self.shuffle:
current_indexes = self.random_state.permutation(current_indexes)
self.current_indexes = current_indexes.tolist()
self.chunk_index += 1
# Get the first index
index = self.current_indexes.pop(0)
# Call the `__getitem__` method.
data = self.__getitem__(
ChunkedIndex(
index=index,
chunk_index=self.worker_chunks[self.chunk_index - 1],
chunk_indexes=None if self.has_triggered_download else self.worker_chunks,
)
)
self.has_triggered_download = True
self.index += 1
return data

View File

@ -3,15 +3,15 @@ import os
import signal
import traceback
import types
from abc import ABC, abstractmethod
from enum import Enum
from multiprocessing import Process, Queue
from pathlib import Path
from queue import Empty
from shutil import copyfile
from textwrap import dedent
from threading import Thread
from time import sleep, time
from typing import Any, Callable, Dict, List, Literal, Optional, Tuple
from typing import Any, Callable, Dict, List, Literal, Optional, Protocol, Tuple, TypeVar, runtime_checkable
from urllib import parse
from tqdm.auto import tqdm
@ -167,7 +167,7 @@ class BaseWorker:
start_index: int,
dataset_name: str,
node_rank: int,
dataset_optimizer: "DatasetOptimizer",
prepare_item: Callable,
src_dir: str,
remote_src_dir: str,
remote_dst_dir: Optional[str],
@ -187,7 +187,7 @@ class BaseWorker:
self.start_index = start_index
self.dataset_name = dataset_name
self.node_rank = node_rank
self.prepare_item = dataset_optimizer.prepare_item
self.prepare_item = prepare_item
self.src_dir = src_dir
self.remote_src_dir = remote_src_dir
self.remote_dst_dir = remote_dst_dir
@ -432,57 +432,21 @@ class WorkerType(Enum):
PROCESS = "process"
class DatasetOptimizer(ABC):
@abstractmethod
def prepare_dataset_structure(self, src_dir: str, filepaths: List[str]) -> List[Any]:
"""This function is meant to return a list of item metadata. Each item metadata should be enough to prepare a
single item when called with the prepare_item.
T = TypeVar("T")
Example::
# For a classification use case
def prepare_dataset_structure(self, src_dir, filepaths)
import numpy as np
filepaths = ['class_a/file_1.ext', ..., 'class_b/file_1.ext', ...]
classes = np.unique([filepath.split("/")[0] for filepath in filepaths])
classes_to_idx_map = {c: idx for idx, c in enumerate(classes)}
# Return pair with the filepath to the obj and its class
# [('class_a/file_1.ext', 0), ... ('class_b/file_1.ext', 1)]
return [(filepath, classes_to_idx_map[filepath.split("/")[0]]) for filepath in filepaths]
Example::
# For a image segmentation use case
def prepare_dataset_structure(self, src_dir, filepaths)
import numpy as np
filepaths = ['file_1.JPEG', 'file_1.mask', .... 'file_N.JPEG', 'file_N.mask', ...]
# [('file_1.JPEG', 'file_1.mask'), ... ('file_N.JPEG', 'file_N.mask')]
return [(x[i], x[i+1]) for i in range(len(filepaths) -1)]
def prepare_item(self, obj):
image_filepath, mask_filepath = obj
image = load_and_resize(image_filepath)
mask = load_and_resize(mask_filepath)
return (image, mask)
"""
@runtime_checkable
class _OptimizableDataset(Protocol):
@staticmethod
def prepare_dataset_structure(root: str, filepaths: List[str]) -> List[T]:
pass
def prepare_item(self, metadata_item: Any) -> Any:
"""Using some metadata, prepare the associated item.
@staticmethod
def prepare_item(item_metadata: T) -> Any:
return item_metadata
The output of this function will be binarised
"""
return metadata_item
class DatasetOptimizer:
def __init__(
self,
name: str,
@ -547,9 +511,29 @@ class DatasetOptimizer(ABC):
)
self.random_seed = random_seed
def run(self) -> None:
def run(self, optimizable_dataset: _OptimizableDataset) -> None:
"""The `DatasetChunker.run(...)` method is used to trigger the data processing from your dataset into
chunks."""
if not isinstance(optimizable_dataset, _OptimizableDataset):
raise ValueError(
dedent(
"""The provided argument to the DatasetOptimizer.run(...) needs to have the following format:
Example:
class YourDataset:
@staticmethod
def prepare_dataset_structure(root: str, filepaths: List[str]) -> List[T]:
return [...]
@staticmethod
def prepare_item(item_metadata: T) -> Any:
return ...
"""
)
)
t0 = time()
print(f"Setup started for `{self.name}` with fast_dev_run={self.fast_dev_run}.")
@ -564,7 +548,7 @@ class DatasetOptimizer(ABC):
seed_everything(self.random_seed)
# Call the setup method of the user
user_items = self.prepare_dataset_structure(self.src_dir, filepaths)
user_items: List[Any] = optimizable_dataset.prepare_dataset_structure(self.src_dir, filepaths)
if not isinstance(user_items, list):
raise ValueError("The setup_fn should return a list of item metadata.")
@ -588,9 +572,9 @@ class DatasetOptimizer(ABC):
signal.signal(signal.SIGINT, self._signal_handler)
if self.worker_type == WorkerType.THREAD.value:
self._create_thread_workers(begins, workers_user_items)
self._create_thread_workers(optimizable_dataset, begins, workers_user_items)
else:
self._create_process_workers(begins, workers_user_items)
self._create_process_workers(optimizable_dataset, begins, workers_user_items)
print("Workers are ready ! Starting data processing...")
@ -634,7 +618,9 @@ class DatasetOptimizer(ABC):
w.join(0)
raise RuntimeError(f"We found the following error {error}.")
def _create_thread_workers(self, begins: List[int], workers_user_items: List[List[Any]]) -> None:
def _create_thread_workers(
self, optimizable_dataset: _OptimizableDataset, begins: List[int], workers_user_items: List[List[Any]]
) -> None:
current_total = 0
total = sum([len(w) for w in workers_user_items])
with tqdm(total=total, smoothing=0) as pbar:
@ -649,7 +635,7 @@ class DatasetOptimizer(ABC):
begins[worker_idx],
self.name,
_get_node_rank(),
self,
optimizable_dataset.prepare_item,
self.src_dir,
self.remote_src_dir,
self.remote_dst_dir,
@ -676,7 +662,9 @@ class DatasetOptimizer(ABC):
if current_total == total:
break
def _create_process_workers(self, begins: List[int], workers_user_items: List[List[Any]]) -> None:
def _create_process_workers(
self, optimizable_dataset: _OptimizableDataset, begins: List[int], workers_user_items: List[List[Any]]
) -> None:
self.progress_queue = Queue()
workers: List[DataWorkerProcess] = []
stop_queues: List[Queue] = []
@ -688,7 +676,7 @@ class DatasetOptimizer(ABC):
begins[worker_idx],
self.name,
_get_node_rank(),
self,
optimizable_dataset.prepare_item,
self.src_dir,
self.remote_src_dir,
self.remote_dst_dir,

View File

@ -62,7 +62,7 @@ class PyTreeLoader(BaseItemLoader):
return intervals
def load_item_from_chunk(self, index: int, chunk_index: int, chunk_filepath: str, begin: int) -> bytes:
offset = (1 + (index - begin)) * 4
offset = (1 + (index - begin) if index >= begin else index + 1) * 4
while not os.path.exists(chunk_filepath):
sleep(0.0001)
@ -115,9 +115,10 @@ class TokensLoader(BaseItemLoader):
end = 0
for chunk in self._chunks:
dim = chunk["dim"]
end += dim // self._block_size
num_blocks = dim // self._block_size
end += num_blocks
self._intervals.append((begin, end))
begin += end
begin += num_blocks
return self._intervals
@ -136,5 +137,5 @@ class TokensLoader(BaseItemLoader):
assert self._dtype
buffer: bytes = self._buffers[chunk_index]
offset = self._dtype.itemsize * index * self._block_size
offset = self._dtype.itemsize * index
return torch.frombuffer(buffer, dtype=self._dtype, count=self._block_size, offset=offset)

View File

@ -47,7 +47,7 @@ def test_cp_local_to_remote(tmpdir, monkeypatch):
)
result = MagicMock()
result.get.return_value = V1UploadProjectArtifactResponse(upload_url="http://foo.bar")
result.get.return_value = V1UploadProjectArtifactResponse(urls=["http://foo.bar"])
client.lightningapp_instance_service_upload_project_artifact.return_value = result
monkeypatch.setattr(cp, "LightningClient", MagicMock(return_value=client))

View File

@ -23,10 +23,12 @@ from lightning.data.datasets.env import _DistributedEnv
from lightning.data.streaming import Cache
from lightning.data.streaming import cache as cache_module
from lightning.data.streaming.dataloader import StreamingDataLoader
from lightning.data.streaming.dataset import StreamingDataset
from lightning.data.streaming.item_loader import TokensLoader
from lightning.fabric import Fabric
from lightning.pytorch.demos.boring_classes import RandomDataset
from lightning_utilities.core.imports import RequirementCache
from torch.utils.data import Dataset
from torch.utils.data import DataLoader, Dataset
_PIL_AVAILABLE = RequirementCache("PIL")
_TORCH_VISION_AVAILABLE = RequirementCache("torchvision")
@ -113,11 +115,23 @@ def _cache_for_image_dataset(num_workers, tmpdir, fabric=None):
assert indexes2 != indexes
streaming_dataset = StreamingDataset(name="dummy", cache_dir=cache_dir)
for i in range(len(streaming_dataset)):
cached_data = streaming_dataset[i]
original_data = dataset.data[i]
assert cached_data["class"] == original_data["class"]
original_array = PILToTensor()(Image.open(original_data["image"]))
assert torch.equal(original_array, cached_data["image"])
streaming_dataset_iter = iter(streaming_dataset)
for _ in streaming_dataset_iter:
pass
@pytest.mark.skipif(
condition=not _PIL_AVAILABLE or not _TORCH_VISION_AVAILABLE, reason="Requires: ['pil', 'torchvision']"
)
@pytest.mark.parametrize("num_workers", [1])
@pytest.mark.parametrize("num_workers", [0, 1, 2])
def test_cache_for_image_dataset(num_workers, tmpdir):
cache_dir = os.path.join(tmpdir, "cache")
os.makedirs(cache_dir)
@ -218,3 +232,27 @@ def test_cache_with_name(tmpdir, monkeypatch):
assert cache._writer._chunk_size == 2
assert cache._writer._cache_dir == os.path.join(tmpdir, "something")
assert cache._reader._remote_dir == os.path.join(tmpdir, "remote_dir")
def test_streaming_dataset(tmpdir, monkeypatch):
seed_everything(42)
os.makedirs(os.path.join(tmpdir, "remote_dir"), exist_ok=True)
monkeypatch.setattr(cache_module, "_try_create_cache_dir", lambda name: tmpdir)
with pytest.raises(ValueError, match="The provided dataset `choco` isn't filled up."):
dataset = StreamingDataset(name="choco", cache_dir=tmpdir)
dataset = RandomDataset(128, 64)
dataloader = StreamingDataLoader(dataset, cache_dir=tmpdir, chunk_bytes=2 << 12)
for batch in dataloader:
assert isinstance(batch, torch.Tensor)
dataset = StreamingDataset(name="choco", cache_dir=tmpdir, item_loader=TokensLoader(block_size=10))
assert len(dataset) == 816
dataset_iter = iter(dataset)
assert len(dataset_iter) == 816
dataloader = DataLoader(dataset, num_workers=2, batch_size=2)
assert len(dataloader) == 408

View File

@ -6,7 +6,7 @@ from unittest import mock
import numpy as np
import pytest
import torch
from lightning import seed_everything
from lightning import LightningDataModule, seed_everything
from lightning.data.streaming.dataset_optimizer import (
DatasetOptimizer,
_download_data_target,
@ -131,11 +131,14 @@ def test_download_data_target(tmpdir):
assert os.listdir(cache_dir) == ["a.txt"]
class TestDatasetOptimizer(DatasetOptimizer):
class DataModuleImage(LightningDataModule):
def prepare_dataset_structure(self, src_dir: str, filepaths: List[str]) -> List[Any]:
assert len(filepaths) == 30
return filepaths
def prepare_item(self, item):
return item
@pytest.mark.parametrize("delete_cached_files", [False, True])
@pytest.mark.parametrize("fast_dev_run", [False, True])
@ -154,7 +157,7 @@ def test_data_optimizer(fast_dev_run, delete_cached_files, tmpdir, monkeypatch):
cache_dir = os.path.join(tmpdir, "cache")
monkeypatch.setenv("DATA_OPTIMIZER_HOME_FOLDER", home_dir)
monkeypatch.setenv("DATA_OPTIMIZER_CACHE_FOLDER", cache_dir)
datasetOptimizer = TestDatasetOptimizer(
dataset_optimizer = DatasetOptimizer(
name="dummy_dataset",
src_dir=tmpdir,
chunk_size=2,
@ -165,7 +168,7 @@ def test_data_optimizer(fast_dev_run, delete_cached_files, tmpdir, monkeypatch):
delete_cached_files=delete_cached_files,
fast_dev_run=fast_dev_run,
)
datasetOptimizer.run()
dataset_optimizer.run(DataModuleImage())
assert sorted(os.listdir(cache_dir)) == ["data", "dummy_dataset"]
@ -242,7 +245,7 @@ def test_data_optimizer_distributed(fast_dev_run, delete_cached_files, tmpdir, m
monkeypatch.setenv("DATA_OPTIMIZER_CACHE_FOLDER", cache_dir)
monkeypatch.setenv("DATA_OPTIMIZER_NUM_NODES", "2")
monkeypatch.setenv("DATA_OPTIMIZER_NODE_RANK", "0")
datasetOptimizer = TestDatasetOptimizer(
dataset_optimizer = DatasetOptimizer(
name="dummy_dataset",
src_dir=tmpdir,
chunk_size=2,
@ -254,7 +257,7 @@ def test_data_optimizer_distributed(fast_dev_run, delete_cached_files, tmpdir, m
fast_dev_run=fast_dev_run,
remote_dst_dir=remote_dst_dir,
)
datasetOptimizer.run()
dataset_optimizer.run(DataModuleImage())
assert sorted(os.listdir(cache_dir)) == ["data", "dummy_dataset"]
@ -276,7 +279,7 @@ def test_data_optimizer_distributed(fast_dev_run, delete_cached_files, tmpdir, m
monkeypatch.setenv("DATA_OPTIMIZER_CACHE_FOLDER", cache_dir)
monkeypatch.setenv("DATA_OPTIMIZER_NUM_NODES", "2")
monkeypatch.setenv("DATA_OPTIMIZER_NODE_RANK", "1")
datasetOptimizer = TestDatasetOptimizer(
dataset_optimizer = DatasetOptimizer(
name="dummy_dataset",
src_dir=tmpdir,
chunk_size=2,
@ -288,7 +291,7 @@ def test_data_optimizer_distributed(fast_dev_run, delete_cached_files, tmpdir, m
fast_dev_run=fast_dev_run,
remote_dst_dir=remote_dst_dir,
)
datasetOptimizer.run()
dataset_optimizer.run(DataModuleImage())
assert sorted(os.listdir(cache_dir)) == ["data", "dummy_dataset"]
@ -309,11 +312,13 @@ def test_data_optimizer_distributed(fast_dev_run, delete_cached_files, tmpdir, m
assert sorted(os.listdir(remote_dst_dir)) == expected
class NlpDatasetOptimizer(DatasetOptimizer):
def prepare_dataset_structure(self, src_dir: str, filepaths: List[str]) -> List[Any]:
class DataModule(LightningDataModule):
@staticmethod
def prepare_dataset_structure(src_dir: str, filepaths: List[str]) -> List[Any]:
return [os.path.join(src_dir, "dummy2")]
def prepare_item(self, filepath):
@staticmethod
def prepare_item(filepath):
for _ in range(100):
yield torch.randint(0, 1000, (np.random.randint(0, 1000),)).to(torch.int)
@ -327,7 +332,15 @@ def test_data_optimizer_nlp(tmpdir, monkeypatch):
with open(os.path.join(tmpdir, "dummy.txt"), "w") as f:
f.write("Hello World !")
dataset_optimizer = NlpDatasetOptimizer(
dataset_optimizer = DatasetOptimizer(
name="dummy2", src_dir=tmpdir, num_workers=1, num_downloaders=1, chunk_size=1024 * 11
)
dataset_optimizer.run()
dataset_optimizer.run(DataModule())
def test_data_optimizer_api(tmpdir):
dataset_optimizer = DatasetOptimizer(
name="dummy2", src_dir=tmpdir, num_workers=1, num_downloaders=1, chunk_size=1024 * 11
)
with pytest.raises(ValueError, match="prepare_dataset_structure"):
dataset_optimizer.run(None)