Add fault tolerance for the StreamingDataset 1/n (#19049)

Co-authored-by: thomas <thomas@thomass-MacBook-Pro.local>
This commit is contained in:
thomas chaton 2023-11-22 17:22:00 +00:00 committed by GitHub
parent bc1658039f
commit 1073276a58
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 476 additions and 30 deletions

View File

@ -16,4 +16,11 @@ from lightning.data.streaming.data_processor import DataChunkRecipe, DataProcess
from lightning.data.streaming.dataset import StreamingDataset
from lightning.data.streaming.item_loader import TokensLoader
__all__ = ["Cache", "DataProcessor", "StreamingDataset", "DataTransformRecipe", "DataChunkRecipe", "TokensLoader"]
__all__ = [
"Cache",
"DataProcessor",
"StreamingDataset",
"DataTransformRecipe",
"DataChunkRecipe",
"TokensLoader",
]

View File

@ -94,6 +94,10 @@ class Cache:
self._is_done = False
self._distributed_env = _DistributedEnv.detect()
@property
def rank(self) -> int:
return self._reader.rank
@property
def filled(self) -> bool:
"""Returns whether the caching phase is done."""
@ -102,6 +106,20 @@ class Cache:
self._is_done = os.path.exists(os.path.join(self._cache_dir, _INDEX_FILENAME))
return self._is_done
@property
def checkpoint_dir(self) -> str:
checkpoint_dir = os.path.join(self._cache_dir, "checkpoints")
if not os.path.exists(checkpoint_dir):
os.makedirs(checkpoint_dir, exist_ok=True)
return checkpoint_dir
@property
def checkpoint_rank_dir(self) -> str:
checkpoint_rank_dir = os.path.join(self.checkpoint_dir, str(self.rank))
if not os.path.exists(checkpoint_rank_dir):
os.makedirs(checkpoint_rank_dir, exist_ok=True)
return checkpoint_rank_dir
def __setitem__(self, index: int, data: Any) -> None:
"""Store an item in the writer."""
self._writer[index] = data

View File

@ -51,3 +51,5 @@ _TORCH_DTYPES_MAPPING = {
18: torch.long,
19: torch.bool,
}
_TIME_FORMAT = "%Y-%m-%d_%H-%M-%S.%fZ"

View File

@ -12,20 +12,34 @@
# limitations under the License.
import hashlib
import json
import os
import shutil
import sys
import tempfile
from copy import deepcopy
from dataclasses import dataclass
from datetime import datetime
from time import time
from typing import Any, Dict, List, Optional, Union
import numpy as np
import torch
from torch.utils.data import IterableDataset
from lightning.data.streaming import Cache
from lightning.data.streaming.constants import _DEFAULT_CACHE_DIR, _INDEX_FILENAME, _LIGHTNING_CLOUD_LATEST
from lightning.data.streaming.constants import (
_DEFAULT_CACHE_DIR,
_INDEX_FILENAME,
_LIGHTNING_CLOUD_LATEST,
_TIME_FORMAT,
)
from lightning.data.streaming.item_loader import BaseItemLoader
from lightning.data.streaming.sampler import ChunkedIndex
from lightning.data.streaming.serializers import Serializer
from lightning.data.streaming.shuffle import FullShuffle, NoShuffle, Shuffle
from lightning.data.utilities.env import Environment, _DistributedEnv, _WorkerEnv
from lightning.fabric.utilities.distributed import group as _group
if _LIGHTNING_CLOUD_LATEST:
from lightning_cloud.resolver import Dir, _resolve_dir
@ -42,6 +56,7 @@ class StreamingDataset(IterableDataset):
drop_last: bool = False,
seed: int = 42,
serializers: Optional[Dict[str, Serializer]] = None,
checkpoint_interval: int = 60 * 5,
) -> None:
"""The streaming dataset can be used once your data have been optimised using the DatasetOptimiser class.
@ -53,6 +68,7 @@ class StreamingDataset(IterableDataset):
all processes/workers return the same amount of data.
seed: Random seed for shuffling.
serializers: The serializers used to serialize and deserialize the chunks.
checkpoint_interval: Interval in seconds at which the workers are going to store their own progress.
"""
super().__init__()
@ -77,6 +93,7 @@ class StreamingDataset(IterableDataset):
self.worker_intervals: List[List[int]] = []
self.current_indexes: List[int] = []
self.chunk_index = 0
self.global_index = 0
self.index = 0
self.has_triggered_download = False
self.min_items_per_replica: Optional[int] = None
@ -84,6 +101,8 @@ class StreamingDataset(IterableDataset):
self.random_state = None
self.shuffler: Optional[Shuffle] = None
self.serializers = serializers
self.checkpoint_interval = checkpoint_interval
self._state_dict: Optional[Dict[str, Dict[str, Any]]] = None
def _create_cache(self, worker_env: _WorkerEnv) -> Cache:
env = Environment(dist_env=self.distributed_env, worker_env=worker_env)
@ -109,11 +128,10 @@ class StreamingDataset(IterableDataset):
return cache
def _create_shuffler(self, cache: Cache) -> Shuffle:
return (
FullShuffle(cache, self.seed, self.drop_last)
if self.shuffle
else NoShuffle(cache, self.seed, self.drop_last)
)
seed = self.seed
if self._state_dict is not None:
seed = self._state_dict[str(cache.rank)]["seed"]
return FullShuffle(cache, seed, self.drop_last) if self.shuffle else NoShuffle(cache, seed, self.drop_last)
def __len__(self) -> int:
if self.shuffler is None:
@ -126,6 +144,17 @@ class StreamingDataset(IterableDataset):
self.cache = self._create_cache(worker_env=self.worker_env)
self.shuffler = self._create_shuffler(self.cache)
# Handle restart
if self._state_dict:
self._validate_state_dict()
state = self._state_dict[str(self.cache.rank)]
# reload indexes
self.chunk_index = state["chunk_index"]
self.global_index = state["global_index"]
self.index = state["index"]
self.current_epoch = state["current_epoch"]
chunks_per_replica, intervals_per_replica = self.shuffler.get_chunks_and_intervals_per_ranks(
self.distributed_env, self.current_epoch
)
@ -141,10 +170,26 @@ class StreamingDataset(IterableDataset):
self.worker_chunks.append(chunk_index)
self.worker_intervals.append(chunk_interval)
self.current_indexes = []
self.chunk_index = 0
self.index = 0
# Handle restart
if self._state_dict:
state = self._state_dict[str(self.cache.rank)]
# re-generate indexes
interval = self.worker_intervals[self.chunk_index]
current_indexes = np.arange(interval[0], interval[1])
current_indexes = self.shuffler(current_indexes, self.current_epoch, self.chunk_index)
self.current_indexes = current_indexes[state["index"] :]
# Bump the chunk_index
self.chunk_index += 1
else:
self.current_indexes = []
self.chunk_index = 0
self.global_index = 0
self.index = 0
self.has_triggered_download = False
self.last_time = time()
return self
@ -159,7 +204,7 @@ class StreamingDataset(IterableDataset):
def __next__(self) -> Any:
# Prevent to create more batch on a given process
if self.index >= len(self):
if self.global_index >= len(self):
self.current_epoch += 1
raise StopIteration
@ -169,14 +214,19 @@ class StreamingDataset(IterableDataset):
self.current_epoch += 1
raise StopIteration
# reset index
self.index = 0
# Checkpoint when reaching a new chunk
self.checkpoint(self.chunk_index)
interval = self.worker_intervals[self.chunk_index]
current_indexes = np.arange(interval[0], interval[1])
assert self.shuffler is not None
self.current_indexes = self.shuffler(current_indexes)
self.chunk_index += 1
self.current_indexes = self.shuffler(current_indexes, self.current_epoch, self.chunk_index)
last_index = self.chunk_index == len(self.worker_intervals) and len(self.current_indexes) == 1
self.chunk_index += 1
# Get the first index
index = self.current_indexes.pop(0)
@ -188,15 +238,165 @@ class StreamingDataset(IterableDataset):
chunk_index=self.worker_chunks[self.chunk_index - 1],
# We provide the chunks indexes only one the first
chunk_indexes=None if self.has_triggered_download else self.worker_chunks,
last_index=last_index,
last_index=(self.chunk_index - 1) == len(self.worker_intervals) and len(self.current_indexes) == 1,
)
)
self.has_triggered_download = True
self.global_index += 1
self.index += 1
# Checkpoint based on time
if (self.last_time - time()) > self.checkpoint_interval:
self.checkpoint(self.chunk_index - 1)
return data
def checkpoint(self, chunk_index: int) -> None:
# Checkpointing isn't supported for windows
if sys.platform == "win32":
return
assert self.cache
assert self.worker_env
with tempfile.TemporaryDirectory() as tmpdir:
tmp_checkpoint_path = os.path.join(tmpdir, "checkpoint.json")
with open(tmp_checkpoint_path, "w") as f:
# 1. Write the state to a tempfile
json.dump(
{
"rank": self.cache._reader.rank,
"current_epoch": self.current_epoch,
"input_dir_path": self.input_dir.path,
"input_dir_url": self.input_dir.url,
"item_loader": self.item_loader.state_dict() if self.item_loader else None,
"drop_last": self.drop_last,
"seed": self.seed,
"checkpoint_interval": self.checkpoint_interval,
"chunk_index": chunk_index,
"global_index": self.global_index,
"index": self.index,
"world_size": self.distributed_env.world_size,
"num_workers": self.worker_env.world_size,
"shuffle": self.shuffle,
},
f,
)
# 3. Move the file to avoid corrupted read from the main thread.
now = datetime.now().strftime(_TIME_FORMAT)
checkpoint_path = os.path.join(self.cache.checkpoint_rank_dir, f"checkpoint-{now}.json")
# 4. Move the file to its target position
shutil.move(tmp_checkpoint_path, checkpoint_path)
self.last_time = time()
def state_dict(self) -> Dict[str, Any]:
if self.cache is None:
self.worker_env = _WorkerEnv.detect()
self.cache = self._create_cache(worker_env=self.worker_env)
state_dict: Dict[str, Any] = {}
worker_env = _WorkerEnv.detect()
if worker_env.world_size == 1:
# 1. Check whether the checkpoint_dir exists
if not os.path.exists(self.cache.checkpoint_dir):
return state_dict
# 2. Iterate through the workers and read the latest checkpoint
for worker_idx in os.listdir(self.cache.checkpoint_dir):
checkpoints = os.listdir(os.path.join(self.cache.checkpoint_dir, str(worker_idx)))
checkpoints = sorted(checkpoints, key=_string_to_datetime)
# Load the latest checkpoint for this worker
checkpoint_path = os.path.join(self.cache.checkpoint_dir, str(worker_idx), checkpoints[-1])
with open(checkpoint_path) as f:
state_dict[worker_idx] = json.load(f)
_state_dict = deepcopy(state_dict)
if self.distributed_env.world_size > 1:
# TODO: Move this to fabric.
num_devices = torch.cuda.device_count() or 1
node_ranks = []
for index in range(self.distributed_env.world_size):
node_rank = index // num_devices
if node_rank in node_ranks:
continue
state = {}
obj = [_state_dict]
torch.distributed.broadcast_object_list(obj, index, group=_group.WORLD)
state = obj[0]
state_dict.update(**state)
node_ranks.append(node_rank)
else:
raise NotImplementedError("The `state_dict` should be called on the main thread.")
return state_dict
def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
if state_dict:
# the state is restored within the workers
self._state_dict = state_dict
def _validate_state_dict(self) -> None:
assert self._state_dict
assert self.worker_env
assert self.cache
env = Environment(dist_env=self.distributed_env, worker_env=self.worker_env)
if env.num_shards != len(self._state_dict):
raise ValueError(
"The provided `state` size doesn't match the number workers world size. "
f"Found `{env.num_shards}` instead of `{len(self._state_dict)}`."
)
state: Dict[str, Any] = self._state_dict[str(self.cache.rank)]
if state["shuffle"] != self.shuffle:
raise ValueError(
"The provided `shuffle` state doesn't match the current one. "
f"Found `{self.shuffle}` instead of `{state['shuffle']}`."
)
if state["num_workers"] != self.worker_env.world_size:
raise ValueError(
"The provided `num_workers` state doesn't match the current one. "
f"Found `{self.worker_env.world_size}` instead of `{state['num_workers']}`."
)
if state["input_dir_path"] != self.input_dir.path:
raise ValueError(
"The provided `input_dir` path state doesn't match the current one. "
f"Found `{self.input_dir.path}` instead of `{state['input_dir_path']}`."
)
if state["input_dir_url"] != self.input_dir.url:
raise ValueError(
"The provided `input_dir` URL state doesn't match the current one. "
f"Found `{self.input_dir.url}` instead of `{state['input_dir_url']}`."
)
if state["seed"] != self.seed:
raise ValueError(
"The provided `seed` state doesn't match the current one. "
f"Found `{self.seed}` instead of `{state['seed']}`."
)
if self.item_loader and state["item_loader"] != self.item_loader.state_dict():
raise ValueError(
"The provided `item_loader` state doesn't match the current one. "
f"Found `{self.item_loader.state_dict()}` instead of `{state['item_loader']}`."
)
if state["drop_last"] != self.drop_last:
raise ValueError(
"The provided `drop_last` state doesn't match the current one. "
f"Found `{self.drop_last}` instead of `{state['drop_last']}`."
)
def _try_create_cache_dir(input_dir: str, shard_rank: int = 0) -> Optional[str]:
hash_object = hashlib.md5(input_dir.encode())
@ -209,6 +409,10 @@ def _try_create_cache_dir(input_dir: str, shard_rank: int = 0) -> Optional[str]:
return cache_dir
def _string_to_datetime(item: str) -> datetime:
return datetime.strptime(item.split("checkpoint-")[1].split(".json")[0], _TIME_FORMAT)
@dataclass
class RemoteDir:
"""Holds a remote URL to a directory and a cache directory where the data will be downloaded."""

View File

@ -37,6 +37,9 @@ class BaseItemLoader(ABC):
self._chunks = chunks
self._serializers = serializers
def state_dict(self) -> Dict:
return {}
@abstractmethod
def generate_intervals(self) -> List[Tuple[int, int]]:
"""Returns a list of tuple describing the indexes intervals of the chunks."""
@ -115,6 +118,11 @@ class TokensLoader(BaseItemLoader):
self._dtype: Optional[torch.dtype] = None
self._chunk_filepaths: Dict[str, bool] = {}
def state_dict(self) -> Dict:
return {
"block_size": self._block_size,
}
def setup(self, config: Dict, chunks: List, serializers: Dict[str, Serializer]) -> None:
super().setup(config, chunks, serializers)
self._dtype = _TORCH_DTYPES_MAPPING[int(config["data_format"][0].split(":")[1])]

View File

@ -28,7 +28,6 @@ class Shuffle(ABC):
self.cache = cache
self.seed = seed
self.drop_last = drop_last
self.random_state = None
@lru_cache(maxsize=10)
def get_len(self, distributed_env: _DistributedEnv, current_epoch: int) -> int:
@ -48,7 +47,7 @@ class Shuffle(ABC):
pass
@abstractmethod
def __call__(self, array: np.ndarray) -> List[int]:
def __call__(self, array: np.ndarray, current_epoch: int, chunk_index: int) -> List[int]:
pass
@ -68,7 +67,7 @@ class NoShuffle(Shuffle):
return chunks_per_ranks, intervals_per_ranks
def __call__(self, array: np.ndarray) -> List[int]:
def __call__(self, array: np.ndarray, current_epoch: int, chunk_index: int) -> List[int]:
return array.tolist()
@ -92,14 +91,12 @@ class FullShuffle(Shuffle):
@lru_cache(maxsize=10)
def get_chunks_and_intervals_per_ranks(self, distributed_env: _DistributedEnv, current_epoch: int) -> Any:
self.random_state = np.random.RandomState(seed=self.seed + current_epoch) # type: ignore
# 1. Get the intervals
chunk_intervals = self.cache.get_chunk_intervals()
# 2. Shuffle them
indexes = range(len(chunk_intervals))
shuffled_indexes = self.random_state.permutation(indexes)
shuffled_indexes = np.random.RandomState(seed=self.seed + current_epoch).permutation(indexes)
shuffled_chunk_intervals = np.asarray(chunk_intervals)[shuffled_indexes]
# 3. Compute the items budget of each rank
@ -147,6 +144,5 @@ class FullShuffle(Shuffle):
return chunks_per_ranks, intervals_per_ranks
def __call__(self, array: np.ndarray) -> List[int]:
assert self.random_state
return self.random_state.permutation(array).tolist()
def __call__(self, array: np.ndarray, current_epoch: int, chunk_index: int) -> List[int]:
return np.random.RandomState(seed=self.seed + current_epoch + chunk_index).permutation(array).tolist()

View File

@ -161,7 +161,7 @@ def test_remove_target(tmpdir):
@pytest.mark.skipif(condition=sys.platform == "win32", reason="Not supported on windows")
@mock.patch("lightning.data.streaming.data_processor._wait_for_disk_usage_higher_than_threshold")
def test_download_data_target(tmpdir):
def test_download_data_target(wait_for_disk_usage_higher_than_threshold_mock, tmpdir):
input_dir = os.path.join(tmpdir, "input_dir")
os.makedirs(input_dir, exist_ok=True)
@ -194,6 +194,8 @@ def test_download_data_target(tmpdir):
assert os.listdir(cache_dir) == ["a.txt"]
wait_for_disk_usage_higher_than_threshold_mock.assert_called()
def test_wait_for_disk_usage_higher_than_threshold():
disk_usage_mock = mock.Mock(side_effect=[mock.Mock(free=10e9), mock.Mock(free=10e9), mock.Mock(free=10e11)])

View File

@ -11,8 +11,11 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import json
import os
import sys
from datetime import datetime
from time import sleep
from unittest import mock
import numpy as np
@ -20,6 +23,7 @@ import pytest
import torch
from lightning import seed_everything
from lightning.data.streaming import Cache, functions
from lightning.data.streaming.constants import _TIME_FORMAT
from lightning.data.streaming.dataset import StreamingDataset, _try_create_cache_dir
from lightning.data.streaming.item_loader import TokensLoader
from lightning.data.streaming.shuffle import FullShuffle, NoShuffle
@ -160,7 +164,7 @@ def test_streaming_dataset_distributed_full_shuffle_odd(drop_last, tmpdir):
dataset_iter = iter(dataset)
assert len(dataset_iter) == 548
process_1_1 = list(dataset_iter)
assert process_1_1[:10] == [785, 788, 782, 783, 789, 787, 786, 781, 784, 780]
assert process_1_1[:10] == [788, 781, 785, 780, 787, 782, 789, 784, 783, 786]
assert len(process_1_1) == 548
dataset_2 = StreamingDataset(input_dir=str(tmpdir), shuffle=True, drop_last=drop_last)
@ -171,7 +175,7 @@ def test_streaming_dataset_distributed_full_shuffle_odd(drop_last, tmpdir):
dataset_2_iter = iter(dataset_2)
assert len(dataset_2_iter) == 548 + int(not drop_last)
process_2_1 = list(dataset_2_iter)
assert process_2_1[:10] == [939, 938, 252, 259, 257, 255, 258, 253, 250, 251]
assert process_2_1[:10] == [939, 938, 253, 259, 256, 258, 252, 255, 251, 257]
assert len(process_2_1) == 548 + int(not drop_last)
assert len([i for i in process_1_1 if i in process_2_1]) == 0
@ -200,7 +204,7 @@ def test_streaming_dataset_distributed_full_shuffle_even(drop_last, tmpdir):
dataset_iter = iter(dataset)
assert len(dataset_iter) == 611
process_1_1 = list(dataset_iter)
assert process_1_1[:10] == [185, 184, 182, 189, 187, 181, 183, 180, 186, 188]
assert process_1_1[:10] == [188, 181, 185, 180, 187, 182, 189, 184, 183, 186]
assert len(process_1_1) == 611
dataset_2 = StreamingDataset(input_dir=str(tmpdir), shuffle=True, drop_last=drop_last)
@ -211,9 +215,8 @@ def test_streaming_dataset_distributed_full_shuffle_even(drop_last, tmpdir):
dataset_2_iter = iter(dataset_2)
assert len(dataset_2_iter) == 611
process_2_1 = list(dataset_2_iter)
assert process_2_1[:10] == [813, 815, 816, 812, 818, 811, 817, 814, 819, 277]
assert process_2_1[:10] == [818, 812, 816, 811, 819, 813, 815, 814, 817, 273]
assert len(process_2_1) == 611
assert len([i for i in process_1_1 if i in process_2_1]) == 0
@ -527,3 +530,209 @@ def test_s3_streaming_dataset():
dataset = StreamingDataset(input_dir="s3://pl-flash-data/optimized_tiny_imagenet")
assert dataset.input_dir.url == "s3://pl-flash-data/optimized_tiny_imagenet"
assert dataset.input_dir.path is None
@pytest.mark.skipif(sys.platform == "win32", reason="Not tested on windows and MacOs")
def test_resumable_dataset_single_worker(tmpdir):
seed_everything(42)
block_size = 20
cache = Cache(input_dir=str(tmpdir), chunk_size=40, item_loader=TokensLoader(block_size))
counter = 0
for i in range(100):
text_ids = torch.arange(counter, counter + 20).to(torch.int)
cache[i] = text_ids
counter += 20
cache.done()
cache.merge()
assert len([f for f in os.listdir(tmpdir) if f.endswith(".bin")]) == 50
dataset = StreamingDataset(input_dir=str(tmpdir), item_loader=TokensLoader(block_size), shuffle=True)
dataset.current_epoch = 1
assert dataset.state_dict() == {}
dataloader = DataLoader(dataset, num_workers=1, batch_size=2, prefetch_factor=1)
dataloader_iter = iter(dataloader)
_ = next(dataloader_iter)
state_dict_0 = dataset.state_dict()
sleep(0.1)
assert state_dict_0["0"]["chunk_index"] == 0
assert state_dict_0["0"]["index"] == 0
checkpoint_dir = os.path.join(tmpdir, "checkpoints")
assert os.listdir(checkpoint_dir) == ["0"]
_ = next(dataloader_iter)
sleep(0.1)
state_dict_1 = dataset.state_dict()
assert state_dict_1["0"]["chunk_index"] == 2
assert state_dict_1["0"]["index"] == 0
batch_2 = next(dataloader_iter)
sleep(0.1)
state_dict_2 = dataset.state_dict()
assert state_dict_2["0"]["chunk_index"] == 3
assert state_dict_2["0"]["index"] == 0
dataset = StreamingDataset(input_dir=str(tmpdir), item_loader=TokensLoader(block_size), shuffle=True)
dataset.load_state_dict(state_dict_1)
dataloader = DataLoader(dataset, num_workers=1, batch_size=2, prefetch_factor=1)
dataloader_iter = iter(dataloader)
batch_0_restart = next(dataloader_iter)
sleep(0.1)
state_dict_2 = dataset.state_dict()
assert state_dict_2["0"]["chunk_index"] == 3
assert state_dict_2["0"]["index"] == 0
assert torch.equal(batch_2, batch_0_restart)
@pytest.mark.skipif(sys.platform == "win32", reason="Not tested on windows and MacOs")
def test_dataset_valid_state(tmpdir):
seed_everything(42)
block_size = 20
cache = Cache(input_dir=str(tmpdir), chunk_size=40, item_loader=TokensLoader(block_size))
counter = 0
for i in range(100):
text_ids = torch.arange(counter, counter + 20).to(torch.int)
cache[i] = text_ids
counter += 20
cache.done()
cache.merge()
dataset = StreamingDataset(input_dir=str(tmpdir), item_loader=TokensLoader(block_size), shuffle=False)
dataloader = DataLoader(dataset, num_workers=1, batch_size=2, prefetch_factor=1)
dataloader_iter = iter(dataloader)
next(dataloader_iter)
sleep(0.1)
state_dict = dataset.state_dict()
dataset.load_state_dict(state_dict)
dataset._validate_state_dict()
state_dict["0"]["drop_last"] = True
dataset.load_state_dict(state_dict)
with pytest.raises(
ValueError,
match="The provided `drop_last` state doesn't match the current one. Found `False` instead of `True`.", # noqa E501
):
dataset._validate_state_dict()
state_dict["0"]["item_loader"] = {}
dataset.load_state_dict(state_dict)
with pytest.raises(
ValueError,
match="The provided `item_loader` state doesn't match the current one. Found `{'block_size': 20}` instead of `{}`.", # noqa E501
):
dataset._validate_state_dict()
state_dict["0"]["seed"] = 12
dataset.load_state_dict(state_dict)
with pytest.raises(
ValueError,
match="The provided `seed` state doesn't match the current one. Found `42` instead of `12`.", # noqa E501
):
dataset._validate_state_dict()
state_dict["0"]["input_dir_url"] = "toto"
dataset.load_state_dict(state_dict)
with pytest.raises(
ValueError,
match="The provided `input_dir` URL state doesn't match the current one. Found `None` instead of `toto`.", # noqa E501
):
dataset._validate_state_dict()
state_dict["0"]["input_dir_path"] = "toto"
dataset.load_state_dict(state_dict)
with pytest.raises(
ValueError,
match=f"The provided `input_dir` path state doesn't match the current one. Found `{tmpdir}` instead of `toto`.", # noqa E501
):
dataset._validate_state_dict()
state_dict["0"]["num_workers"] = "8"
dataset.load_state_dict(state_dict)
with pytest.raises(
ValueError,
match=f"The provided `num_workers` state doesn't match the current one. Found `1` instead of `8`.", # noqa E501
):
dataset._validate_state_dict()
state_dict["0"]["shuffle"] = True
dataset.load_state_dict(state_dict)
with pytest.raises(
ValueError,
match=f"The provided `shuffle` state doesn't match the current one. Found `False` instead of `True`.", # noqa E501
):
dataset._validate_state_dict()
def test_resumable_dataset_distributed_state_dict(tmpdir):
seed_everything(42)
block_size = 20
cache = Cache(input_dir=str(tmpdir), chunk_size=40, item_loader=TokensLoader(block_size))
counter = 0
for i in range(100):
text_ids = torch.arange(counter, counter + 20).to(torch.int)
cache[i] = text_ids
counter += 20
cache.done()
cache.merge()
assert len([f for f in os.listdir(tmpdir) if f.endswith(".bin")]) == 50
dataset = StreamingDataset(input_dir=str(tmpdir), item_loader=TokensLoader(block_size), shuffle=False)
dataset.distributed_env = _DistributedEnv(world_size=16, global_rank=0)
# used to create the cache
iter(dataset)
os.makedirs(dataset.cache.checkpoint_dir, exist_ok=True)
for i in range(4):
now = datetime.now().strftime(_TIME_FORMAT)
checkpoint_rank_dir = os.path.join(dataset.cache.checkpoint_dir, str(i))
os.makedirs(checkpoint_rank_dir, exist_ok=True)
checkpoint_path = os.path.join(checkpoint_rank_dir, f"checkpoint-{now}.json")
with open(checkpoint_path, "w") as f:
json.dump({}, f)
torch_mock = mock.MagicMock()
torch_mock.cuda.device_count.return_value = 4
state_list = [{} for _ in range(4)]
for i in range(16):
state_list[i // 4].update({str(i): {}})
def broadcast_object_list(obj, src, **kwargs):
assert src in [0, 4, 8, 12]
obj[0] = state_list.pop(0)
torch_mock.distributed.broadcast_object_list = broadcast_object_list
with mock.patch("lightning.data.streaming.dataset.torch", torch_mock):
state_dict = dataset.state_dict()
assert len(state_dict) == 16