Cache directory per worker to avoid collisions (#18957)

This commit is contained in:
Adrian Wälchli 2023-11-07 16:19:03 +01:00 committed by GitHub
parent 529f07f254
commit 8a5d3423a7
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 166 additions and 90 deletions

View File

@ -57,8 +57,6 @@ class Cache:
Arguments:
input_dir: The path to where the chunks will be or are stored.
name: The name of dataset in the cloud.
version: The version of the dataset in the cloud to use. By default, we will use the latest.
compression: The name of the algorithm to reduce the size of the chunks.
chunk_bytes: The maximum number of bytes within a chunk.
chunk_size: The maximum number of items within a chunk.

View File

@ -11,6 +11,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import copy
import hashlib
import os
from typing import Any, List, Optional, Union
@ -18,7 +19,7 @@ from typing import Any, List, Optional, Union
import numpy as np
from torch.utils.data import IterableDataset
from lightning.data.datasets.env import _DistributedEnv, _WorkerEnv
from lightning.data.datasets.env import Environment, _DistributedEnv, _WorkerEnv
from lightning.data.streaming import Cache
from lightning.data.streaming.constants import _INDEX_FILENAME, _LIGHTNING_CLOUD_GREATER_EQUAL_0_5_50
from lightning.data.streaming.item_loader import BaseItemLoader
@ -57,27 +58,14 @@ class StreamingDataset(IterableDataset):
input_dir = _resolve_dir(input_dir)
# Override the provided input_path
cache_dir = _try_create_cache_dir(input_dir.path)
if cache_dir:
input_dir.path = cache_dir
self.cache = Cache(input_dir=input_dir, item_loader=item_loader, chunk_bytes=1)
self.cache._reader._try_load_config()
if not self.cache.filled:
raise ValueError(
f"The provided dataset `{input_dir}` doesn't contain any {_INDEX_FILENAME} file."
" HINT: Did you successfully optimize a dataset to the provided `input_dir` ?"
)
self.distributed_env = _DistributedEnv.detect()
self.shuffle: Shuffle = (
FullShuffle(self.cache, seed, drop_last) if shuffle else NoShuffle(self.cache, seed, drop_last)
)
self.input_dir = input_dir
self.item_loader = item_loader
self.shuffle: bool = shuffle
self.drop_last = drop_last
self.seed = seed
self.cache: Optional[Cache] = None
self.distributed_env = _DistributedEnv.detect()
self.worker_env: Optional[_WorkerEnv] = None
self.worker_chunks: List[int] = []
self.worker_intervals: List[List[int]] = []
@ -86,23 +74,52 @@ class StreamingDataset(IterableDataset):
self.index = 0
self.has_triggered_download = False
self.min_items_per_replica: Optional[int] = None
self.seed = seed
self.current_epoch = 0
self.random_state = None
self.shuffler: Optional[Shuffle] = None
def _create_cache(self, worker_env: _WorkerEnv) -> Cache:
env = Environment(dist_env=self.distributed_env, worker_env=worker_env)
cache_path = _try_create_cache_dir(input_dir=self.input_dir.path, shard_rank=env.shard_rank)
cache_dir = copy.deepcopy(self.input_dir)
if cache_path:
cache_dir.path = cache_path
cache = Cache(input_dir=cache_dir, item_loader=self.item_loader, chunk_bytes=1)
cache._reader._try_load_config()
if not cache.filled:
raise ValueError(
f"The provided dataset `{self.input_dir}` doesn't contain any {_INDEX_FILENAME} file."
" HINT: Did you successfully optimize a dataset to the provided `input_dir`?"
)
return cache
def _create_shuffler(self, cache: Cache) -> Shuffle:
return (
FullShuffle(cache, self.seed, self.drop_last)
if self.shuffle
else NoShuffle(cache, self.seed, self.drop_last)
)
def __len__(self) -> int:
return self.shuffle.get_len(self.distributed_env, self.current_epoch)
if self.shuffler is None:
cache = self._create_cache(worker_env=_WorkerEnv.detect())
self.shuffler = self._create_shuffler(cache)
return self.shuffler.get_len(self.distributed_env, self.current_epoch)
def __iter__(self) -> "StreamingDataset":
chunks_per_replica, intervals_per_replica = self.shuffle.get_chunks_and_intervals_per_ranks(
self.worker_env = _WorkerEnv.detect()
self.cache = self._create_cache(worker_env=self.worker_env)
self.shuffler = self._create_shuffler(self.cache)
chunks_per_replica, intervals_per_replica = self.shuffler.get_chunks_and_intervals_per_ranks(
self.distributed_env, self.current_epoch
)
current_chunks = chunks_per_replica[self.distributed_env.global_rank % self.distributed_env.world_size]
current_intervals = intervals_per_replica[self.distributed_env.global_rank % self.distributed_env.world_size]
if self.worker_env is None:
self.worker_env = _WorkerEnv.detect()
self.worker_chunks = []
self.worker_intervals = []
@ -119,6 +136,10 @@ class StreamingDataset(IterableDataset):
return self
def __getitem__(self, index: Union[ChunkedIndex, int]) -> Any:
if self.cache is None:
self.worker_env = _WorkerEnv.detect()
self.cache = self._create_cache(worker_env=self.worker_env)
self.shuffler = self._create_shuffler(self.cache)
if isinstance(index, int):
index = ChunkedIndex(index, self.cache._get_chunk_index_from_index(index))
return self.cache[index]
@ -137,7 +158,9 @@ class StreamingDataset(IterableDataset):
interval = self.worker_intervals[self.chunk_index]
current_indexes = np.arange(interval[0], interval[1])
self.current_indexes = self.shuffle(current_indexes)
assert self.shuffler is not None
self.current_indexes = self.shuffler(current_indexes)
self.chunk_index += 1
# Get the first index
@ -158,10 +181,10 @@ class StreamingDataset(IterableDataset):
return data
def _try_create_cache_dir(input_dir: str) -> Optional[str]:
def _try_create_cache_dir(input_dir: str, shard_rank: int = 0) -> Optional[str]:
if "LIGHTNING_CLUSTER_ID" not in os.environ or "LIGHTNING_CLOUD_PROJECT_ID" not in os.environ:
return None
hash_object = hashlib.md5(input_dir.encode())
cache_dir = os.path.join(f"/cache/chunks/{hash_object.hexdigest()}")
cache_dir = os.path.join("/cache", "chunks", hash_object.hexdigest(), str(shard_rank))
os.makedirs(cache_dir, exist_ok=True)
return cache_dir

View File

@ -23,12 +23,11 @@ from lightning.data.datasets.env import _DistributedEnv
from lightning.data.streaming import Cache
from lightning.data.streaming.dataloader import StreamingDataLoader
from lightning.data.streaming.dataset import StreamingDataset
from lightning.data.streaming.item_loader import TokensLoader
from lightning.fabric import Fabric
from lightning.pytorch.demos.boring_classes import RandomDataset
from lightning_utilities.core.imports import RequirementCache
from lightning_utilities.test.warning import no_warning_call
from torch.utils.data import DataLoader, Dataset
from torch.utils.data import Dataset
_PIL_AVAILABLE = RequirementCache("PIL")
_TORCH_VISION_AVAILABLE = RequirementCache("torchvision")
@ -222,29 +221,6 @@ def test_cache_with_auto_wrapping(tmpdir):
pass
def test_streaming_dataset(tmpdir, monkeypatch):
seed_everything(42)
os.makedirs(os.path.join(tmpdir, "remote_dir"), exist_ok=True)
with pytest.raises(ValueError, match="The provided dataset"):
dataset = StreamingDataset(input_dir=tmpdir)
dataset = RandomDataset(128, 64)
dataloader = StreamingDataLoader(dataset, cache_dir=tmpdir, chunk_bytes=2 << 12)
for batch in dataloader:
assert isinstance(batch, torch.Tensor)
dataset = StreamingDataset(input_dir=tmpdir, item_loader=TokensLoader(block_size=10))
assert len(dataset) == 816
dataset_iter = iter(dataset)
assert len(dataset_iter) == 816
dataloader = DataLoader(dataset, num_workers=2, batch_size=2)
assert len(dataloader) == 408
def test_create_oversized_chunk_single_item(tmp_path):
cache = Cache(str(tmp_path), chunk_bytes=700)
with pytest.warns(UserWarning, match="An item was larger than the target chunk size"):

View File

@ -12,49 +12,55 @@
# limitations under the License.
import os
from re import escape
from unittest import mock
import pytest
import torch
from lightning import seed_everything
from lightning.data.datasets.env import _DistributedEnv
from lightning.data.streaming import Cache
from lightning.data.streaming.dataloader import StreamingDataLoader
from lightning.data.streaming.dataset import StreamingDataset, _try_create_cache_dir
from lightning.data.streaming.item_loader import TokensLoader
from lightning.data.streaming.shuffle import FullShuffle, NoShuffle
from lightning.pytorch.demos.boring_classes import RandomDataset
from torch.utils.data import DataLoader
def test_streaming_dataset(tmpdir, monkeypatch):
seed_everything(42)
dataset = StreamingDataset(input_dir=tmpdir)
with pytest.raises(ValueError, match="The provided dataset"):
dataset = StreamingDataset(input_dir=tmpdir)
iter(dataset)
dataset = StreamingDataset(input_dir=tmpdir)
with pytest.raises(ValueError, match="The provided dataset"):
_ = dataset[0]
dataset = RandomDataset(128, 64)
dataloader = StreamingDataLoader(dataset, cache_dir=tmpdir, chunk_bytes=2 << 12)
for batch in dataloader:
assert isinstance(batch, torch.Tensor)
cache = Cache(tmpdir, chunk_size=10)
for i in range(12):
cache[i] = i
cache.done()
cache.merge()
dataset = StreamingDataset(input_dir=tmpdir, item_loader=TokensLoader(block_size=10))
dataset = StreamingDataset(input_dir=tmpdir)
assert len(dataset) == 816
assert len(dataset) == 12
dataset_iter = iter(dataset)
assert len(dataset_iter) == 816
assert len(dataset_iter) == 12
dataloader = DataLoader(dataset, num_workers=2, batch_size=1)
assert len(dataloader) == 12
dataloader = DataLoader(dataset, num_workers=2, batch_size=2)
assert len(dataloader) == 408
assert len(dataloader) == 6
@mock.patch.dict(os.environ, {"LIGHTNING_CLUSTER_ID": "123", "LIGHTNING_CLOUD_PROJECT_ID": "456"})
@mock.patch("lightning.data.streaming.dataset.os.makedirs")
def test_create_cache_dir_in_lightning_cloud(makedirs_mock, tmpdir):
# Locally, we can't actually write to the root filesystem with user privileges, so we need to mock the call
with pytest.raises(FileNotFoundError, match="`/cache/chunks/275876e34cf609db118f3d84b799a790` doesn't exist"):
StreamingDataset("dummy")
makedirs_mock.assert_called_once_with("/cache/chunks/275876e34cf609db118f3d84b799a790", exist_ok=True)
dataset = StreamingDataset("dummy")
expected = os.path.join("/cache", "chunks", "275876e34cf609db118f3d84b799a790", "0")
with pytest.raises(FileNotFoundError, match=escape(f"`{expected}` doesn't exist")):
iter(dataset)
makedirs_mock.assert_called_once_with(expected, exist_ok=True)
@pytest.mark.parametrize("drop_last", [False, True])
@ -69,8 +75,9 @@ def test_streaming_dataset_distributed_no_shuffle(drop_last, tmpdir):
cache.merge()
dataset = StreamingDataset(input_dir=tmpdir, shuffle=False, drop_last=drop_last)
assert isinstance(dataset.shuffle, NoShuffle)
assert not dataset.shuffle
_ = dataset[0] # init shuffler
assert isinstance(dataset.shuffler, NoShuffle)
for i in range(101):
assert dataset[i] == i
@ -105,7 +112,7 @@ def test_streaming_dataset_distributed_no_shuffle(drop_last, tmpdir):
assert len(process_2_2) == 50
_, intervals_per_ranks = dataset.shuffle.get_chunks_and_intervals_per_ranks(
_, intervals_per_ranks = dataset.shuffler.get_chunks_and_intervals_per_ranks(
dataset.distributed_env, dataset.current_epoch
)
@ -149,8 +156,9 @@ def test_streaming_dataset_distributed_full_shuffle_odd(drop_last, tmpdir):
cache.merge()
dataset = StreamingDataset(input_dir=tmpdir, shuffle=True, drop_last=drop_last)
assert isinstance(dataset.shuffle, FullShuffle)
assert dataset.shuffle
_ = dataset[0]
assert isinstance(dataset.shuffler, FullShuffle)
for i in range(1097):
assert dataset[i] == i
@ -164,7 +172,8 @@ def test_streaming_dataset_distributed_full_shuffle_odd(drop_last, tmpdir):
assert len(process_1_1) == 548
dataset_2 = StreamingDataset(input_dir=tmpdir, shuffle=True, drop_last=drop_last)
assert isinstance(dataset_2.shuffle, FullShuffle)
iter(dataset_2)
assert isinstance(dataset_2.shuffler, FullShuffle)
dataset_2.distributed_env = _DistributedEnv(2, 1)
assert len(dataset_2) == 548 + int(not drop_last)
dataset_2_iter = iter(dataset_2)
@ -187,8 +196,9 @@ def test_streaming_dataset_distributed_full_shuffle_even(drop_last, tmpdir):
cache.merge()
dataset = StreamingDataset(input_dir=tmpdir, shuffle=True, drop_last=drop_last)
assert isinstance(dataset.shuffle, FullShuffle)
assert dataset.shuffle
_ = dataset[0]
assert isinstance(dataset.shuffler, FullShuffle)
for i in range(1222):
assert dataset[i] == i
@ -202,7 +212,8 @@ def test_streaming_dataset_distributed_full_shuffle_even(drop_last, tmpdir):
assert len(process_1_1) == 611
dataset_2 = StreamingDataset(input_dir=tmpdir, shuffle=True, drop_last=drop_last)
assert isinstance(dataset_2.shuffle, FullShuffle)
iter(dataset_2)
assert isinstance(dataset_2.shuffler, FullShuffle)
dataset_2.distributed_env = _DistributedEnv(2, 1)
assert len(dataset_2) == 611
dataset_2_iter = iter(dataset_2)
@ -229,6 +240,9 @@ def test_streaming_dataset_deepcopy(tmpdir, monkeypatch):
cache.merge()
dataset = StreamingDataset(input_dir=remote_dir, shuffle=True)
assert dataset.cache is None
iter(dataset)
assert dataset.cache is not None
assert dataset.cache._reader._prepare_thread is None
dataset.cache._reader._prepare_thread = True
dataloader = DataLoader(dataset, num_workers=1)
@ -240,11 +254,75 @@ def test_streaming_dataset_deepcopy(tmpdir, monkeypatch):
assert len(batches) == 10
@mock.patch.dict(os.environ, {"LIGHTNING_CLUSTER_ID": "123", "LIGHTNING_CLOUD_PROJECT_ID": "456"})
@mock.patch("lightning.data.streaming.dataset.os.makedirs")
def test_try_create_cache_dir(makedirs, monkeypatch):
cache_dir_1 = _try_create_cache_dir("")
cache_dir_2 = _try_create_cache_dir("ssdf")
assert cache_dir_1 != cache_dir_2
assert cache_dir_1 == "/cache/chunks/d41d8cd98f00b204e9800998ecf8427e"
assert len(makedirs._mock_mock_calls) == 2
def test_dataset_cache_recreation(tmpdir):
"""Test that we recreate the cache and other objects only when appropriate."""
cache = Cache(tmpdir, chunk_size=10)
for i in range(10):
cache[i] = i
cache.done()
cache.merge()
# repated `len()` calls
dataset = StreamingDataset(input_dir=tmpdir)
assert not dataset.cache
assert not dataset.shuffler
len(dataset)
assert not dataset.cache
shuffler = dataset.shuffler
assert isinstance(shuffler, NoShuffle)
len(dataset)
assert dataset.shuffler is shuffler
# repeated `iter()` calls
dataset = StreamingDataset(input_dir=tmpdir)
assert not dataset.cache
assert not dataset.shuffler
iter(dataset)
cache = dataset.cache
shuffler = dataset.shuffler
assert isinstance(cache, Cache)
assert isinstance(shuffler, NoShuffle)
iter(dataset)
assert isinstance(dataset.cache, Cache)
assert isinstance(dataset.shuffler, NoShuffle)
assert dataset.cache is not cache # cache gets recreated
assert dataset.shuffler is not shuffler # shuffler gets recreated
# repeated `getitem()` calls
dataset = StreamingDataset(input_dir=tmpdir)
assert not dataset.cache
assert not dataset.shuffler
_ = dataset[0]
cache = dataset.cache
shuffler = dataset.shuffler
assert isinstance(cache, Cache)
assert isinstance(shuffler, NoShuffle)
_ = dataset[1]
assert dataset.cache is cache # cache gets reused
assert dataset.shuffler is shuffler # shuffler gets reused
def test_try_create_cache_dir():
with mock.patch.dict(os.environ, {}, clear=True):
assert _try_create_cache_dir("any") is None
# the cache dir creating at /cache requires root privileges, so we need to mock `os.makedirs()`
with (
mock.patch.dict("os.environ", {"LIGHTNING_CLUSTER_ID": "abc", "LIGHTNING_CLOUD_PROJECT_ID": "123"}),
mock.patch("lightning.data.streaming.dataset.os.makedirs") as makedirs_mock,
):
cache_dir_1 = _try_create_cache_dir("")
cache_dir_2 = _try_create_cache_dir("ssdf")
assert cache_dir_1 != cache_dir_2
assert cache_dir_1 == os.path.join("/cache", "chunks", "d41d8cd98f00b204e9800998ecf8427e", "0")
assert len(makedirs_mock.mock_calls) == 2
assert _try_create_cache_dir("dir", shard_rank=0) == os.path.join(
"/cache", "chunks", "736007832d2167baaae763fd3a3f3cf1", "0"
)
assert _try_create_cache_dir("dir", shard_rank=1) == os.path.join(
"/cache", "chunks", "736007832d2167baaae763fd3a3f3cf1", "1"
)
assert _try_create_cache_dir("dir", shard_rank=2) == os.path.join(
"/cache", "chunks", "736007832d2167baaae763fd3a3f3cf1", "2"
)

View File

@ -67,6 +67,7 @@ def test_pil_serializer(mode):
assert np.array_equal(np_data, np_dec_data)
@pytest.mark.flaky(reruns=3)
@pytest.mark.skipif(sys.platform == "win32", reason="Not supported on windows")
def test_tensor_serializer():
seed_everything(42)