Better default for drop_last in a distributed setting (#19478)
Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com> Co-authored-by: thomas <thomas@thomass-MacBook-Pro.local>
This commit is contained in:
parent
265025bd5d
commit
b024e7a73b
|
@ -13,6 +13,7 @@
|
|||
|
||||
import hashlib
|
||||
import os
|
||||
from logging import Logger
|
||||
from time import time
|
||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||
|
||||
|
@ -31,6 +32,8 @@ from lightning.data.streaming.serializers import Serializer
|
|||
from lightning.data.streaming.shuffle import FullShuffle, NoShuffle, Shuffle
|
||||
from lightning.data.utilities.env import _DistributedEnv, _is_in_dataloader_worker, _WorkerEnv
|
||||
|
||||
logger = Logger(__name__)
|
||||
|
||||
|
||||
class StreamingDataset(IterableDataset):
|
||||
"""The streaming dataset can be used once your data have been optimised using the DatasetOptimiser class."""
|
||||
|
@ -40,7 +43,7 @@ class StreamingDataset(IterableDataset):
|
|||
input_dir: Union[str, "Dir"],
|
||||
item_loader: Optional[BaseItemLoader] = None,
|
||||
shuffle: bool = False,
|
||||
drop_last: bool = False,
|
||||
drop_last: Optional[bool] = None,
|
||||
seed: int = 42,
|
||||
serializers: Optional[Dict[str, Serializer]] = None,
|
||||
max_cache_size: Union[int, str] = "100GB",
|
||||
|
@ -53,6 +56,8 @@ class StreamingDataset(IterableDataset):
|
|||
shuffle: Whether to shuffle the data.
|
||||
drop_last: If `True`, drops the last items to ensure that
|
||||
all processes/workers return the same amount of data.
|
||||
The argument `drop_last` is set to `True` in a distributed setting
|
||||
and `False` otherwise.
|
||||
seed: Random seed for shuffling.
|
||||
serializers: The serializers used to serialize and deserialize the chunks.
|
||||
max_cache_size: The maximum cache size used by the StreamingDataset.
|
||||
|
@ -68,12 +73,24 @@ class StreamingDataset(IterableDataset):
|
|||
|
||||
self.item_loader = item_loader
|
||||
self.shuffle: bool = shuffle
|
||||
self.drop_last = drop_last
|
||||
self.distributed_env = _DistributedEnv.detect()
|
||||
|
||||
if self.distributed_env.world_size > 1:
|
||||
if drop_last is False:
|
||||
logger.warn(
|
||||
"You're operating within a distributed environment and have disabled the `drop_last` option. "
|
||||
"Please note that this configuration may lead to training interruptions if your system depends "
|
||||
"on distributed collectives."
|
||||
)
|
||||
else:
|
||||
drop_last = True
|
||||
|
||||
self.drop_last = drop_last or False
|
||||
|
||||
self.seed = seed
|
||||
self.max_cache_size = max_cache_size
|
||||
|
||||
self.cache: Optional[Cache] = None
|
||||
self.distributed_env = _DistributedEnv.detect()
|
||||
self.worker_env: Optional[_WorkerEnv] = None
|
||||
self.worker_chunks: List[int] = []
|
||||
self.worker_intervals: List[List[int]] = []
|
||||
|
|
|
@ -88,7 +88,7 @@ class PILSerializer(Serializer):
|
|||
return Image.frombytes(mode, size, raw) # pyright: ignore
|
||||
|
||||
def can_serialize(self, item: Any) -> bool:
|
||||
return isinstance(item, Image.Image) and not isinstance(item, JpegImageFile)
|
||||
return bool(_PIL_AVAILABLE) and isinstance(item, Image.Image) and not isinstance(item, JpegImageFile)
|
||||
|
||||
|
||||
class JPEGSerializer(Serializer):
|
||||
|
@ -133,7 +133,7 @@ class JPEGSerializer(Serializer):
|
|||
return img
|
||||
|
||||
def can_serialize(self, item: Any) -> bool:
|
||||
return isinstance(item, JpegImageFile)
|
||||
return bool(_PIL_AVAILABLE) and isinstance(item, JpegImageFile)
|
||||
|
||||
|
||||
class BytesSerializer(Serializer):
|
||||
|
|
|
@ -13,7 +13,7 @@
|
|||
|
||||
from abc import ABC, abstractmethod
|
||||
from functools import lru_cache
|
||||
from typing import Any, List, Tuple
|
||||
from typing import Any, List
|
||||
|
||||
import numpy as np
|
||||
|
||||
|
@ -38,8 +38,10 @@ class Shuffle(ABC):
|
|||
items_per_process = [
|
||||
sum((interval[-1] - interval[0]) for interval in intervals) for intervals in intervals_per_ranks
|
||||
]
|
||||
min_items_per_process = min(items_per_process)
|
||||
return min_items_per_process
|
||||
# Validate each processes gets the exact number of elements
|
||||
if len(items_per_process) > 1:
|
||||
assert all(items_per_process[0] == items_to_process for items_to_process in items_per_process[:1])
|
||||
return items_per_process[0]
|
||||
|
||||
return sum((interval[-1] - interval[0]) for interval in intervals_per_ranks[distributed_env.global_rank])
|
||||
|
||||
|
@ -58,16 +60,18 @@ class NoShuffle(Shuffle):
|
|||
|
||||
@lru_cache(maxsize=10)
|
||||
def get_chunks_and_intervals_per_ranks(self, distributed_env: _DistributedEnv, current_epoch: int) -> Any:
|
||||
# 1. Get the intervals
|
||||
chunk_intervals = self.cache.get_chunk_intervals()
|
||||
chunks_per_ranks: List[List[int]] = [[] for _ in range(distributed_env.world_size)]
|
||||
intervals_per_ranks: List[List[Tuple]] = [[] for _ in range(distributed_env.world_size)]
|
||||
for chunk_index, chunk_interval in enumerate(chunk_intervals):
|
||||
replica_index = chunk_index % distributed_env.world_size
|
||||
chunks_per_ranks[replica_index].append(chunk_index)
|
||||
intervals_per_ranks[replica_index].append(chunk_interval)
|
||||
indexes = range(len(chunk_intervals))
|
||||
|
||||
# 2. Compute the items budget of each rank
|
||||
chunks_per_ranks, intervals_per_ranks = _associate_chunks_and_internals_to_ranks(
|
||||
distributed_env, indexes, chunk_intervals, self.drop_last
|
||||
)
|
||||
|
||||
return chunks_per_ranks, intervals_per_ranks
|
||||
|
||||
|
||||
def __call__(self, array: np.ndarray, num_chunks: int, current_epoch: int, chunk_index: int) -> List[int]:
|
||||
return array.tolist()
|
||||
|
||||
|
|
|
@ -45,7 +45,7 @@ class _DistributedEnv:
|
|||
# validate the world size is divisble by the number of GPUs
|
||||
assert world_size % torch.cuda.device_count() == 0
|
||||
|
||||
return cls(world_size=world_size, global_rank=global_rank, num_nodes=num_nodes)
|
||||
return cls(world_size=world_size, global_rank=global_rank, num_nodes=max(1, num_nodes))
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"{self.__class__.__name__}(world_size: {self.world_size}, global_rank: {self.global_rank}\n)"
|
||||
|
|
|
@ -22,6 +22,7 @@ import torch
|
|||
from lightning import seed_everything
|
||||
from lightning.data.processing import functions
|
||||
from lightning.data.streaming import Cache
|
||||
from lightning.data.streaming import dataset as dataset_module
|
||||
from lightning.data.streaming.dataloader import StreamingDataLoader
|
||||
from lightning.data.streaming.dataset import (
|
||||
_INDEX_FILENAME,
|
||||
|
@ -100,31 +101,46 @@ def test_streaming_dataset_distributed_no_shuffle(drop_last, tmpdir):
|
|||
assert len(dataset) == 101
|
||||
|
||||
dataset.distributed_env = _DistributedEnv(2, 0, 1)
|
||||
assert len(dataset) == 50
|
||||
|
||||
dataset.distributed_env = _DistributedEnv(2, 1, 1)
|
||||
assert len(dataset) == 50 + int(not drop_last)
|
||||
|
||||
dataset_iter = iter(dataset)
|
||||
assert len(dataset_iter) == 50 + int(not drop_last)
|
||||
|
||||
dataset.distributed_env = _DistributedEnv(2, 0, 1)
|
||||
|
||||
process_1_1 = list(dataset_iter)
|
||||
assert len(process_1_1) == 50 + int(not drop_last)
|
||||
|
||||
assert len(process_1_1) == 50
|
||||
assert process_1_1[:10] == [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
|
||||
dataset_iter = iter(dataset)
|
||||
assert len(dataset_iter) == 50 + int(not drop_last)
|
||||
|
||||
assert len(dataset_iter) == 50
|
||||
process_1_2 = list(dataset_iter)
|
||||
assert process_1_2[:10] == [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
|
||||
assert len(process_1_2) == 50 + int(not drop_last)
|
||||
|
||||
assert len(process_1_2) == 50
|
||||
|
||||
dataset = StreamingDataset(input_dir=str(tmpdir), shuffle=False, drop_last=drop_last)
|
||||
dataset.distributed_env = _DistributedEnv(2, 1, 1)
|
||||
assert len(dataset) == 50
|
||||
dataset_iter = iter(dataset)
|
||||
process_2_1 = list(dataset_iter)
|
||||
assert process_2_1[:10] == [10, 11, 12, 13, 14, 15, 16, 17, 18, 19]
|
||||
assert len(process_2_1) == 50
|
||||
dataset_iter = iter(dataset)
|
||||
assert len(dataset_iter) == 50
|
||||
process_2_2 = list(dataset_iter)
|
||||
assert process_2_2[:10] == [10, 11, 12, 13, 14, 15, 16, 17, 18, 19]
|
||||
|
||||
assert len(process_2_2) == 50
|
||||
assert len(dataset) == 50 + int(not drop_last)
|
||||
dataset_iter = iter(dataset)
|
||||
|
||||
process_2_1 = list(dataset_iter)
|
||||
assert process_2_1[:10] == [50, 51, 52, 53, 54, 55, 56, 57, 58, 59]
|
||||
|
||||
assert len(process_2_1) == 50 + int(not drop_last)
|
||||
dataset_iter = iter(dataset)
|
||||
|
||||
assert len(dataset_iter) == 50 + int(not drop_last)
|
||||
process_2_2 = list(dataset_iter)
|
||||
|
||||
assert process_2_2[:10] == [50, 51, 52, 53, 54, 55, 56, 57, 58, 59]
|
||||
|
||||
assert len(process_2_2) == 50 + int(not drop_last)
|
||||
|
||||
_, intervals_per_ranks = dataset.shuffler.get_chunks_and_intervals_per_ranks(
|
||||
dataset.distributed_env, dataset.current_epoch
|
||||
|
@ -503,11 +519,11 @@ def test_dataset_for_text_tokens_distributed_num_workers(tmpdir):
|
|||
assert len(dataset) == 20
|
||||
|
||||
dataset.distributed_env = _DistributedEnv(2, 0, 1)
|
||||
dataloader = DataLoader(dataset, batch_size=2, shuffle=False, num_workers=2)
|
||||
dataloader = DataLoader(dataset, batch_size=2, shuffle=False)
|
||||
|
||||
assert len(dataloader) == 6
|
||||
assert len(dataloader) == 5
|
||||
|
||||
expected = [[0, 10], [80, 90], [20, 30], [100, 110], [160, 170], [180, 190]]
|
||||
expected = [[0, 10], [20, 30], [40, 50], [60, 70], [80, 90]]
|
||||
|
||||
for batch_idx, batch in enumerate(dataloader):
|
||||
assert [batch[0][0].item(), batch[1][0].item()] == expected[batch_idx]
|
||||
|
@ -515,9 +531,9 @@ def test_dataset_for_text_tokens_distributed_num_workers(tmpdir):
|
|||
dataset.distributed_env = _DistributedEnv(2, 1, 1)
|
||||
dataloader = DataLoader(dataset, batch_size=2, shuffle=False)
|
||||
|
||||
assert len(dataloader) == 4
|
||||
assert len(dataloader) == 5
|
||||
|
||||
expected = [[40, 50], [60, 70], [120, 130], [140, 150]]
|
||||
expected = [[100, 110], [120, 130], [140, 150], [160, 170], [180, 190]]
|
||||
|
||||
for batch_idx, batch in enumerate(dataloader):
|
||||
assert [batch[0][0].item(), batch[1][0].item()] == expected[batch_idx]
|
||||
|
@ -570,7 +586,7 @@ def test_dataset_for_text_tokens_distributed_num_workers_end_to_end(tmpdir, monk
|
|||
|
||||
assert len(dataloader) == 5
|
||||
|
||||
expected = [[0, 10], [40, 50], [80, 90], [120, 130], [160, 170]]
|
||||
expected = [[0, 10], [20, 30], [40, 50], [60, 70], [80, 90]]
|
||||
|
||||
for batch_idx, batch in enumerate(dataloader):
|
||||
assert [batch[0][0].item(), batch[1][0].item()] == expected[batch_idx]
|
||||
|
@ -580,7 +596,7 @@ def test_dataset_for_text_tokens_distributed_num_workers_end_to_end(tmpdir, monk
|
|||
|
||||
assert len(dataloader) == 5
|
||||
|
||||
expected = [[20, 30], [60, 70], [100, 110], [140, 150], [180, 190]]
|
||||
expected = [[100, 110], [120, 130], [140, 150], [160, 170], [180, 190]]
|
||||
|
||||
for batch_idx, batch in enumerate(dataloader):
|
||||
assert [batch[0][0].item(), batch[1][0].item()] == expected[batch_idx]
|
||||
|
@ -760,7 +776,7 @@ def test_dataset_valid_state(tmpdir, monkeypatch):
|
|||
cache.merge()
|
||||
|
||||
dataset = EmulateS3StreamingDataset(
|
||||
input_dir=Dir(cache_dir, data_dir), item_loader=TokensLoader(block_size), shuffle=False
|
||||
input_dir=Dir(cache_dir, data_dir), item_loader=TokensLoader(block_size), shuffle=False, drop_last=False,
|
||||
)
|
||||
dataloader = DataLoader(dataset, num_workers=1, batch_size=2)
|
||||
dataloader_iter = iter(dataloader)
|
||||
|
@ -865,3 +881,26 @@ def test_replay_chunks_sampling():
|
|||
assert _replay_chunks_sampling(workers_intervals, {0: 16, 1: 11}) == ({0: 3, 1: 2}, {0: 1, 1: 1})
|
||||
assert _replay_chunks_sampling(workers_intervals, {0: 14, 1: 13}) == ({0: 2, 1: 2}, {0: 4, 1: 3})
|
||||
assert _replay_chunks_sampling(workers_intervals, {0: 15, 1: 12}) == ({0: 3, 1: 2}, {0: 0, 1: 2})
|
||||
|
||||
|
||||
def test_dataset_distributed_drop_last(tmpdir, monkeypatch):
|
||||
|
||||
class _DistributedEnvMock():
|
||||
|
||||
def detect(cls):
|
||||
return _DistributedEnv(2, 0, 1)
|
||||
|
||||
logger_mock = mock.MagicMock()
|
||||
|
||||
monkeypatch.setattr(dataset_module, "_DistributedEnv", _DistributedEnvMock())
|
||||
monkeypatch.setattr(dataset_module, "logger", logger_mock)
|
||||
|
||||
dataset = StreamingDataset(str(tmpdir), drop_last=None)
|
||||
assert dataset.drop_last
|
||||
|
||||
dataset = StreamingDataset(str(tmpdir), drop_last=False)
|
||||
assert not dataset.drop_last
|
||||
|
||||
warn_value = logger_mock.warn._mock_mock_calls[0].args[0]
|
||||
assert warn_value == "You're operating within a distributed environment and have disabled the `drop_last`" \
|
||||
" option. Please note that this configuration may lead to training interruptions if your system depends on distributed collectives." # noqa: E501
|
||||
|
|
Loading…
Reference in New Issue