From b024e7a73bb33bd84a8823b1712114c1f2900316 Mon Sep 17 00:00:00 2001 From: thomas chaton Date: Thu, 15 Feb 2024 17:11:45 +0000 Subject: [PATCH] Better default for drop_last in a distributed setting (#19478) Co-authored-by: Rohit Gupta Co-authored-by: thomas --- src/lightning/data/streaming/dataset.py | 23 +++++- src/lightning/data/streaming/serializers.py | 4 +- src/lightning/data/streaming/shuffle.py | 22 +++--- src/lightning/data/utilities/env.py | 2 +- tests/tests_data/streaming/test_dataset.py | 81 +++++++++++++++------ 5 files changed, 96 insertions(+), 36 deletions(-) diff --git a/src/lightning/data/streaming/dataset.py b/src/lightning/data/streaming/dataset.py index da0028184e..16159a3a6f 100644 --- a/src/lightning/data/streaming/dataset.py +++ b/src/lightning/data/streaming/dataset.py @@ -13,6 +13,7 @@ import hashlib import os +from logging import Logger from time import time from typing import Any, Dict, List, Optional, Tuple, Union @@ -31,6 +32,8 @@ from lightning.data.streaming.serializers import Serializer from lightning.data.streaming.shuffle import FullShuffle, NoShuffle, Shuffle from lightning.data.utilities.env import _DistributedEnv, _is_in_dataloader_worker, _WorkerEnv +logger = Logger(__name__) + class StreamingDataset(IterableDataset): """The streaming dataset can be used once your data have been optimised using the DatasetOptimiser class.""" @@ -40,7 +43,7 @@ class StreamingDataset(IterableDataset): input_dir: Union[str, "Dir"], item_loader: Optional[BaseItemLoader] = None, shuffle: bool = False, - drop_last: bool = False, + drop_last: Optional[bool] = None, seed: int = 42, serializers: Optional[Dict[str, Serializer]] = None, max_cache_size: Union[int, str] = "100GB", @@ -53,6 +56,8 @@ class StreamingDataset(IterableDataset): shuffle: Whether to shuffle the data. drop_last: If `True`, drops the last items to ensure that all processes/workers return the same amount of data. + The argument `drop_last` is set to `True` in a distributed setting + and `False` otherwise. seed: Random seed for shuffling. serializers: The serializers used to serialize and deserialize the chunks. max_cache_size: The maximum cache size used by the StreamingDataset. @@ -68,12 +73,24 @@ class StreamingDataset(IterableDataset): self.item_loader = item_loader self.shuffle: bool = shuffle - self.drop_last = drop_last + self.distributed_env = _DistributedEnv.detect() + + if self.distributed_env.world_size > 1: + if drop_last is False: + logger.warn( + "You're operating within a distributed environment and have disabled the `drop_last` option. " + "Please note that this configuration may lead to training interruptions if your system depends " + "on distributed collectives." + ) + else: + drop_last = True + + self.drop_last = drop_last or False + self.seed = seed self.max_cache_size = max_cache_size self.cache: Optional[Cache] = None - self.distributed_env = _DistributedEnv.detect() self.worker_env: Optional[_WorkerEnv] = None self.worker_chunks: List[int] = [] self.worker_intervals: List[List[int]] = [] diff --git a/src/lightning/data/streaming/serializers.py b/src/lightning/data/streaming/serializers.py index b7ee18fe78..57c9bb095e 100644 --- a/src/lightning/data/streaming/serializers.py +++ b/src/lightning/data/streaming/serializers.py @@ -88,7 +88,7 @@ class PILSerializer(Serializer): return Image.frombytes(mode, size, raw) # pyright: ignore def can_serialize(self, item: Any) -> bool: - return isinstance(item, Image.Image) and not isinstance(item, JpegImageFile) + return bool(_PIL_AVAILABLE) and isinstance(item, Image.Image) and not isinstance(item, JpegImageFile) class JPEGSerializer(Serializer): @@ -133,7 +133,7 @@ class JPEGSerializer(Serializer): return img def can_serialize(self, item: Any) -> bool: - return isinstance(item, JpegImageFile) + return bool(_PIL_AVAILABLE) and isinstance(item, JpegImageFile) class BytesSerializer(Serializer): diff --git a/src/lightning/data/streaming/shuffle.py b/src/lightning/data/streaming/shuffle.py index b0a48bd728..85ae67dec3 100644 --- a/src/lightning/data/streaming/shuffle.py +++ b/src/lightning/data/streaming/shuffle.py @@ -13,7 +13,7 @@ from abc import ABC, abstractmethod from functools import lru_cache -from typing import Any, List, Tuple +from typing import Any, List import numpy as np @@ -38,8 +38,10 @@ class Shuffle(ABC): items_per_process = [ sum((interval[-1] - interval[0]) for interval in intervals) for intervals in intervals_per_ranks ] - min_items_per_process = min(items_per_process) - return min_items_per_process + # Validate each processes gets the exact number of elements + if len(items_per_process) > 1: + assert all(items_per_process[0] == items_to_process for items_to_process in items_per_process[:1]) + return items_per_process[0] return sum((interval[-1] - interval[0]) for interval in intervals_per_ranks[distributed_env.global_rank]) @@ -58,16 +60,18 @@ class NoShuffle(Shuffle): @lru_cache(maxsize=10) def get_chunks_and_intervals_per_ranks(self, distributed_env: _DistributedEnv, current_epoch: int) -> Any: + # 1. Get the intervals chunk_intervals = self.cache.get_chunk_intervals() - chunks_per_ranks: List[List[int]] = [[] for _ in range(distributed_env.world_size)] - intervals_per_ranks: List[List[Tuple]] = [[] for _ in range(distributed_env.world_size)] - for chunk_index, chunk_interval in enumerate(chunk_intervals): - replica_index = chunk_index % distributed_env.world_size - chunks_per_ranks[replica_index].append(chunk_index) - intervals_per_ranks[replica_index].append(chunk_interval) + indexes = range(len(chunk_intervals)) + + # 2. Compute the items budget of each rank + chunks_per_ranks, intervals_per_ranks = _associate_chunks_and_internals_to_ranks( + distributed_env, indexes, chunk_intervals, self.drop_last + ) return chunks_per_ranks, intervals_per_ranks + def __call__(self, array: np.ndarray, num_chunks: int, current_epoch: int, chunk_index: int) -> List[int]: return array.tolist() diff --git a/src/lightning/data/utilities/env.py b/src/lightning/data/utilities/env.py index c9406963d9..027346d216 100644 --- a/src/lightning/data/utilities/env.py +++ b/src/lightning/data/utilities/env.py @@ -45,7 +45,7 @@ class _DistributedEnv: # validate the world size is divisble by the number of GPUs assert world_size % torch.cuda.device_count() == 0 - return cls(world_size=world_size, global_rank=global_rank, num_nodes=num_nodes) + return cls(world_size=world_size, global_rank=global_rank, num_nodes=max(1, num_nodes)) def __repr__(self) -> str: return f"{self.__class__.__name__}(world_size: {self.world_size}, global_rank: {self.global_rank}\n)" diff --git a/tests/tests_data/streaming/test_dataset.py b/tests/tests_data/streaming/test_dataset.py index e48db3fab9..ddf349c40e 100644 --- a/tests/tests_data/streaming/test_dataset.py +++ b/tests/tests_data/streaming/test_dataset.py @@ -22,6 +22,7 @@ import torch from lightning import seed_everything from lightning.data.processing import functions from lightning.data.streaming import Cache +from lightning.data.streaming import dataset as dataset_module from lightning.data.streaming.dataloader import StreamingDataLoader from lightning.data.streaming.dataset import ( _INDEX_FILENAME, @@ -100,31 +101,46 @@ def test_streaming_dataset_distributed_no_shuffle(drop_last, tmpdir): assert len(dataset) == 101 dataset.distributed_env = _DistributedEnv(2, 0, 1) + assert len(dataset) == 50 + + dataset.distributed_env = _DistributedEnv(2, 1, 1) assert len(dataset) == 50 + int(not drop_last) + dataset_iter = iter(dataset) assert len(dataset_iter) == 50 + int(not drop_last) + + dataset.distributed_env = _DistributedEnv(2, 0, 1) + process_1_1 = list(dataset_iter) - assert len(process_1_1) == 50 + int(not drop_last) + + assert len(process_1_1) == 50 assert process_1_1[:10] == [0, 1, 2, 3, 4, 5, 6, 7, 8, 9] dataset_iter = iter(dataset) - assert len(dataset_iter) == 50 + int(not drop_last) + + assert len(dataset_iter) == 50 process_1_2 = list(dataset_iter) assert process_1_2[:10] == [0, 1, 2, 3, 4, 5, 6, 7, 8, 9] - assert len(process_1_2) == 50 + int(not drop_last) + + assert len(process_1_2) == 50 dataset = StreamingDataset(input_dir=str(tmpdir), shuffle=False, drop_last=drop_last) dataset.distributed_env = _DistributedEnv(2, 1, 1) - assert len(dataset) == 50 - dataset_iter = iter(dataset) - process_2_1 = list(dataset_iter) - assert process_2_1[:10] == [10, 11, 12, 13, 14, 15, 16, 17, 18, 19] - assert len(process_2_1) == 50 - dataset_iter = iter(dataset) - assert len(dataset_iter) == 50 - process_2_2 = list(dataset_iter) - assert process_2_2[:10] == [10, 11, 12, 13, 14, 15, 16, 17, 18, 19] - assert len(process_2_2) == 50 + assert len(dataset) == 50 + int(not drop_last) + dataset_iter = iter(dataset) + + process_2_1 = list(dataset_iter) + assert process_2_1[:10] == [50, 51, 52, 53, 54, 55, 56, 57, 58, 59] + + assert len(process_2_1) == 50 + int(not drop_last) + dataset_iter = iter(dataset) + + assert len(dataset_iter) == 50 + int(not drop_last) + process_2_2 = list(dataset_iter) + + assert process_2_2[:10] == [50, 51, 52, 53, 54, 55, 56, 57, 58, 59] + + assert len(process_2_2) == 50 + int(not drop_last) _, intervals_per_ranks = dataset.shuffler.get_chunks_and_intervals_per_ranks( dataset.distributed_env, dataset.current_epoch @@ -503,11 +519,11 @@ def test_dataset_for_text_tokens_distributed_num_workers(tmpdir): assert len(dataset) == 20 dataset.distributed_env = _DistributedEnv(2, 0, 1) - dataloader = DataLoader(dataset, batch_size=2, shuffle=False, num_workers=2) + dataloader = DataLoader(dataset, batch_size=2, shuffle=False) - assert len(dataloader) == 6 + assert len(dataloader) == 5 - expected = [[0, 10], [80, 90], [20, 30], [100, 110], [160, 170], [180, 190]] + expected = [[0, 10], [20, 30], [40, 50], [60, 70], [80, 90]] for batch_idx, batch in enumerate(dataloader): assert [batch[0][0].item(), batch[1][0].item()] == expected[batch_idx] @@ -515,9 +531,9 @@ def test_dataset_for_text_tokens_distributed_num_workers(tmpdir): dataset.distributed_env = _DistributedEnv(2, 1, 1) dataloader = DataLoader(dataset, batch_size=2, shuffle=False) - assert len(dataloader) == 4 + assert len(dataloader) == 5 - expected = [[40, 50], [60, 70], [120, 130], [140, 150]] + expected = [[100, 110], [120, 130], [140, 150], [160, 170], [180, 190]] for batch_idx, batch in enumerate(dataloader): assert [batch[0][0].item(), batch[1][0].item()] == expected[batch_idx] @@ -570,7 +586,7 @@ def test_dataset_for_text_tokens_distributed_num_workers_end_to_end(tmpdir, monk assert len(dataloader) == 5 - expected = [[0, 10], [40, 50], [80, 90], [120, 130], [160, 170]] + expected = [[0, 10], [20, 30], [40, 50], [60, 70], [80, 90]] for batch_idx, batch in enumerate(dataloader): assert [batch[0][0].item(), batch[1][0].item()] == expected[batch_idx] @@ -580,7 +596,7 @@ def test_dataset_for_text_tokens_distributed_num_workers_end_to_end(tmpdir, monk assert len(dataloader) == 5 - expected = [[20, 30], [60, 70], [100, 110], [140, 150], [180, 190]] + expected = [[100, 110], [120, 130], [140, 150], [160, 170], [180, 190]] for batch_idx, batch in enumerate(dataloader): assert [batch[0][0].item(), batch[1][0].item()] == expected[batch_idx] @@ -760,7 +776,7 @@ def test_dataset_valid_state(tmpdir, monkeypatch): cache.merge() dataset = EmulateS3StreamingDataset( - input_dir=Dir(cache_dir, data_dir), item_loader=TokensLoader(block_size), shuffle=False + input_dir=Dir(cache_dir, data_dir), item_loader=TokensLoader(block_size), shuffle=False, drop_last=False, ) dataloader = DataLoader(dataset, num_workers=1, batch_size=2) dataloader_iter = iter(dataloader) @@ -865,3 +881,26 @@ def test_replay_chunks_sampling(): assert _replay_chunks_sampling(workers_intervals, {0: 16, 1: 11}) == ({0: 3, 1: 2}, {0: 1, 1: 1}) assert _replay_chunks_sampling(workers_intervals, {0: 14, 1: 13}) == ({0: 2, 1: 2}, {0: 4, 1: 3}) assert _replay_chunks_sampling(workers_intervals, {0: 15, 1: 12}) == ({0: 3, 1: 2}, {0: 0, 1: 2}) + + +def test_dataset_distributed_drop_last(tmpdir, monkeypatch): + + class _DistributedEnvMock(): + + def detect(cls): + return _DistributedEnv(2, 0, 1) + + logger_mock = mock.MagicMock() + + monkeypatch.setattr(dataset_module, "_DistributedEnv", _DistributedEnvMock()) + monkeypatch.setattr(dataset_module, "logger", logger_mock) + + dataset = StreamingDataset(str(tmpdir), drop_last=None) + assert dataset.drop_last + + dataset = StreamingDataset(str(tmpdir), drop_last=False) + assert not dataset.drop_last + + warn_value = logger_mock.warn._mock_mock_calls[0].args[0] + assert warn_value == "You're operating within a distributed environment and have disabled the `drop_last`" \ + " option. Please note that this configuration may lead to training interruptions if your system depends on distributed collectives." # noqa: E501