From 8a5d3423a711875badceba8c0c40016591d58aa8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Tue, 7 Nov 2023 16:19:03 +0100 Subject: [PATCH] Cache directory per worker to avoid collisions (#18957) --- src/lightning/data/streaming/cache.py | 2 - src/lightning/data/streaming/dataset.py | 83 ++++++---- tests/tests_data/streaming/test_cache.py | 26 +--- tests/tests_data/streaming/test_dataset.py | 144 ++++++++++++++---- tests/tests_data/streaming/test_serializer.py | 1 + 5 files changed, 166 insertions(+), 90 deletions(-) diff --git a/src/lightning/data/streaming/cache.py b/src/lightning/data/streaming/cache.py index a28c20cf62..30100cc0e9 100644 --- a/src/lightning/data/streaming/cache.py +++ b/src/lightning/data/streaming/cache.py @@ -57,8 +57,6 @@ class Cache: Arguments: input_dir: The path to where the chunks will be or are stored. - name: The name of dataset in the cloud. - version: The version of the dataset in the cloud to use. By default, we will use the latest. compression: The name of the algorithm to reduce the size of the chunks. chunk_bytes: The maximum number of bytes within a chunk. chunk_size: The maximum number of items within a chunk. diff --git a/src/lightning/data/streaming/dataset.py b/src/lightning/data/streaming/dataset.py index c014ef346c..70a69ea29d 100644 --- a/src/lightning/data/streaming/dataset.py +++ b/src/lightning/data/streaming/dataset.py @@ -11,6 +11,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import copy import hashlib import os from typing import Any, List, Optional, Union @@ -18,7 +19,7 @@ from typing import Any, List, Optional, Union import numpy as np from torch.utils.data import IterableDataset -from lightning.data.datasets.env import _DistributedEnv, _WorkerEnv +from lightning.data.datasets.env import Environment, _DistributedEnv, _WorkerEnv from lightning.data.streaming import Cache from lightning.data.streaming.constants import _INDEX_FILENAME, _LIGHTNING_CLOUD_GREATER_EQUAL_0_5_50 from lightning.data.streaming.item_loader import BaseItemLoader @@ -57,27 +58,14 @@ class StreamingDataset(IterableDataset): input_dir = _resolve_dir(input_dir) - # Override the provided input_path - cache_dir = _try_create_cache_dir(input_dir.path) - if cache_dir: - input_dir.path = cache_dir - - self.cache = Cache(input_dir=input_dir, item_loader=item_loader, chunk_bytes=1) - - self.cache._reader._try_load_config() - - if not self.cache.filled: - raise ValueError( - f"The provided dataset `{input_dir}` doesn't contain any {_INDEX_FILENAME} file." - " HINT: Did you successfully optimize a dataset to the provided `input_dir` ?" - ) - - self.distributed_env = _DistributedEnv.detect() - - self.shuffle: Shuffle = ( - FullShuffle(self.cache, seed, drop_last) if shuffle else NoShuffle(self.cache, seed, drop_last) - ) + self.input_dir = input_dir + self.item_loader = item_loader + self.shuffle: bool = shuffle self.drop_last = drop_last + self.seed = seed + + self.cache: Optional[Cache] = None + self.distributed_env = _DistributedEnv.detect() self.worker_env: Optional[_WorkerEnv] = None self.worker_chunks: List[int] = [] self.worker_intervals: List[List[int]] = [] @@ -86,23 +74,52 @@ class StreamingDataset(IterableDataset): self.index = 0 self.has_triggered_download = False self.min_items_per_replica: Optional[int] = None - self.seed = seed self.current_epoch = 0 self.random_state = None + self.shuffler: Optional[Shuffle] = None + + def _create_cache(self, worker_env: _WorkerEnv) -> Cache: + env = Environment(dist_env=self.distributed_env, worker_env=worker_env) + cache_path = _try_create_cache_dir(input_dir=self.input_dir.path, shard_rank=env.shard_rank) + cache_dir = copy.deepcopy(self.input_dir) + if cache_path: + cache_dir.path = cache_path + + cache = Cache(input_dir=cache_dir, item_loader=self.item_loader, chunk_bytes=1) + cache._reader._try_load_config() + + if not cache.filled: + raise ValueError( + f"The provided dataset `{self.input_dir}` doesn't contain any {_INDEX_FILENAME} file." + " HINT: Did you successfully optimize a dataset to the provided `input_dir`?" + ) + + return cache + + def _create_shuffler(self, cache: Cache) -> Shuffle: + return ( + FullShuffle(cache, self.seed, self.drop_last) + if self.shuffle + else NoShuffle(cache, self.seed, self.drop_last) + ) def __len__(self) -> int: - return self.shuffle.get_len(self.distributed_env, self.current_epoch) + if self.shuffler is None: + cache = self._create_cache(worker_env=_WorkerEnv.detect()) + self.shuffler = self._create_shuffler(cache) + return self.shuffler.get_len(self.distributed_env, self.current_epoch) def __iter__(self) -> "StreamingDataset": - chunks_per_replica, intervals_per_replica = self.shuffle.get_chunks_and_intervals_per_ranks( + self.worker_env = _WorkerEnv.detect() + self.cache = self._create_cache(worker_env=self.worker_env) + self.shuffler = self._create_shuffler(self.cache) + + chunks_per_replica, intervals_per_replica = self.shuffler.get_chunks_and_intervals_per_ranks( self.distributed_env, self.current_epoch ) current_chunks = chunks_per_replica[self.distributed_env.global_rank % self.distributed_env.world_size] current_intervals = intervals_per_replica[self.distributed_env.global_rank % self.distributed_env.world_size] - if self.worker_env is None: - self.worker_env = _WorkerEnv.detect() - self.worker_chunks = [] self.worker_intervals = [] @@ -119,6 +136,10 @@ class StreamingDataset(IterableDataset): return self def __getitem__(self, index: Union[ChunkedIndex, int]) -> Any: + if self.cache is None: + self.worker_env = _WorkerEnv.detect() + self.cache = self._create_cache(worker_env=self.worker_env) + self.shuffler = self._create_shuffler(self.cache) if isinstance(index, int): index = ChunkedIndex(index, self.cache._get_chunk_index_from_index(index)) return self.cache[index] @@ -137,7 +158,9 @@ class StreamingDataset(IterableDataset): interval = self.worker_intervals[self.chunk_index] current_indexes = np.arange(interval[0], interval[1]) - self.current_indexes = self.shuffle(current_indexes) + + assert self.shuffler is not None + self.current_indexes = self.shuffler(current_indexes) self.chunk_index += 1 # Get the first index @@ -158,10 +181,10 @@ class StreamingDataset(IterableDataset): return data -def _try_create_cache_dir(input_dir: str) -> Optional[str]: +def _try_create_cache_dir(input_dir: str, shard_rank: int = 0) -> Optional[str]: if "LIGHTNING_CLUSTER_ID" not in os.environ or "LIGHTNING_CLOUD_PROJECT_ID" not in os.environ: return None hash_object = hashlib.md5(input_dir.encode()) - cache_dir = os.path.join(f"/cache/chunks/{hash_object.hexdigest()}") + cache_dir = os.path.join("/cache", "chunks", hash_object.hexdigest(), str(shard_rank)) os.makedirs(cache_dir, exist_ok=True) return cache_dir diff --git a/tests/tests_data/streaming/test_cache.py b/tests/tests_data/streaming/test_cache.py index 41317735be..4f7da16aa0 100644 --- a/tests/tests_data/streaming/test_cache.py +++ b/tests/tests_data/streaming/test_cache.py @@ -23,12 +23,11 @@ from lightning.data.datasets.env import _DistributedEnv from lightning.data.streaming import Cache from lightning.data.streaming.dataloader import StreamingDataLoader from lightning.data.streaming.dataset import StreamingDataset -from lightning.data.streaming.item_loader import TokensLoader from lightning.fabric import Fabric from lightning.pytorch.demos.boring_classes import RandomDataset from lightning_utilities.core.imports import RequirementCache from lightning_utilities.test.warning import no_warning_call -from torch.utils.data import DataLoader, Dataset +from torch.utils.data import Dataset _PIL_AVAILABLE = RequirementCache("PIL") _TORCH_VISION_AVAILABLE = RequirementCache("torchvision") @@ -222,29 +221,6 @@ def test_cache_with_auto_wrapping(tmpdir): pass -def test_streaming_dataset(tmpdir, monkeypatch): - seed_everything(42) - - os.makedirs(os.path.join(tmpdir, "remote_dir"), exist_ok=True) - - with pytest.raises(ValueError, match="The provided dataset"): - dataset = StreamingDataset(input_dir=tmpdir) - - dataset = RandomDataset(128, 64) - dataloader = StreamingDataLoader(dataset, cache_dir=tmpdir, chunk_bytes=2 << 12) - for batch in dataloader: - assert isinstance(batch, torch.Tensor) - - dataset = StreamingDataset(input_dir=tmpdir, item_loader=TokensLoader(block_size=10)) - - assert len(dataset) == 816 - dataset_iter = iter(dataset) - assert len(dataset_iter) == 816 - - dataloader = DataLoader(dataset, num_workers=2, batch_size=2) - assert len(dataloader) == 408 - - def test_create_oversized_chunk_single_item(tmp_path): cache = Cache(str(tmp_path), chunk_bytes=700) with pytest.warns(UserWarning, match="An item was larger than the target chunk size"): diff --git a/tests/tests_data/streaming/test_dataset.py b/tests/tests_data/streaming/test_dataset.py index 15ed713351..f903f8b377 100644 --- a/tests/tests_data/streaming/test_dataset.py +++ b/tests/tests_data/streaming/test_dataset.py @@ -12,49 +12,55 @@ # limitations under the License. import os +from re import escape from unittest import mock import pytest -import torch from lightning import seed_everything from lightning.data.datasets.env import _DistributedEnv from lightning.data.streaming import Cache -from lightning.data.streaming.dataloader import StreamingDataLoader from lightning.data.streaming.dataset import StreamingDataset, _try_create_cache_dir -from lightning.data.streaming.item_loader import TokensLoader from lightning.data.streaming.shuffle import FullShuffle, NoShuffle -from lightning.pytorch.demos.boring_classes import RandomDataset from torch.utils.data import DataLoader def test_streaming_dataset(tmpdir, monkeypatch): seed_everything(42) + dataset = StreamingDataset(input_dir=tmpdir) with pytest.raises(ValueError, match="The provided dataset"): - dataset = StreamingDataset(input_dir=tmpdir) + iter(dataset) + dataset = StreamingDataset(input_dir=tmpdir) + with pytest.raises(ValueError, match="The provided dataset"): + _ = dataset[0] - dataset = RandomDataset(128, 64) - dataloader = StreamingDataLoader(dataset, cache_dir=tmpdir, chunk_bytes=2 << 12) - for batch in dataloader: - assert isinstance(batch, torch.Tensor) + cache = Cache(tmpdir, chunk_size=10) + for i in range(12): + cache[i] = i + cache.done() + cache.merge() - dataset = StreamingDataset(input_dir=tmpdir, item_loader=TokensLoader(block_size=10)) + dataset = StreamingDataset(input_dir=tmpdir) - assert len(dataset) == 816 + assert len(dataset) == 12 dataset_iter = iter(dataset) - assert len(dataset_iter) == 816 + assert len(dataset_iter) == 12 + dataloader = DataLoader(dataset, num_workers=2, batch_size=1) + assert len(dataloader) == 12 dataloader = DataLoader(dataset, num_workers=2, batch_size=2) - assert len(dataloader) == 408 + assert len(dataloader) == 6 @mock.patch.dict(os.environ, {"LIGHTNING_CLUSTER_ID": "123", "LIGHTNING_CLOUD_PROJECT_ID": "456"}) @mock.patch("lightning.data.streaming.dataset.os.makedirs") def test_create_cache_dir_in_lightning_cloud(makedirs_mock, tmpdir): # Locally, we can't actually write to the root filesystem with user privileges, so we need to mock the call - with pytest.raises(FileNotFoundError, match="`/cache/chunks/275876e34cf609db118f3d84b799a790` doesn't exist"): - StreamingDataset("dummy") - makedirs_mock.assert_called_once_with("/cache/chunks/275876e34cf609db118f3d84b799a790", exist_ok=True) + dataset = StreamingDataset("dummy") + expected = os.path.join("/cache", "chunks", "275876e34cf609db118f3d84b799a790", "0") + with pytest.raises(FileNotFoundError, match=escape(f"`{expected}` doesn't exist")): + iter(dataset) + makedirs_mock.assert_called_once_with(expected, exist_ok=True) @pytest.mark.parametrize("drop_last", [False, True]) @@ -69,8 +75,9 @@ def test_streaming_dataset_distributed_no_shuffle(drop_last, tmpdir): cache.merge() dataset = StreamingDataset(input_dir=tmpdir, shuffle=False, drop_last=drop_last) - - assert isinstance(dataset.shuffle, NoShuffle) + assert not dataset.shuffle + _ = dataset[0] # init shuffler + assert isinstance(dataset.shuffler, NoShuffle) for i in range(101): assert dataset[i] == i @@ -105,7 +112,7 @@ def test_streaming_dataset_distributed_no_shuffle(drop_last, tmpdir): assert len(process_2_2) == 50 - _, intervals_per_ranks = dataset.shuffle.get_chunks_and_intervals_per_ranks( + _, intervals_per_ranks = dataset.shuffler.get_chunks_and_intervals_per_ranks( dataset.distributed_env, dataset.current_epoch ) @@ -149,8 +156,9 @@ def test_streaming_dataset_distributed_full_shuffle_odd(drop_last, tmpdir): cache.merge() dataset = StreamingDataset(input_dir=tmpdir, shuffle=True, drop_last=drop_last) - - assert isinstance(dataset.shuffle, FullShuffle) + assert dataset.shuffle + _ = dataset[0] + assert isinstance(dataset.shuffler, FullShuffle) for i in range(1097): assert dataset[i] == i @@ -164,7 +172,8 @@ def test_streaming_dataset_distributed_full_shuffle_odd(drop_last, tmpdir): assert len(process_1_1) == 548 dataset_2 = StreamingDataset(input_dir=tmpdir, shuffle=True, drop_last=drop_last) - assert isinstance(dataset_2.shuffle, FullShuffle) + iter(dataset_2) + assert isinstance(dataset_2.shuffler, FullShuffle) dataset_2.distributed_env = _DistributedEnv(2, 1) assert len(dataset_2) == 548 + int(not drop_last) dataset_2_iter = iter(dataset_2) @@ -187,8 +196,9 @@ def test_streaming_dataset_distributed_full_shuffle_even(drop_last, tmpdir): cache.merge() dataset = StreamingDataset(input_dir=tmpdir, shuffle=True, drop_last=drop_last) - - assert isinstance(dataset.shuffle, FullShuffle) + assert dataset.shuffle + _ = dataset[0] + assert isinstance(dataset.shuffler, FullShuffle) for i in range(1222): assert dataset[i] == i @@ -202,7 +212,8 @@ def test_streaming_dataset_distributed_full_shuffle_even(drop_last, tmpdir): assert len(process_1_1) == 611 dataset_2 = StreamingDataset(input_dir=tmpdir, shuffle=True, drop_last=drop_last) - assert isinstance(dataset_2.shuffle, FullShuffle) + iter(dataset_2) + assert isinstance(dataset_2.shuffler, FullShuffle) dataset_2.distributed_env = _DistributedEnv(2, 1) assert len(dataset_2) == 611 dataset_2_iter = iter(dataset_2) @@ -229,6 +240,9 @@ def test_streaming_dataset_deepcopy(tmpdir, monkeypatch): cache.merge() dataset = StreamingDataset(input_dir=remote_dir, shuffle=True) + assert dataset.cache is None + iter(dataset) + assert dataset.cache is not None assert dataset.cache._reader._prepare_thread is None dataset.cache._reader._prepare_thread = True dataloader = DataLoader(dataset, num_workers=1) @@ -240,11 +254,75 @@ def test_streaming_dataset_deepcopy(tmpdir, monkeypatch): assert len(batches) == 10 -@mock.patch.dict(os.environ, {"LIGHTNING_CLUSTER_ID": "123", "LIGHTNING_CLOUD_PROJECT_ID": "456"}) -@mock.patch("lightning.data.streaming.dataset.os.makedirs") -def test_try_create_cache_dir(makedirs, monkeypatch): - cache_dir_1 = _try_create_cache_dir("") - cache_dir_2 = _try_create_cache_dir("ssdf") - assert cache_dir_1 != cache_dir_2 - assert cache_dir_1 == "/cache/chunks/d41d8cd98f00b204e9800998ecf8427e" - assert len(makedirs._mock_mock_calls) == 2 +def test_dataset_cache_recreation(tmpdir): + """Test that we recreate the cache and other objects only when appropriate.""" + cache = Cache(tmpdir, chunk_size=10) + for i in range(10): + cache[i] = i + cache.done() + cache.merge() + + # repated `len()` calls + dataset = StreamingDataset(input_dir=tmpdir) + assert not dataset.cache + assert not dataset.shuffler + len(dataset) + assert not dataset.cache + shuffler = dataset.shuffler + assert isinstance(shuffler, NoShuffle) + len(dataset) + assert dataset.shuffler is shuffler + + # repeated `iter()` calls + dataset = StreamingDataset(input_dir=tmpdir) + assert not dataset.cache + assert not dataset.shuffler + iter(dataset) + cache = dataset.cache + shuffler = dataset.shuffler + assert isinstance(cache, Cache) + assert isinstance(shuffler, NoShuffle) + iter(dataset) + assert isinstance(dataset.cache, Cache) + assert isinstance(dataset.shuffler, NoShuffle) + assert dataset.cache is not cache # cache gets recreated + assert dataset.shuffler is not shuffler # shuffler gets recreated + + # repeated `getitem()` calls + dataset = StreamingDataset(input_dir=tmpdir) + assert not dataset.cache + assert not dataset.shuffler + _ = dataset[0] + cache = dataset.cache + shuffler = dataset.shuffler + assert isinstance(cache, Cache) + assert isinstance(shuffler, NoShuffle) + _ = dataset[1] + assert dataset.cache is cache # cache gets reused + assert dataset.shuffler is shuffler # shuffler gets reused + + +def test_try_create_cache_dir(): + with mock.patch.dict(os.environ, {}, clear=True): + assert _try_create_cache_dir("any") is None + + # the cache dir creating at /cache requires root privileges, so we need to mock `os.makedirs()` + with ( + mock.patch.dict("os.environ", {"LIGHTNING_CLUSTER_ID": "abc", "LIGHTNING_CLOUD_PROJECT_ID": "123"}), + mock.patch("lightning.data.streaming.dataset.os.makedirs") as makedirs_mock, + ): + cache_dir_1 = _try_create_cache_dir("") + cache_dir_2 = _try_create_cache_dir("ssdf") + assert cache_dir_1 != cache_dir_2 + assert cache_dir_1 == os.path.join("/cache", "chunks", "d41d8cd98f00b204e9800998ecf8427e", "0") + assert len(makedirs_mock.mock_calls) == 2 + + assert _try_create_cache_dir("dir", shard_rank=0) == os.path.join( + "/cache", "chunks", "736007832d2167baaae763fd3a3f3cf1", "0" + ) + assert _try_create_cache_dir("dir", shard_rank=1) == os.path.join( + "/cache", "chunks", "736007832d2167baaae763fd3a3f3cf1", "1" + ) + assert _try_create_cache_dir("dir", shard_rank=2) == os.path.join( + "/cache", "chunks", "736007832d2167baaae763fd3a3f3cf1", "2" + ) diff --git a/tests/tests_data/streaming/test_serializer.py b/tests/tests_data/streaming/test_serializer.py index 194f208a3a..9b0ef76957 100644 --- a/tests/tests_data/streaming/test_serializer.py +++ b/tests/tests_data/streaming/test_serializer.py @@ -67,6 +67,7 @@ def test_pil_serializer(mode): assert np.array_equal(np_data, np_dec_data) +@pytest.mark.flaky(reruns=3) @pytest.mark.skipif(sys.platform == "win32", reason="Not supported on windows") def test_tensor_serializer(): seed_everything(42)