Better default for drop_last in a distributed setting (#19478)

Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com>
Co-authored-by: thomas <thomas@thomass-MacBook-Pro.local>
This commit is contained in:
thomas chaton 2024-02-15 17:11:45 +00:00 committed by GitHub
parent 265025bd5d
commit b024e7a73b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 96 additions and 36 deletions

View File

@ -13,6 +13,7 @@
import hashlib
import os
from logging import Logger
from time import time
from typing import Any, Dict, List, Optional, Tuple, Union
@ -31,6 +32,8 @@ from lightning.data.streaming.serializers import Serializer
from lightning.data.streaming.shuffle import FullShuffle, NoShuffle, Shuffle
from lightning.data.utilities.env import _DistributedEnv, _is_in_dataloader_worker, _WorkerEnv
logger = Logger(__name__)
class StreamingDataset(IterableDataset):
"""The streaming dataset can be used once your data have been optimised using the DatasetOptimiser class."""
@ -40,7 +43,7 @@ class StreamingDataset(IterableDataset):
input_dir: Union[str, "Dir"],
item_loader: Optional[BaseItemLoader] = None,
shuffle: bool = False,
drop_last: bool = False,
drop_last: Optional[bool] = None,
seed: int = 42,
serializers: Optional[Dict[str, Serializer]] = None,
max_cache_size: Union[int, str] = "100GB",
@ -53,6 +56,8 @@ class StreamingDataset(IterableDataset):
shuffle: Whether to shuffle the data.
drop_last: If `True`, drops the last items to ensure that
all processes/workers return the same amount of data.
The argument `drop_last` is set to `True` in a distributed setting
and `False` otherwise.
seed: Random seed for shuffling.
serializers: The serializers used to serialize and deserialize the chunks.
max_cache_size: The maximum cache size used by the StreamingDataset.
@ -68,12 +73,24 @@ class StreamingDataset(IterableDataset):
self.item_loader = item_loader
self.shuffle: bool = shuffle
self.drop_last = drop_last
self.distributed_env = _DistributedEnv.detect()
if self.distributed_env.world_size > 1:
if drop_last is False:
logger.warn(
"You're operating within a distributed environment and have disabled the `drop_last` option. "
"Please note that this configuration may lead to training interruptions if your system depends "
"on distributed collectives."
)
else:
drop_last = True
self.drop_last = drop_last or False
self.seed = seed
self.max_cache_size = max_cache_size
self.cache: Optional[Cache] = None
self.distributed_env = _DistributedEnv.detect()
self.worker_env: Optional[_WorkerEnv] = None
self.worker_chunks: List[int] = []
self.worker_intervals: List[List[int]] = []

View File

@ -88,7 +88,7 @@ class PILSerializer(Serializer):
return Image.frombytes(mode, size, raw) # pyright: ignore
def can_serialize(self, item: Any) -> bool:
return isinstance(item, Image.Image) and not isinstance(item, JpegImageFile)
return bool(_PIL_AVAILABLE) and isinstance(item, Image.Image) and not isinstance(item, JpegImageFile)
class JPEGSerializer(Serializer):
@ -133,7 +133,7 @@ class JPEGSerializer(Serializer):
return img
def can_serialize(self, item: Any) -> bool:
return isinstance(item, JpegImageFile)
return bool(_PIL_AVAILABLE) and isinstance(item, JpegImageFile)
class BytesSerializer(Serializer):

View File

@ -13,7 +13,7 @@
from abc import ABC, abstractmethod
from functools import lru_cache
from typing import Any, List, Tuple
from typing import Any, List
import numpy as np
@ -38,8 +38,10 @@ class Shuffle(ABC):
items_per_process = [
sum((interval[-1] - interval[0]) for interval in intervals) for intervals in intervals_per_ranks
]
min_items_per_process = min(items_per_process)
return min_items_per_process
# Validate each processes gets the exact number of elements
if len(items_per_process) > 1:
assert all(items_per_process[0] == items_to_process for items_to_process in items_per_process[:1])
return items_per_process[0]
return sum((interval[-1] - interval[0]) for interval in intervals_per_ranks[distributed_env.global_rank])
@ -58,16 +60,18 @@ class NoShuffle(Shuffle):
@lru_cache(maxsize=10)
def get_chunks_and_intervals_per_ranks(self, distributed_env: _DistributedEnv, current_epoch: int) -> Any:
# 1. Get the intervals
chunk_intervals = self.cache.get_chunk_intervals()
chunks_per_ranks: List[List[int]] = [[] for _ in range(distributed_env.world_size)]
intervals_per_ranks: List[List[Tuple]] = [[] for _ in range(distributed_env.world_size)]
for chunk_index, chunk_interval in enumerate(chunk_intervals):
replica_index = chunk_index % distributed_env.world_size
chunks_per_ranks[replica_index].append(chunk_index)
intervals_per_ranks[replica_index].append(chunk_interval)
indexes = range(len(chunk_intervals))
# 2. Compute the items budget of each rank
chunks_per_ranks, intervals_per_ranks = _associate_chunks_and_internals_to_ranks(
distributed_env, indexes, chunk_intervals, self.drop_last
)
return chunks_per_ranks, intervals_per_ranks
def __call__(self, array: np.ndarray, num_chunks: int, current_epoch: int, chunk_index: int) -> List[int]:
return array.tolist()

View File

@ -45,7 +45,7 @@ class _DistributedEnv:
# validate the world size is divisble by the number of GPUs
assert world_size % torch.cuda.device_count() == 0
return cls(world_size=world_size, global_rank=global_rank, num_nodes=num_nodes)
return cls(world_size=world_size, global_rank=global_rank, num_nodes=max(1, num_nodes))
def __repr__(self) -> str:
return f"{self.__class__.__name__}(world_size: {self.world_size}, global_rank: {self.global_rank}\n)"

View File

@ -22,6 +22,7 @@ import torch
from lightning import seed_everything
from lightning.data.processing import functions
from lightning.data.streaming import Cache
from lightning.data.streaming import dataset as dataset_module
from lightning.data.streaming.dataloader import StreamingDataLoader
from lightning.data.streaming.dataset import (
_INDEX_FILENAME,
@ -100,31 +101,46 @@ def test_streaming_dataset_distributed_no_shuffle(drop_last, tmpdir):
assert len(dataset) == 101
dataset.distributed_env = _DistributedEnv(2, 0, 1)
assert len(dataset) == 50
dataset.distributed_env = _DistributedEnv(2, 1, 1)
assert len(dataset) == 50 + int(not drop_last)
dataset_iter = iter(dataset)
assert len(dataset_iter) == 50 + int(not drop_last)
dataset.distributed_env = _DistributedEnv(2, 0, 1)
process_1_1 = list(dataset_iter)
assert len(process_1_1) == 50 + int(not drop_last)
assert len(process_1_1) == 50
assert process_1_1[:10] == [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
dataset_iter = iter(dataset)
assert len(dataset_iter) == 50 + int(not drop_last)
assert len(dataset_iter) == 50
process_1_2 = list(dataset_iter)
assert process_1_2[:10] == [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
assert len(process_1_2) == 50 + int(not drop_last)
assert len(process_1_2) == 50
dataset = StreamingDataset(input_dir=str(tmpdir), shuffle=False, drop_last=drop_last)
dataset.distributed_env = _DistributedEnv(2, 1, 1)
assert len(dataset) == 50
dataset_iter = iter(dataset)
process_2_1 = list(dataset_iter)
assert process_2_1[:10] == [10, 11, 12, 13, 14, 15, 16, 17, 18, 19]
assert len(process_2_1) == 50
dataset_iter = iter(dataset)
assert len(dataset_iter) == 50
process_2_2 = list(dataset_iter)
assert process_2_2[:10] == [10, 11, 12, 13, 14, 15, 16, 17, 18, 19]
assert len(process_2_2) == 50
assert len(dataset) == 50 + int(not drop_last)
dataset_iter = iter(dataset)
process_2_1 = list(dataset_iter)
assert process_2_1[:10] == [50, 51, 52, 53, 54, 55, 56, 57, 58, 59]
assert len(process_2_1) == 50 + int(not drop_last)
dataset_iter = iter(dataset)
assert len(dataset_iter) == 50 + int(not drop_last)
process_2_2 = list(dataset_iter)
assert process_2_2[:10] == [50, 51, 52, 53, 54, 55, 56, 57, 58, 59]
assert len(process_2_2) == 50 + int(not drop_last)
_, intervals_per_ranks = dataset.shuffler.get_chunks_and_intervals_per_ranks(
dataset.distributed_env, dataset.current_epoch
@ -503,11 +519,11 @@ def test_dataset_for_text_tokens_distributed_num_workers(tmpdir):
assert len(dataset) == 20
dataset.distributed_env = _DistributedEnv(2, 0, 1)
dataloader = DataLoader(dataset, batch_size=2, shuffle=False, num_workers=2)
dataloader = DataLoader(dataset, batch_size=2, shuffle=False)
assert len(dataloader) == 6
assert len(dataloader) == 5
expected = [[0, 10], [80, 90], [20, 30], [100, 110], [160, 170], [180, 190]]
expected = [[0, 10], [20, 30], [40, 50], [60, 70], [80, 90]]
for batch_idx, batch in enumerate(dataloader):
assert [batch[0][0].item(), batch[1][0].item()] == expected[batch_idx]
@ -515,9 +531,9 @@ def test_dataset_for_text_tokens_distributed_num_workers(tmpdir):
dataset.distributed_env = _DistributedEnv(2, 1, 1)
dataloader = DataLoader(dataset, batch_size=2, shuffle=False)
assert len(dataloader) == 4
assert len(dataloader) == 5
expected = [[40, 50], [60, 70], [120, 130], [140, 150]]
expected = [[100, 110], [120, 130], [140, 150], [160, 170], [180, 190]]
for batch_idx, batch in enumerate(dataloader):
assert [batch[0][0].item(), batch[1][0].item()] == expected[batch_idx]
@ -570,7 +586,7 @@ def test_dataset_for_text_tokens_distributed_num_workers_end_to_end(tmpdir, monk
assert len(dataloader) == 5
expected = [[0, 10], [40, 50], [80, 90], [120, 130], [160, 170]]
expected = [[0, 10], [20, 30], [40, 50], [60, 70], [80, 90]]
for batch_idx, batch in enumerate(dataloader):
assert [batch[0][0].item(), batch[1][0].item()] == expected[batch_idx]
@ -580,7 +596,7 @@ def test_dataset_for_text_tokens_distributed_num_workers_end_to_end(tmpdir, monk
assert len(dataloader) == 5
expected = [[20, 30], [60, 70], [100, 110], [140, 150], [180, 190]]
expected = [[100, 110], [120, 130], [140, 150], [160, 170], [180, 190]]
for batch_idx, batch in enumerate(dataloader):
assert [batch[0][0].item(), batch[1][0].item()] == expected[batch_idx]
@ -760,7 +776,7 @@ def test_dataset_valid_state(tmpdir, monkeypatch):
cache.merge()
dataset = EmulateS3StreamingDataset(
input_dir=Dir(cache_dir, data_dir), item_loader=TokensLoader(block_size), shuffle=False
input_dir=Dir(cache_dir, data_dir), item_loader=TokensLoader(block_size), shuffle=False, drop_last=False,
)
dataloader = DataLoader(dataset, num_workers=1, batch_size=2)
dataloader_iter = iter(dataloader)
@ -865,3 +881,26 @@ def test_replay_chunks_sampling():
assert _replay_chunks_sampling(workers_intervals, {0: 16, 1: 11}) == ({0: 3, 1: 2}, {0: 1, 1: 1})
assert _replay_chunks_sampling(workers_intervals, {0: 14, 1: 13}) == ({0: 2, 1: 2}, {0: 4, 1: 3})
assert _replay_chunks_sampling(workers_intervals, {0: 15, 1: 12}) == ({0: 3, 1: 2}, {0: 0, 1: 2})
def test_dataset_distributed_drop_last(tmpdir, monkeypatch):
class _DistributedEnvMock():
def detect(cls):
return _DistributedEnv(2, 0, 1)
logger_mock = mock.MagicMock()
monkeypatch.setattr(dataset_module, "_DistributedEnv", _DistributedEnvMock())
monkeypatch.setattr(dataset_module, "logger", logger_mock)
dataset = StreamingDataset(str(tmpdir), drop_last=None)
assert dataset.drop_last
dataset = StreamingDataset(str(tmpdir), drop_last=False)
assert not dataset.drop_last
warn_value = logger_mock.warn._mock_mock_calls[0].args[0]
assert warn_value == "You're operating within a distributed environment and have disabled the `drop_last`" \
" option. Please note that this configuration may lead to training interruptions if your system depends on distributed collectives." # noqa: E501