Cache directory per worker to avoid collisions (#18957)
This commit is contained in:
parent
529f07f254
commit
8a5d3423a7
|
@ -57,8 +57,6 @@ class Cache:
|
|||
|
||||
Arguments:
|
||||
input_dir: The path to where the chunks will be or are stored.
|
||||
name: The name of dataset in the cloud.
|
||||
version: The version of the dataset in the cloud to use. By default, we will use the latest.
|
||||
compression: The name of the algorithm to reduce the size of the chunks.
|
||||
chunk_bytes: The maximum number of bytes within a chunk.
|
||||
chunk_size: The maximum number of items within a chunk.
|
||||
|
|
|
@ -11,6 +11,7 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import copy
|
||||
import hashlib
|
||||
import os
|
||||
from typing import Any, List, Optional, Union
|
||||
|
@ -18,7 +19,7 @@ from typing import Any, List, Optional, Union
|
|||
import numpy as np
|
||||
from torch.utils.data import IterableDataset
|
||||
|
||||
from lightning.data.datasets.env import _DistributedEnv, _WorkerEnv
|
||||
from lightning.data.datasets.env import Environment, _DistributedEnv, _WorkerEnv
|
||||
from lightning.data.streaming import Cache
|
||||
from lightning.data.streaming.constants import _INDEX_FILENAME, _LIGHTNING_CLOUD_GREATER_EQUAL_0_5_50
|
||||
from lightning.data.streaming.item_loader import BaseItemLoader
|
||||
|
@ -57,27 +58,14 @@ class StreamingDataset(IterableDataset):
|
|||
|
||||
input_dir = _resolve_dir(input_dir)
|
||||
|
||||
# Override the provided input_path
|
||||
cache_dir = _try_create_cache_dir(input_dir.path)
|
||||
if cache_dir:
|
||||
input_dir.path = cache_dir
|
||||
|
||||
self.cache = Cache(input_dir=input_dir, item_loader=item_loader, chunk_bytes=1)
|
||||
|
||||
self.cache._reader._try_load_config()
|
||||
|
||||
if not self.cache.filled:
|
||||
raise ValueError(
|
||||
f"The provided dataset `{input_dir}` doesn't contain any {_INDEX_FILENAME} file."
|
||||
" HINT: Did you successfully optimize a dataset to the provided `input_dir` ?"
|
||||
)
|
||||
|
||||
self.distributed_env = _DistributedEnv.detect()
|
||||
|
||||
self.shuffle: Shuffle = (
|
||||
FullShuffle(self.cache, seed, drop_last) if shuffle else NoShuffle(self.cache, seed, drop_last)
|
||||
)
|
||||
self.input_dir = input_dir
|
||||
self.item_loader = item_loader
|
||||
self.shuffle: bool = shuffle
|
||||
self.drop_last = drop_last
|
||||
self.seed = seed
|
||||
|
||||
self.cache: Optional[Cache] = None
|
||||
self.distributed_env = _DistributedEnv.detect()
|
||||
self.worker_env: Optional[_WorkerEnv] = None
|
||||
self.worker_chunks: List[int] = []
|
||||
self.worker_intervals: List[List[int]] = []
|
||||
|
@ -86,23 +74,52 @@ class StreamingDataset(IterableDataset):
|
|||
self.index = 0
|
||||
self.has_triggered_download = False
|
||||
self.min_items_per_replica: Optional[int] = None
|
||||
self.seed = seed
|
||||
self.current_epoch = 0
|
||||
self.random_state = None
|
||||
self.shuffler: Optional[Shuffle] = None
|
||||
|
||||
def _create_cache(self, worker_env: _WorkerEnv) -> Cache:
|
||||
env = Environment(dist_env=self.distributed_env, worker_env=worker_env)
|
||||
cache_path = _try_create_cache_dir(input_dir=self.input_dir.path, shard_rank=env.shard_rank)
|
||||
cache_dir = copy.deepcopy(self.input_dir)
|
||||
if cache_path:
|
||||
cache_dir.path = cache_path
|
||||
|
||||
cache = Cache(input_dir=cache_dir, item_loader=self.item_loader, chunk_bytes=1)
|
||||
cache._reader._try_load_config()
|
||||
|
||||
if not cache.filled:
|
||||
raise ValueError(
|
||||
f"The provided dataset `{self.input_dir}` doesn't contain any {_INDEX_FILENAME} file."
|
||||
" HINT: Did you successfully optimize a dataset to the provided `input_dir`?"
|
||||
)
|
||||
|
||||
return cache
|
||||
|
||||
def _create_shuffler(self, cache: Cache) -> Shuffle:
|
||||
return (
|
||||
FullShuffle(cache, self.seed, self.drop_last)
|
||||
if self.shuffle
|
||||
else NoShuffle(cache, self.seed, self.drop_last)
|
||||
)
|
||||
|
||||
def __len__(self) -> int:
|
||||
return self.shuffle.get_len(self.distributed_env, self.current_epoch)
|
||||
if self.shuffler is None:
|
||||
cache = self._create_cache(worker_env=_WorkerEnv.detect())
|
||||
self.shuffler = self._create_shuffler(cache)
|
||||
return self.shuffler.get_len(self.distributed_env, self.current_epoch)
|
||||
|
||||
def __iter__(self) -> "StreamingDataset":
|
||||
chunks_per_replica, intervals_per_replica = self.shuffle.get_chunks_and_intervals_per_ranks(
|
||||
self.worker_env = _WorkerEnv.detect()
|
||||
self.cache = self._create_cache(worker_env=self.worker_env)
|
||||
self.shuffler = self._create_shuffler(self.cache)
|
||||
|
||||
chunks_per_replica, intervals_per_replica = self.shuffler.get_chunks_and_intervals_per_ranks(
|
||||
self.distributed_env, self.current_epoch
|
||||
)
|
||||
current_chunks = chunks_per_replica[self.distributed_env.global_rank % self.distributed_env.world_size]
|
||||
current_intervals = intervals_per_replica[self.distributed_env.global_rank % self.distributed_env.world_size]
|
||||
|
||||
if self.worker_env is None:
|
||||
self.worker_env = _WorkerEnv.detect()
|
||||
|
||||
self.worker_chunks = []
|
||||
self.worker_intervals = []
|
||||
|
||||
|
@ -119,6 +136,10 @@ class StreamingDataset(IterableDataset):
|
|||
return self
|
||||
|
||||
def __getitem__(self, index: Union[ChunkedIndex, int]) -> Any:
|
||||
if self.cache is None:
|
||||
self.worker_env = _WorkerEnv.detect()
|
||||
self.cache = self._create_cache(worker_env=self.worker_env)
|
||||
self.shuffler = self._create_shuffler(self.cache)
|
||||
if isinstance(index, int):
|
||||
index = ChunkedIndex(index, self.cache._get_chunk_index_from_index(index))
|
||||
return self.cache[index]
|
||||
|
@ -137,7 +158,9 @@ class StreamingDataset(IterableDataset):
|
|||
|
||||
interval = self.worker_intervals[self.chunk_index]
|
||||
current_indexes = np.arange(interval[0], interval[1])
|
||||
self.current_indexes = self.shuffle(current_indexes)
|
||||
|
||||
assert self.shuffler is not None
|
||||
self.current_indexes = self.shuffler(current_indexes)
|
||||
self.chunk_index += 1
|
||||
|
||||
# Get the first index
|
||||
|
@ -158,10 +181,10 @@ class StreamingDataset(IterableDataset):
|
|||
return data
|
||||
|
||||
|
||||
def _try_create_cache_dir(input_dir: str) -> Optional[str]:
|
||||
def _try_create_cache_dir(input_dir: str, shard_rank: int = 0) -> Optional[str]:
|
||||
if "LIGHTNING_CLUSTER_ID" not in os.environ or "LIGHTNING_CLOUD_PROJECT_ID" not in os.environ:
|
||||
return None
|
||||
hash_object = hashlib.md5(input_dir.encode())
|
||||
cache_dir = os.path.join(f"/cache/chunks/{hash_object.hexdigest()}")
|
||||
cache_dir = os.path.join("/cache", "chunks", hash_object.hexdigest(), str(shard_rank))
|
||||
os.makedirs(cache_dir, exist_ok=True)
|
||||
return cache_dir
|
||||
|
|
|
@ -23,12 +23,11 @@ from lightning.data.datasets.env import _DistributedEnv
|
|||
from lightning.data.streaming import Cache
|
||||
from lightning.data.streaming.dataloader import StreamingDataLoader
|
||||
from lightning.data.streaming.dataset import StreamingDataset
|
||||
from lightning.data.streaming.item_loader import TokensLoader
|
||||
from lightning.fabric import Fabric
|
||||
from lightning.pytorch.demos.boring_classes import RandomDataset
|
||||
from lightning_utilities.core.imports import RequirementCache
|
||||
from lightning_utilities.test.warning import no_warning_call
|
||||
from torch.utils.data import DataLoader, Dataset
|
||||
from torch.utils.data import Dataset
|
||||
|
||||
_PIL_AVAILABLE = RequirementCache("PIL")
|
||||
_TORCH_VISION_AVAILABLE = RequirementCache("torchvision")
|
||||
|
@ -222,29 +221,6 @@ def test_cache_with_auto_wrapping(tmpdir):
|
|||
pass
|
||||
|
||||
|
||||
def test_streaming_dataset(tmpdir, monkeypatch):
|
||||
seed_everything(42)
|
||||
|
||||
os.makedirs(os.path.join(tmpdir, "remote_dir"), exist_ok=True)
|
||||
|
||||
with pytest.raises(ValueError, match="The provided dataset"):
|
||||
dataset = StreamingDataset(input_dir=tmpdir)
|
||||
|
||||
dataset = RandomDataset(128, 64)
|
||||
dataloader = StreamingDataLoader(dataset, cache_dir=tmpdir, chunk_bytes=2 << 12)
|
||||
for batch in dataloader:
|
||||
assert isinstance(batch, torch.Tensor)
|
||||
|
||||
dataset = StreamingDataset(input_dir=tmpdir, item_loader=TokensLoader(block_size=10))
|
||||
|
||||
assert len(dataset) == 816
|
||||
dataset_iter = iter(dataset)
|
||||
assert len(dataset_iter) == 816
|
||||
|
||||
dataloader = DataLoader(dataset, num_workers=2, batch_size=2)
|
||||
assert len(dataloader) == 408
|
||||
|
||||
|
||||
def test_create_oversized_chunk_single_item(tmp_path):
|
||||
cache = Cache(str(tmp_path), chunk_bytes=700)
|
||||
with pytest.warns(UserWarning, match="An item was larger than the target chunk size"):
|
||||
|
|
|
@ -12,49 +12,55 @@
|
|||
# limitations under the License.
|
||||
|
||||
import os
|
||||
from re import escape
|
||||
from unittest import mock
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
from lightning import seed_everything
|
||||
from lightning.data.datasets.env import _DistributedEnv
|
||||
from lightning.data.streaming import Cache
|
||||
from lightning.data.streaming.dataloader import StreamingDataLoader
|
||||
from lightning.data.streaming.dataset import StreamingDataset, _try_create_cache_dir
|
||||
from lightning.data.streaming.item_loader import TokensLoader
|
||||
from lightning.data.streaming.shuffle import FullShuffle, NoShuffle
|
||||
from lightning.pytorch.demos.boring_classes import RandomDataset
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
|
||||
def test_streaming_dataset(tmpdir, monkeypatch):
|
||||
seed_everything(42)
|
||||
|
||||
dataset = StreamingDataset(input_dir=tmpdir)
|
||||
with pytest.raises(ValueError, match="The provided dataset"):
|
||||
dataset = StreamingDataset(input_dir=tmpdir)
|
||||
iter(dataset)
|
||||
dataset = StreamingDataset(input_dir=tmpdir)
|
||||
with pytest.raises(ValueError, match="The provided dataset"):
|
||||
_ = dataset[0]
|
||||
|
||||
dataset = RandomDataset(128, 64)
|
||||
dataloader = StreamingDataLoader(dataset, cache_dir=tmpdir, chunk_bytes=2 << 12)
|
||||
for batch in dataloader:
|
||||
assert isinstance(batch, torch.Tensor)
|
||||
cache = Cache(tmpdir, chunk_size=10)
|
||||
for i in range(12):
|
||||
cache[i] = i
|
||||
cache.done()
|
||||
cache.merge()
|
||||
|
||||
dataset = StreamingDataset(input_dir=tmpdir, item_loader=TokensLoader(block_size=10))
|
||||
dataset = StreamingDataset(input_dir=tmpdir)
|
||||
|
||||
assert len(dataset) == 816
|
||||
assert len(dataset) == 12
|
||||
dataset_iter = iter(dataset)
|
||||
assert len(dataset_iter) == 816
|
||||
assert len(dataset_iter) == 12
|
||||
|
||||
dataloader = DataLoader(dataset, num_workers=2, batch_size=1)
|
||||
assert len(dataloader) == 12
|
||||
dataloader = DataLoader(dataset, num_workers=2, batch_size=2)
|
||||
assert len(dataloader) == 408
|
||||
assert len(dataloader) == 6
|
||||
|
||||
|
||||
@mock.patch.dict(os.environ, {"LIGHTNING_CLUSTER_ID": "123", "LIGHTNING_CLOUD_PROJECT_ID": "456"})
|
||||
@mock.patch("lightning.data.streaming.dataset.os.makedirs")
|
||||
def test_create_cache_dir_in_lightning_cloud(makedirs_mock, tmpdir):
|
||||
# Locally, we can't actually write to the root filesystem with user privileges, so we need to mock the call
|
||||
with pytest.raises(FileNotFoundError, match="`/cache/chunks/275876e34cf609db118f3d84b799a790` doesn't exist"):
|
||||
StreamingDataset("dummy")
|
||||
makedirs_mock.assert_called_once_with("/cache/chunks/275876e34cf609db118f3d84b799a790", exist_ok=True)
|
||||
dataset = StreamingDataset("dummy")
|
||||
expected = os.path.join("/cache", "chunks", "275876e34cf609db118f3d84b799a790", "0")
|
||||
with pytest.raises(FileNotFoundError, match=escape(f"`{expected}` doesn't exist")):
|
||||
iter(dataset)
|
||||
makedirs_mock.assert_called_once_with(expected, exist_ok=True)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("drop_last", [False, True])
|
||||
|
@ -69,8 +75,9 @@ def test_streaming_dataset_distributed_no_shuffle(drop_last, tmpdir):
|
|||
cache.merge()
|
||||
|
||||
dataset = StreamingDataset(input_dir=tmpdir, shuffle=False, drop_last=drop_last)
|
||||
|
||||
assert isinstance(dataset.shuffle, NoShuffle)
|
||||
assert not dataset.shuffle
|
||||
_ = dataset[0] # init shuffler
|
||||
assert isinstance(dataset.shuffler, NoShuffle)
|
||||
|
||||
for i in range(101):
|
||||
assert dataset[i] == i
|
||||
|
@ -105,7 +112,7 @@ def test_streaming_dataset_distributed_no_shuffle(drop_last, tmpdir):
|
|||
|
||||
assert len(process_2_2) == 50
|
||||
|
||||
_, intervals_per_ranks = dataset.shuffle.get_chunks_and_intervals_per_ranks(
|
||||
_, intervals_per_ranks = dataset.shuffler.get_chunks_and_intervals_per_ranks(
|
||||
dataset.distributed_env, dataset.current_epoch
|
||||
)
|
||||
|
||||
|
@ -149,8 +156,9 @@ def test_streaming_dataset_distributed_full_shuffle_odd(drop_last, tmpdir):
|
|||
cache.merge()
|
||||
|
||||
dataset = StreamingDataset(input_dir=tmpdir, shuffle=True, drop_last=drop_last)
|
||||
|
||||
assert isinstance(dataset.shuffle, FullShuffle)
|
||||
assert dataset.shuffle
|
||||
_ = dataset[0]
|
||||
assert isinstance(dataset.shuffler, FullShuffle)
|
||||
|
||||
for i in range(1097):
|
||||
assert dataset[i] == i
|
||||
|
@ -164,7 +172,8 @@ def test_streaming_dataset_distributed_full_shuffle_odd(drop_last, tmpdir):
|
|||
assert len(process_1_1) == 548
|
||||
|
||||
dataset_2 = StreamingDataset(input_dir=tmpdir, shuffle=True, drop_last=drop_last)
|
||||
assert isinstance(dataset_2.shuffle, FullShuffle)
|
||||
iter(dataset_2)
|
||||
assert isinstance(dataset_2.shuffler, FullShuffle)
|
||||
dataset_2.distributed_env = _DistributedEnv(2, 1)
|
||||
assert len(dataset_2) == 548 + int(not drop_last)
|
||||
dataset_2_iter = iter(dataset_2)
|
||||
|
@ -187,8 +196,9 @@ def test_streaming_dataset_distributed_full_shuffle_even(drop_last, tmpdir):
|
|||
cache.merge()
|
||||
|
||||
dataset = StreamingDataset(input_dir=tmpdir, shuffle=True, drop_last=drop_last)
|
||||
|
||||
assert isinstance(dataset.shuffle, FullShuffle)
|
||||
assert dataset.shuffle
|
||||
_ = dataset[0]
|
||||
assert isinstance(dataset.shuffler, FullShuffle)
|
||||
|
||||
for i in range(1222):
|
||||
assert dataset[i] == i
|
||||
|
@ -202,7 +212,8 @@ def test_streaming_dataset_distributed_full_shuffle_even(drop_last, tmpdir):
|
|||
assert len(process_1_1) == 611
|
||||
|
||||
dataset_2 = StreamingDataset(input_dir=tmpdir, shuffle=True, drop_last=drop_last)
|
||||
assert isinstance(dataset_2.shuffle, FullShuffle)
|
||||
iter(dataset_2)
|
||||
assert isinstance(dataset_2.shuffler, FullShuffle)
|
||||
dataset_2.distributed_env = _DistributedEnv(2, 1)
|
||||
assert len(dataset_2) == 611
|
||||
dataset_2_iter = iter(dataset_2)
|
||||
|
@ -229,6 +240,9 @@ def test_streaming_dataset_deepcopy(tmpdir, monkeypatch):
|
|||
cache.merge()
|
||||
|
||||
dataset = StreamingDataset(input_dir=remote_dir, shuffle=True)
|
||||
assert dataset.cache is None
|
||||
iter(dataset)
|
||||
assert dataset.cache is not None
|
||||
assert dataset.cache._reader._prepare_thread is None
|
||||
dataset.cache._reader._prepare_thread = True
|
||||
dataloader = DataLoader(dataset, num_workers=1)
|
||||
|
@ -240,11 +254,75 @@ def test_streaming_dataset_deepcopy(tmpdir, monkeypatch):
|
|||
assert len(batches) == 10
|
||||
|
||||
|
||||
@mock.patch.dict(os.environ, {"LIGHTNING_CLUSTER_ID": "123", "LIGHTNING_CLOUD_PROJECT_ID": "456"})
|
||||
@mock.patch("lightning.data.streaming.dataset.os.makedirs")
|
||||
def test_try_create_cache_dir(makedirs, monkeypatch):
|
||||
cache_dir_1 = _try_create_cache_dir("")
|
||||
cache_dir_2 = _try_create_cache_dir("ssdf")
|
||||
assert cache_dir_1 != cache_dir_2
|
||||
assert cache_dir_1 == "/cache/chunks/d41d8cd98f00b204e9800998ecf8427e"
|
||||
assert len(makedirs._mock_mock_calls) == 2
|
||||
def test_dataset_cache_recreation(tmpdir):
|
||||
"""Test that we recreate the cache and other objects only when appropriate."""
|
||||
cache = Cache(tmpdir, chunk_size=10)
|
||||
for i in range(10):
|
||||
cache[i] = i
|
||||
cache.done()
|
||||
cache.merge()
|
||||
|
||||
# repated `len()` calls
|
||||
dataset = StreamingDataset(input_dir=tmpdir)
|
||||
assert not dataset.cache
|
||||
assert not dataset.shuffler
|
||||
len(dataset)
|
||||
assert not dataset.cache
|
||||
shuffler = dataset.shuffler
|
||||
assert isinstance(shuffler, NoShuffle)
|
||||
len(dataset)
|
||||
assert dataset.shuffler is shuffler
|
||||
|
||||
# repeated `iter()` calls
|
||||
dataset = StreamingDataset(input_dir=tmpdir)
|
||||
assert not dataset.cache
|
||||
assert not dataset.shuffler
|
||||
iter(dataset)
|
||||
cache = dataset.cache
|
||||
shuffler = dataset.shuffler
|
||||
assert isinstance(cache, Cache)
|
||||
assert isinstance(shuffler, NoShuffle)
|
||||
iter(dataset)
|
||||
assert isinstance(dataset.cache, Cache)
|
||||
assert isinstance(dataset.shuffler, NoShuffle)
|
||||
assert dataset.cache is not cache # cache gets recreated
|
||||
assert dataset.shuffler is not shuffler # shuffler gets recreated
|
||||
|
||||
# repeated `getitem()` calls
|
||||
dataset = StreamingDataset(input_dir=tmpdir)
|
||||
assert not dataset.cache
|
||||
assert not dataset.shuffler
|
||||
_ = dataset[0]
|
||||
cache = dataset.cache
|
||||
shuffler = dataset.shuffler
|
||||
assert isinstance(cache, Cache)
|
||||
assert isinstance(shuffler, NoShuffle)
|
||||
_ = dataset[1]
|
||||
assert dataset.cache is cache # cache gets reused
|
||||
assert dataset.shuffler is shuffler # shuffler gets reused
|
||||
|
||||
|
||||
def test_try_create_cache_dir():
|
||||
with mock.patch.dict(os.environ, {}, clear=True):
|
||||
assert _try_create_cache_dir("any") is None
|
||||
|
||||
# the cache dir creating at /cache requires root privileges, so we need to mock `os.makedirs()`
|
||||
with (
|
||||
mock.patch.dict("os.environ", {"LIGHTNING_CLUSTER_ID": "abc", "LIGHTNING_CLOUD_PROJECT_ID": "123"}),
|
||||
mock.patch("lightning.data.streaming.dataset.os.makedirs") as makedirs_mock,
|
||||
):
|
||||
cache_dir_1 = _try_create_cache_dir("")
|
||||
cache_dir_2 = _try_create_cache_dir("ssdf")
|
||||
assert cache_dir_1 != cache_dir_2
|
||||
assert cache_dir_1 == os.path.join("/cache", "chunks", "d41d8cd98f00b204e9800998ecf8427e", "0")
|
||||
assert len(makedirs_mock.mock_calls) == 2
|
||||
|
||||
assert _try_create_cache_dir("dir", shard_rank=0) == os.path.join(
|
||||
"/cache", "chunks", "736007832d2167baaae763fd3a3f3cf1", "0"
|
||||
)
|
||||
assert _try_create_cache_dir("dir", shard_rank=1) == os.path.join(
|
||||
"/cache", "chunks", "736007832d2167baaae763fd3a3f3cf1", "1"
|
||||
)
|
||||
assert _try_create_cache_dir("dir", shard_rank=2) == os.path.join(
|
||||
"/cache", "chunks", "736007832d2167baaae763fd3a3f3cf1", "2"
|
||||
)
|
||||
|
|
|
@ -67,6 +67,7 @@ def test_pil_serializer(mode):
|
|||
assert np.array_equal(np_data, np_dec_data)
|
||||
|
||||
|
||||
@pytest.mark.flaky(reruns=3)
|
||||
@pytest.mark.skipif(sys.platform == "win32", reason="Not supported on windows")
|
||||
def test_tensor_serializer():
|
||||
seed_everything(42)
|
||||
|
|
Loading…
Reference in New Issue