Add fault tolerance for the StreamingDataset 1/n (#19049)
Co-authored-by: thomas <thomas@thomass-MacBook-Pro.local>
This commit is contained in:
parent
bc1658039f
commit
1073276a58
|
@ -16,4 +16,11 @@ from lightning.data.streaming.data_processor import DataChunkRecipe, DataProcess
|
|||
from lightning.data.streaming.dataset import StreamingDataset
|
||||
from lightning.data.streaming.item_loader import TokensLoader
|
||||
|
||||
__all__ = ["Cache", "DataProcessor", "StreamingDataset", "DataTransformRecipe", "DataChunkRecipe", "TokensLoader"]
|
||||
__all__ = [
|
||||
"Cache",
|
||||
"DataProcessor",
|
||||
"StreamingDataset",
|
||||
"DataTransformRecipe",
|
||||
"DataChunkRecipe",
|
||||
"TokensLoader",
|
||||
]
|
||||
|
|
|
@ -94,6 +94,10 @@ class Cache:
|
|||
self._is_done = False
|
||||
self._distributed_env = _DistributedEnv.detect()
|
||||
|
||||
@property
|
||||
def rank(self) -> int:
|
||||
return self._reader.rank
|
||||
|
||||
@property
|
||||
def filled(self) -> bool:
|
||||
"""Returns whether the caching phase is done."""
|
||||
|
@ -102,6 +106,20 @@ class Cache:
|
|||
self._is_done = os.path.exists(os.path.join(self._cache_dir, _INDEX_FILENAME))
|
||||
return self._is_done
|
||||
|
||||
@property
|
||||
def checkpoint_dir(self) -> str:
|
||||
checkpoint_dir = os.path.join(self._cache_dir, "checkpoints")
|
||||
if not os.path.exists(checkpoint_dir):
|
||||
os.makedirs(checkpoint_dir, exist_ok=True)
|
||||
return checkpoint_dir
|
||||
|
||||
@property
|
||||
def checkpoint_rank_dir(self) -> str:
|
||||
checkpoint_rank_dir = os.path.join(self.checkpoint_dir, str(self.rank))
|
||||
if not os.path.exists(checkpoint_rank_dir):
|
||||
os.makedirs(checkpoint_rank_dir, exist_ok=True)
|
||||
return checkpoint_rank_dir
|
||||
|
||||
def __setitem__(self, index: int, data: Any) -> None:
|
||||
"""Store an item in the writer."""
|
||||
self._writer[index] = data
|
||||
|
|
|
@ -51,3 +51,5 @@ _TORCH_DTYPES_MAPPING = {
|
|||
18: torch.long,
|
||||
19: torch.bool,
|
||||
}
|
||||
|
||||
_TIME_FORMAT = "%Y-%m-%d_%H-%M-%S.%fZ"
|
||||
|
|
|
@ -12,20 +12,34 @@
|
|||
# limitations under the License.
|
||||
|
||||
import hashlib
|
||||
import json
|
||||
import os
|
||||
import shutil
|
||||
import sys
|
||||
import tempfile
|
||||
from copy import deepcopy
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime
|
||||
from time import time
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch.utils.data import IterableDataset
|
||||
|
||||
from lightning.data.streaming import Cache
|
||||
from lightning.data.streaming.constants import _DEFAULT_CACHE_DIR, _INDEX_FILENAME, _LIGHTNING_CLOUD_LATEST
|
||||
from lightning.data.streaming.constants import (
|
||||
_DEFAULT_CACHE_DIR,
|
||||
_INDEX_FILENAME,
|
||||
_LIGHTNING_CLOUD_LATEST,
|
||||
_TIME_FORMAT,
|
||||
)
|
||||
from lightning.data.streaming.item_loader import BaseItemLoader
|
||||
from lightning.data.streaming.sampler import ChunkedIndex
|
||||
from lightning.data.streaming.serializers import Serializer
|
||||
from lightning.data.streaming.shuffle import FullShuffle, NoShuffle, Shuffle
|
||||
from lightning.data.utilities.env import Environment, _DistributedEnv, _WorkerEnv
|
||||
from lightning.fabric.utilities.distributed import group as _group
|
||||
|
||||
if _LIGHTNING_CLOUD_LATEST:
|
||||
from lightning_cloud.resolver import Dir, _resolve_dir
|
||||
|
@ -42,6 +56,7 @@ class StreamingDataset(IterableDataset):
|
|||
drop_last: bool = False,
|
||||
seed: int = 42,
|
||||
serializers: Optional[Dict[str, Serializer]] = None,
|
||||
checkpoint_interval: int = 60 * 5,
|
||||
) -> None:
|
||||
"""The streaming dataset can be used once your data have been optimised using the DatasetOptimiser class.
|
||||
|
||||
|
@ -53,6 +68,7 @@ class StreamingDataset(IterableDataset):
|
|||
all processes/workers return the same amount of data.
|
||||
seed: Random seed for shuffling.
|
||||
serializers: The serializers used to serialize and deserialize the chunks.
|
||||
checkpoint_interval: Interval in seconds at which the workers are going to store their own progress.
|
||||
|
||||
"""
|
||||
super().__init__()
|
||||
|
@ -77,6 +93,7 @@ class StreamingDataset(IterableDataset):
|
|||
self.worker_intervals: List[List[int]] = []
|
||||
self.current_indexes: List[int] = []
|
||||
self.chunk_index = 0
|
||||
self.global_index = 0
|
||||
self.index = 0
|
||||
self.has_triggered_download = False
|
||||
self.min_items_per_replica: Optional[int] = None
|
||||
|
@ -84,6 +101,8 @@ class StreamingDataset(IterableDataset):
|
|||
self.random_state = None
|
||||
self.shuffler: Optional[Shuffle] = None
|
||||
self.serializers = serializers
|
||||
self.checkpoint_interval = checkpoint_interval
|
||||
self._state_dict: Optional[Dict[str, Dict[str, Any]]] = None
|
||||
|
||||
def _create_cache(self, worker_env: _WorkerEnv) -> Cache:
|
||||
env = Environment(dist_env=self.distributed_env, worker_env=worker_env)
|
||||
|
@ -109,11 +128,10 @@ class StreamingDataset(IterableDataset):
|
|||
return cache
|
||||
|
||||
def _create_shuffler(self, cache: Cache) -> Shuffle:
|
||||
return (
|
||||
FullShuffle(cache, self.seed, self.drop_last)
|
||||
if self.shuffle
|
||||
else NoShuffle(cache, self.seed, self.drop_last)
|
||||
)
|
||||
seed = self.seed
|
||||
if self._state_dict is not None:
|
||||
seed = self._state_dict[str(cache.rank)]["seed"]
|
||||
return FullShuffle(cache, seed, self.drop_last) if self.shuffle else NoShuffle(cache, seed, self.drop_last)
|
||||
|
||||
def __len__(self) -> int:
|
||||
if self.shuffler is None:
|
||||
|
@ -126,6 +144,17 @@ class StreamingDataset(IterableDataset):
|
|||
self.cache = self._create_cache(worker_env=self.worker_env)
|
||||
self.shuffler = self._create_shuffler(self.cache)
|
||||
|
||||
# Handle restart
|
||||
if self._state_dict:
|
||||
self._validate_state_dict()
|
||||
state = self._state_dict[str(self.cache.rank)]
|
||||
|
||||
# reload indexes
|
||||
self.chunk_index = state["chunk_index"]
|
||||
self.global_index = state["global_index"]
|
||||
self.index = state["index"]
|
||||
self.current_epoch = state["current_epoch"]
|
||||
|
||||
chunks_per_replica, intervals_per_replica = self.shuffler.get_chunks_and_intervals_per_ranks(
|
||||
self.distributed_env, self.current_epoch
|
||||
)
|
||||
|
@ -141,10 +170,26 @@ class StreamingDataset(IterableDataset):
|
|||
self.worker_chunks.append(chunk_index)
|
||||
self.worker_intervals.append(chunk_interval)
|
||||
|
||||
self.current_indexes = []
|
||||
self.chunk_index = 0
|
||||
self.index = 0
|
||||
# Handle restart
|
||||
if self._state_dict:
|
||||
state = self._state_dict[str(self.cache.rank)]
|
||||
|
||||
# re-generate indexes
|
||||
interval = self.worker_intervals[self.chunk_index]
|
||||
current_indexes = np.arange(interval[0], interval[1])
|
||||
current_indexes = self.shuffler(current_indexes, self.current_epoch, self.chunk_index)
|
||||
self.current_indexes = current_indexes[state["index"] :]
|
||||
|
||||
# Bump the chunk_index
|
||||
self.chunk_index += 1
|
||||
else:
|
||||
self.current_indexes = []
|
||||
self.chunk_index = 0
|
||||
self.global_index = 0
|
||||
self.index = 0
|
||||
|
||||
self.has_triggered_download = False
|
||||
self.last_time = time()
|
||||
|
||||
return self
|
||||
|
||||
|
@ -159,7 +204,7 @@ class StreamingDataset(IterableDataset):
|
|||
|
||||
def __next__(self) -> Any:
|
||||
# Prevent to create more batch on a given process
|
||||
if self.index >= len(self):
|
||||
if self.global_index >= len(self):
|
||||
self.current_epoch += 1
|
||||
raise StopIteration
|
||||
|
||||
|
@ -169,14 +214,19 @@ class StreamingDataset(IterableDataset):
|
|||
self.current_epoch += 1
|
||||
raise StopIteration
|
||||
|
||||
# reset index
|
||||
self.index = 0
|
||||
|
||||
# Checkpoint when reaching a new chunk
|
||||
self.checkpoint(self.chunk_index)
|
||||
|
||||
interval = self.worker_intervals[self.chunk_index]
|
||||
current_indexes = np.arange(interval[0], interval[1])
|
||||
|
||||
assert self.shuffler is not None
|
||||
self.current_indexes = self.shuffler(current_indexes)
|
||||
self.chunk_index += 1
|
||||
self.current_indexes = self.shuffler(current_indexes, self.current_epoch, self.chunk_index)
|
||||
|
||||
last_index = self.chunk_index == len(self.worker_intervals) and len(self.current_indexes) == 1
|
||||
self.chunk_index += 1
|
||||
|
||||
# Get the first index
|
||||
index = self.current_indexes.pop(0)
|
||||
|
@ -188,15 +238,165 @@ class StreamingDataset(IterableDataset):
|
|||
chunk_index=self.worker_chunks[self.chunk_index - 1],
|
||||
# We provide the chunks indexes only one the first
|
||||
chunk_indexes=None if self.has_triggered_download else self.worker_chunks,
|
||||
last_index=last_index,
|
||||
last_index=(self.chunk_index - 1) == len(self.worker_intervals) and len(self.current_indexes) == 1,
|
||||
)
|
||||
)
|
||||
|
||||
self.has_triggered_download = True
|
||||
self.global_index += 1
|
||||
self.index += 1
|
||||
|
||||
# Checkpoint based on time
|
||||
if (self.last_time - time()) > self.checkpoint_interval:
|
||||
self.checkpoint(self.chunk_index - 1)
|
||||
|
||||
return data
|
||||
|
||||
def checkpoint(self, chunk_index: int) -> None:
|
||||
# Checkpointing isn't supported for windows
|
||||
if sys.platform == "win32":
|
||||
return
|
||||
|
||||
assert self.cache
|
||||
assert self.worker_env
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
tmp_checkpoint_path = os.path.join(tmpdir, "checkpoint.json")
|
||||
with open(tmp_checkpoint_path, "w") as f:
|
||||
# 1. Write the state to a tempfile
|
||||
json.dump(
|
||||
{
|
||||
"rank": self.cache._reader.rank,
|
||||
"current_epoch": self.current_epoch,
|
||||
"input_dir_path": self.input_dir.path,
|
||||
"input_dir_url": self.input_dir.url,
|
||||
"item_loader": self.item_loader.state_dict() if self.item_loader else None,
|
||||
"drop_last": self.drop_last,
|
||||
"seed": self.seed,
|
||||
"checkpoint_interval": self.checkpoint_interval,
|
||||
"chunk_index": chunk_index,
|
||||
"global_index": self.global_index,
|
||||
"index": self.index,
|
||||
"world_size": self.distributed_env.world_size,
|
||||
"num_workers": self.worker_env.world_size,
|
||||
"shuffle": self.shuffle,
|
||||
},
|
||||
f,
|
||||
)
|
||||
|
||||
# 3. Move the file to avoid corrupted read from the main thread.
|
||||
now = datetime.now().strftime(_TIME_FORMAT)
|
||||
checkpoint_path = os.path.join(self.cache.checkpoint_rank_dir, f"checkpoint-{now}.json")
|
||||
|
||||
# 4. Move the file to its target position
|
||||
shutil.move(tmp_checkpoint_path, checkpoint_path)
|
||||
|
||||
self.last_time = time()
|
||||
|
||||
def state_dict(self) -> Dict[str, Any]:
|
||||
if self.cache is None:
|
||||
self.worker_env = _WorkerEnv.detect()
|
||||
self.cache = self._create_cache(worker_env=self.worker_env)
|
||||
|
||||
state_dict: Dict[str, Any] = {}
|
||||
worker_env = _WorkerEnv.detect()
|
||||
if worker_env.world_size == 1:
|
||||
# 1. Check whether the checkpoint_dir exists
|
||||
if not os.path.exists(self.cache.checkpoint_dir):
|
||||
return state_dict
|
||||
|
||||
# 2. Iterate through the workers and read the latest checkpoint
|
||||
for worker_idx in os.listdir(self.cache.checkpoint_dir):
|
||||
checkpoints = os.listdir(os.path.join(self.cache.checkpoint_dir, str(worker_idx)))
|
||||
checkpoints = sorted(checkpoints, key=_string_to_datetime)
|
||||
|
||||
# Load the latest checkpoint for this worker
|
||||
checkpoint_path = os.path.join(self.cache.checkpoint_dir, str(worker_idx), checkpoints[-1])
|
||||
with open(checkpoint_path) as f:
|
||||
state_dict[worker_idx] = json.load(f)
|
||||
|
||||
_state_dict = deepcopy(state_dict)
|
||||
|
||||
if self.distributed_env.world_size > 1:
|
||||
# TODO: Move this to fabric.
|
||||
num_devices = torch.cuda.device_count() or 1
|
||||
node_ranks = []
|
||||
for index in range(self.distributed_env.world_size):
|
||||
node_rank = index // num_devices
|
||||
if node_rank in node_ranks:
|
||||
continue
|
||||
state = {}
|
||||
obj = [_state_dict]
|
||||
torch.distributed.broadcast_object_list(obj, index, group=_group.WORLD)
|
||||
state = obj[0]
|
||||
state_dict.update(**state)
|
||||
node_ranks.append(node_rank)
|
||||
else:
|
||||
raise NotImplementedError("The `state_dict` should be called on the main thread.")
|
||||
return state_dict
|
||||
|
||||
def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
|
||||
if state_dict:
|
||||
# the state is restored within the workers
|
||||
self._state_dict = state_dict
|
||||
|
||||
def _validate_state_dict(self) -> None:
|
||||
assert self._state_dict
|
||||
assert self.worker_env
|
||||
assert self.cache
|
||||
|
||||
env = Environment(dist_env=self.distributed_env, worker_env=self.worker_env)
|
||||
|
||||
if env.num_shards != len(self._state_dict):
|
||||
raise ValueError(
|
||||
"The provided `state` size doesn't match the number workers world size. "
|
||||
f"Found `{env.num_shards}` instead of `{len(self._state_dict)}`."
|
||||
)
|
||||
|
||||
state: Dict[str, Any] = self._state_dict[str(self.cache.rank)]
|
||||
|
||||
if state["shuffle"] != self.shuffle:
|
||||
raise ValueError(
|
||||
"The provided `shuffle` state doesn't match the current one. "
|
||||
f"Found `{self.shuffle}` instead of `{state['shuffle']}`."
|
||||
)
|
||||
|
||||
if state["num_workers"] != self.worker_env.world_size:
|
||||
raise ValueError(
|
||||
"The provided `num_workers` state doesn't match the current one. "
|
||||
f"Found `{self.worker_env.world_size}` instead of `{state['num_workers']}`."
|
||||
)
|
||||
|
||||
if state["input_dir_path"] != self.input_dir.path:
|
||||
raise ValueError(
|
||||
"The provided `input_dir` path state doesn't match the current one. "
|
||||
f"Found `{self.input_dir.path}` instead of `{state['input_dir_path']}`."
|
||||
)
|
||||
|
||||
if state["input_dir_url"] != self.input_dir.url:
|
||||
raise ValueError(
|
||||
"The provided `input_dir` URL state doesn't match the current one. "
|
||||
f"Found `{self.input_dir.url}` instead of `{state['input_dir_url']}`."
|
||||
)
|
||||
|
||||
if state["seed"] != self.seed:
|
||||
raise ValueError(
|
||||
"The provided `seed` state doesn't match the current one. "
|
||||
f"Found `{self.seed}` instead of `{state['seed']}`."
|
||||
)
|
||||
|
||||
if self.item_loader and state["item_loader"] != self.item_loader.state_dict():
|
||||
raise ValueError(
|
||||
"The provided `item_loader` state doesn't match the current one. "
|
||||
f"Found `{self.item_loader.state_dict()}` instead of `{state['item_loader']}`."
|
||||
)
|
||||
|
||||
if state["drop_last"] != self.drop_last:
|
||||
raise ValueError(
|
||||
"The provided `drop_last` state doesn't match the current one. "
|
||||
f"Found `{self.drop_last}` instead of `{state['drop_last']}`."
|
||||
)
|
||||
|
||||
|
||||
def _try_create_cache_dir(input_dir: str, shard_rank: int = 0) -> Optional[str]:
|
||||
hash_object = hashlib.md5(input_dir.encode())
|
||||
|
@ -209,6 +409,10 @@ def _try_create_cache_dir(input_dir: str, shard_rank: int = 0) -> Optional[str]:
|
|||
return cache_dir
|
||||
|
||||
|
||||
def _string_to_datetime(item: str) -> datetime:
|
||||
return datetime.strptime(item.split("checkpoint-")[1].split(".json")[0], _TIME_FORMAT)
|
||||
|
||||
|
||||
@dataclass
|
||||
class RemoteDir:
|
||||
"""Holds a remote URL to a directory and a cache directory where the data will be downloaded."""
|
||||
|
|
|
@ -37,6 +37,9 @@ class BaseItemLoader(ABC):
|
|||
self._chunks = chunks
|
||||
self._serializers = serializers
|
||||
|
||||
def state_dict(self) -> Dict:
|
||||
return {}
|
||||
|
||||
@abstractmethod
|
||||
def generate_intervals(self) -> List[Tuple[int, int]]:
|
||||
"""Returns a list of tuple describing the indexes intervals of the chunks."""
|
||||
|
@ -115,6 +118,11 @@ class TokensLoader(BaseItemLoader):
|
|||
self._dtype: Optional[torch.dtype] = None
|
||||
self._chunk_filepaths: Dict[str, bool] = {}
|
||||
|
||||
def state_dict(self) -> Dict:
|
||||
return {
|
||||
"block_size": self._block_size,
|
||||
}
|
||||
|
||||
def setup(self, config: Dict, chunks: List, serializers: Dict[str, Serializer]) -> None:
|
||||
super().setup(config, chunks, serializers)
|
||||
self._dtype = _TORCH_DTYPES_MAPPING[int(config["data_format"][0].split(":")[1])]
|
||||
|
|
|
@ -28,7 +28,6 @@ class Shuffle(ABC):
|
|||
self.cache = cache
|
||||
self.seed = seed
|
||||
self.drop_last = drop_last
|
||||
self.random_state = None
|
||||
|
||||
@lru_cache(maxsize=10)
|
||||
def get_len(self, distributed_env: _DistributedEnv, current_epoch: int) -> int:
|
||||
|
@ -48,7 +47,7 @@ class Shuffle(ABC):
|
|||
pass
|
||||
|
||||
@abstractmethod
|
||||
def __call__(self, array: np.ndarray) -> List[int]:
|
||||
def __call__(self, array: np.ndarray, current_epoch: int, chunk_index: int) -> List[int]:
|
||||
pass
|
||||
|
||||
|
||||
|
@ -68,7 +67,7 @@ class NoShuffle(Shuffle):
|
|||
|
||||
return chunks_per_ranks, intervals_per_ranks
|
||||
|
||||
def __call__(self, array: np.ndarray) -> List[int]:
|
||||
def __call__(self, array: np.ndarray, current_epoch: int, chunk_index: int) -> List[int]:
|
||||
return array.tolist()
|
||||
|
||||
|
||||
|
@ -92,14 +91,12 @@ class FullShuffle(Shuffle):
|
|||
|
||||
@lru_cache(maxsize=10)
|
||||
def get_chunks_and_intervals_per_ranks(self, distributed_env: _DistributedEnv, current_epoch: int) -> Any:
|
||||
self.random_state = np.random.RandomState(seed=self.seed + current_epoch) # type: ignore
|
||||
|
||||
# 1. Get the intervals
|
||||
chunk_intervals = self.cache.get_chunk_intervals()
|
||||
|
||||
# 2. Shuffle them
|
||||
indexes = range(len(chunk_intervals))
|
||||
shuffled_indexes = self.random_state.permutation(indexes)
|
||||
shuffled_indexes = np.random.RandomState(seed=self.seed + current_epoch).permutation(indexes)
|
||||
shuffled_chunk_intervals = np.asarray(chunk_intervals)[shuffled_indexes]
|
||||
|
||||
# 3. Compute the items budget of each rank
|
||||
|
@ -147,6 +144,5 @@ class FullShuffle(Shuffle):
|
|||
|
||||
return chunks_per_ranks, intervals_per_ranks
|
||||
|
||||
def __call__(self, array: np.ndarray) -> List[int]:
|
||||
assert self.random_state
|
||||
return self.random_state.permutation(array).tolist()
|
||||
def __call__(self, array: np.ndarray, current_epoch: int, chunk_index: int) -> List[int]:
|
||||
return np.random.RandomState(seed=self.seed + current_epoch + chunk_index).permutation(array).tolist()
|
||||
|
|
|
@ -161,7 +161,7 @@ def test_remove_target(tmpdir):
|
|||
|
||||
@pytest.mark.skipif(condition=sys.platform == "win32", reason="Not supported on windows")
|
||||
@mock.patch("lightning.data.streaming.data_processor._wait_for_disk_usage_higher_than_threshold")
|
||||
def test_download_data_target(tmpdir):
|
||||
def test_download_data_target(wait_for_disk_usage_higher_than_threshold_mock, tmpdir):
|
||||
input_dir = os.path.join(tmpdir, "input_dir")
|
||||
os.makedirs(input_dir, exist_ok=True)
|
||||
|
||||
|
@ -194,6 +194,8 @@ def test_download_data_target(tmpdir):
|
|||
|
||||
assert os.listdir(cache_dir) == ["a.txt"]
|
||||
|
||||
wait_for_disk_usage_higher_than_threshold_mock.assert_called()
|
||||
|
||||
|
||||
def test_wait_for_disk_usage_higher_than_threshold():
|
||||
disk_usage_mock = mock.Mock(side_effect=[mock.Mock(free=10e9), mock.Mock(free=10e9), mock.Mock(free=10e11)])
|
||||
|
|
|
@ -11,8 +11,11 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
from datetime import datetime
|
||||
from time import sleep
|
||||
from unittest import mock
|
||||
|
||||
import numpy as np
|
||||
|
@ -20,6 +23,7 @@ import pytest
|
|||
import torch
|
||||
from lightning import seed_everything
|
||||
from lightning.data.streaming import Cache, functions
|
||||
from lightning.data.streaming.constants import _TIME_FORMAT
|
||||
from lightning.data.streaming.dataset import StreamingDataset, _try_create_cache_dir
|
||||
from lightning.data.streaming.item_loader import TokensLoader
|
||||
from lightning.data.streaming.shuffle import FullShuffle, NoShuffle
|
||||
|
@ -160,7 +164,7 @@ def test_streaming_dataset_distributed_full_shuffle_odd(drop_last, tmpdir):
|
|||
dataset_iter = iter(dataset)
|
||||
assert len(dataset_iter) == 548
|
||||
process_1_1 = list(dataset_iter)
|
||||
assert process_1_1[:10] == [785, 788, 782, 783, 789, 787, 786, 781, 784, 780]
|
||||
assert process_1_1[:10] == [788, 781, 785, 780, 787, 782, 789, 784, 783, 786]
|
||||
assert len(process_1_1) == 548
|
||||
|
||||
dataset_2 = StreamingDataset(input_dir=str(tmpdir), shuffle=True, drop_last=drop_last)
|
||||
|
@ -171,7 +175,7 @@ def test_streaming_dataset_distributed_full_shuffle_odd(drop_last, tmpdir):
|
|||
dataset_2_iter = iter(dataset_2)
|
||||
assert len(dataset_2_iter) == 548 + int(not drop_last)
|
||||
process_2_1 = list(dataset_2_iter)
|
||||
assert process_2_1[:10] == [939, 938, 252, 259, 257, 255, 258, 253, 250, 251]
|
||||
assert process_2_1[:10] == [939, 938, 253, 259, 256, 258, 252, 255, 251, 257]
|
||||
assert len(process_2_1) == 548 + int(not drop_last)
|
||||
assert len([i for i in process_1_1 if i in process_2_1]) == 0
|
||||
|
||||
|
@ -200,7 +204,7 @@ def test_streaming_dataset_distributed_full_shuffle_even(drop_last, tmpdir):
|
|||
dataset_iter = iter(dataset)
|
||||
assert len(dataset_iter) == 611
|
||||
process_1_1 = list(dataset_iter)
|
||||
assert process_1_1[:10] == [185, 184, 182, 189, 187, 181, 183, 180, 186, 188]
|
||||
assert process_1_1[:10] == [188, 181, 185, 180, 187, 182, 189, 184, 183, 186]
|
||||
assert len(process_1_1) == 611
|
||||
|
||||
dataset_2 = StreamingDataset(input_dir=str(tmpdir), shuffle=True, drop_last=drop_last)
|
||||
|
@ -211,9 +215,8 @@ def test_streaming_dataset_distributed_full_shuffle_even(drop_last, tmpdir):
|
|||
dataset_2_iter = iter(dataset_2)
|
||||
assert len(dataset_2_iter) == 611
|
||||
process_2_1 = list(dataset_2_iter)
|
||||
assert process_2_1[:10] == [813, 815, 816, 812, 818, 811, 817, 814, 819, 277]
|
||||
assert process_2_1[:10] == [818, 812, 816, 811, 819, 813, 815, 814, 817, 273]
|
||||
assert len(process_2_1) == 611
|
||||
|
||||
assert len([i for i in process_1_1 if i in process_2_1]) == 0
|
||||
|
||||
|
||||
|
@ -527,3 +530,209 @@ def test_s3_streaming_dataset():
|
|||
dataset = StreamingDataset(input_dir="s3://pl-flash-data/optimized_tiny_imagenet")
|
||||
assert dataset.input_dir.url == "s3://pl-flash-data/optimized_tiny_imagenet"
|
||||
assert dataset.input_dir.path is None
|
||||
|
||||
|
||||
@pytest.mark.skipif(sys.platform == "win32", reason="Not tested on windows and MacOs")
|
||||
def test_resumable_dataset_single_worker(tmpdir):
|
||||
seed_everything(42)
|
||||
|
||||
block_size = 20
|
||||
cache = Cache(input_dir=str(tmpdir), chunk_size=40, item_loader=TokensLoader(block_size))
|
||||
|
||||
counter = 0
|
||||
for i in range(100):
|
||||
text_ids = torch.arange(counter, counter + 20).to(torch.int)
|
||||
cache[i] = text_ids
|
||||
counter += 20
|
||||
|
||||
cache.done()
|
||||
cache.merge()
|
||||
|
||||
assert len([f for f in os.listdir(tmpdir) if f.endswith(".bin")]) == 50
|
||||
|
||||
dataset = StreamingDataset(input_dir=str(tmpdir), item_loader=TokensLoader(block_size), shuffle=True)
|
||||
|
||||
dataset.current_epoch = 1
|
||||
|
||||
assert dataset.state_dict() == {}
|
||||
|
||||
dataloader = DataLoader(dataset, num_workers=1, batch_size=2, prefetch_factor=1)
|
||||
|
||||
dataloader_iter = iter(dataloader)
|
||||
|
||||
_ = next(dataloader_iter)
|
||||
state_dict_0 = dataset.state_dict()
|
||||
|
||||
sleep(0.1)
|
||||
|
||||
assert state_dict_0["0"]["chunk_index"] == 0
|
||||
assert state_dict_0["0"]["index"] == 0
|
||||
|
||||
checkpoint_dir = os.path.join(tmpdir, "checkpoints")
|
||||
assert os.listdir(checkpoint_dir) == ["0"]
|
||||
_ = next(dataloader_iter)
|
||||
|
||||
sleep(0.1)
|
||||
|
||||
state_dict_1 = dataset.state_dict()
|
||||
assert state_dict_1["0"]["chunk_index"] == 2
|
||||
assert state_dict_1["0"]["index"] == 0
|
||||
|
||||
batch_2 = next(dataloader_iter)
|
||||
|
||||
sleep(0.1)
|
||||
|
||||
state_dict_2 = dataset.state_dict()
|
||||
assert state_dict_2["0"]["chunk_index"] == 3
|
||||
assert state_dict_2["0"]["index"] == 0
|
||||
|
||||
dataset = StreamingDataset(input_dir=str(tmpdir), item_loader=TokensLoader(block_size), shuffle=True)
|
||||
dataset.load_state_dict(state_dict_1)
|
||||
dataloader = DataLoader(dataset, num_workers=1, batch_size=2, prefetch_factor=1)
|
||||
|
||||
dataloader_iter = iter(dataloader)
|
||||
batch_0_restart = next(dataloader_iter)
|
||||
|
||||
sleep(0.1)
|
||||
|
||||
state_dict_2 = dataset.state_dict()
|
||||
assert state_dict_2["0"]["chunk_index"] == 3
|
||||
assert state_dict_2["0"]["index"] == 0
|
||||
|
||||
assert torch.equal(batch_2, batch_0_restart)
|
||||
|
||||
|
||||
@pytest.mark.skipif(sys.platform == "win32", reason="Not tested on windows and MacOs")
|
||||
def test_dataset_valid_state(tmpdir):
|
||||
seed_everything(42)
|
||||
|
||||
block_size = 20
|
||||
cache = Cache(input_dir=str(tmpdir), chunk_size=40, item_loader=TokensLoader(block_size))
|
||||
|
||||
counter = 0
|
||||
for i in range(100):
|
||||
text_ids = torch.arange(counter, counter + 20).to(torch.int)
|
||||
cache[i] = text_ids
|
||||
counter += 20
|
||||
|
||||
cache.done()
|
||||
cache.merge()
|
||||
|
||||
dataset = StreamingDataset(input_dir=str(tmpdir), item_loader=TokensLoader(block_size), shuffle=False)
|
||||
dataloader = DataLoader(dataset, num_workers=1, batch_size=2, prefetch_factor=1)
|
||||
dataloader_iter = iter(dataloader)
|
||||
next(dataloader_iter)
|
||||
|
||||
sleep(0.1)
|
||||
|
||||
state_dict = dataset.state_dict()
|
||||
|
||||
dataset.load_state_dict(state_dict)
|
||||
dataset._validate_state_dict()
|
||||
|
||||
state_dict["0"]["drop_last"] = True
|
||||
dataset.load_state_dict(state_dict)
|
||||
with pytest.raises(
|
||||
ValueError,
|
||||
match="The provided `drop_last` state doesn't match the current one. Found `False` instead of `True`.", # noqa E501
|
||||
):
|
||||
dataset._validate_state_dict()
|
||||
|
||||
state_dict["0"]["item_loader"] = {}
|
||||
dataset.load_state_dict(state_dict)
|
||||
with pytest.raises(
|
||||
ValueError,
|
||||
match="The provided `item_loader` state doesn't match the current one. Found `{'block_size': 20}` instead of `{}`.", # noqa E501
|
||||
):
|
||||
dataset._validate_state_dict()
|
||||
|
||||
state_dict["0"]["seed"] = 12
|
||||
dataset.load_state_dict(state_dict)
|
||||
with pytest.raises(
|
||||
ValueError,
|
||||
match="The provided `seed` state doesn't match the current one. Found `42` instead of `12`.", # noqa E501
|
||||
):
|
||||
dataset._validate_state_dict()
|
||||
|
||||
state_dict["0"]["input_dir_url"] = "toto"
|
||||
dataset.load_state_dict(state_dict)
|
||||
with pytest.raises(
|
||||
ValueError,
|
||||
match="The provided `input_dir` URL state doesn't match the current one. Found `None` instead of `toto`.", # noqa E501
|
||||
):
|
||||
dataset._validate_state_dict()
|
||||
|
||||
state_dict["0"]["input_dir_path"] = "toto"
|
||||
dataset.load_state_dict(state_dict)
|
||||
with pytest.raises(
|
||||
ValueError,
|
||||
match=f"The provided `input_dir` path state doesn't match the current one. Found `{tmpdir}` instead of `toto`.", # noqa E501
|
||||
):
|
||||
dataset._validate_state_dict()
|
||||
|
||||
state_dict["0"]["num_workers"] = "8"
|
||||
dataset.load_state_dict(state_dict)
|
||||
with pytest.raises(
|
||||
ValueError,
|
||||
match=f"The provided `num_workers` state doesn't match the current one. Found `1` instead of `8`.", # noqa E501
|
||||
):
|
||||
dataset._validate_state_dict()
|
||||
|
||||
state_dict["0"]["shuffle"] = True
|
||||
dataset.load_state_dict(state_dict)
|
||||
with pytest.raises(
|
||||
ValueError,
|
||||
match=f"The provided `shuffle` state doesn't match the current one. Found `False` instead of `True`.", # noqa E501
|
||||
):
|
||||
dataset._validate_state_dict()
|
||||
|
||||
|
||||
def test_resumable_dataset_distributed_state_dict(tmpdir):
|
||||
seed_everything(42)
|
||||
|
||||
block_size = 20
|
||||
cache = Cache(input_dir=str(tmpdir), chunk_size=40, item_loader=TokensLoader(block_size))
|
||||
|
||||
counter = 0
|
||||
for i in range(100):
|
||||
text_ids = torch.arange(counter, counter + 20).to(torch.int)
|
||||
cache[i] = text_ids
|
||||
counter += 20
|
||||
|
||||
cache.done()
|
||||
cache.merge()
|
||||
|
||||
assert len([f for f in os.listdir(tmpdir) if f.endswith(".bin")]) == 50
|
||||
|
||||
dataset = StreamingDataset(input_dir=str(tmpdir), item_loader=TokensLoader(block_size), shuffle=False)
|
||||
dataset.distributed_env = _DistributedEnv(world_size=16, global_rank=0)
|
||||
|
||||
# used to create the cache
|
||||
iter(dataset)
|
||||
os.makedirs(dataset.cache.checkpoint_dir, exist_ok=True)
|
||||
|
||||
for i in range(4):
|
||||
now = datetime.now().strftime(_TIME_FORMAT)
|
||||
checkpoint_rank_dir = os.path.join(dataset.cache.checkpoint_dir, str(i))
|
||||
os.makedirs(checkpoint_rank_dir, exist_ok=True)
|
||||
checkpoint_path = os.path.join(checkpoint_rank_dir, f"checkpoint-{now}.json")
|
||||
with open(checkpoint_path, "w") as f:
|
||||
json.dump({}, f)
|
||||
|
||||
torch_mock = mock.MagicMock()
|
||||
torch_mock.cuda.device_count.return_value = 4
|
||||
|
||||
state_list = [{} for _ in range(4)]
|
||||
for i in range(16):
|
||||
state_list[i // 4].update({str(i): {}})
|
||||
|
||||
def broadcast_object_list(obj, src, **kwargs):
|
||||
assert src in [0, 4, 8, 12]
|
||||
obj[0] = state_list.pop(0)
|
||||
|
||||
torch_mock.distributed.broadcast_object_list = broadcast_object_list
|
||||
|
||||
with mock.patch("lightning.data.streaming.dataset.torch", torch_mock):
|
||||
state_dict = dataset.state_dict()
|
||||
|
||||
assert len(state_dict) == 16
|
||||
|
|
Loading…
Reference in New Issue