From ad505f36f3206accd67523d0872d9f96ed309fff Mon Sep 17 00:00:00 2001 From: Stanley Kudrow <63244186+stankudrow@users.noreply.github.com> Date: Sun, 1 Sep 2024 10:18:39 +0300 Subject: [PATCH] Add `itertools.batched` v3.13 function (#177) * Implement itertools.batched * Match upstream args/example, add docs, better error tests --------- Co-authored-by: Stanley Kudrow Co-authored-by: Amethyst Reese --- aioitertools/__init__.py | 1 + aioitertools/itertools.py | 30 +++++++++++++++++++++++++++--- aioitertools/tests/itertools.py | 21 +++++++++++++++++++++ 3 files changed, 49 insertions(+), 3 deletions(-) diff --git a/aioitertools/__init__.py b/aioitertools/__init__.py index 2933601..1816bd8 100644 --- a/aioitertools/__init__.py +++ b/aioitertools/__init__.py @@ -25,6 +25,7 @@ from .builtins import ( ) from .itertools import ( accumulate, + batched, chain, combinations, combinations_with_replacement, diff --git a/aioitertools/itertools.py b/aioitertools/itertools.py index 7332811..b43e483 100644 --- a/aioitertools/itertools.py +++ b/aioitertools/itertools.py @@ -19,7 +19,7 @@ import itertools import operator from typing import Any, AsyncIterator, List, Optional, overload, Tuple -from .builtins import enumerate, iter, list, next, zip +from .builtins import enumerate, iter, list, next, tuple, zip from .helpers import maybe_await from .types import ( Accumulator, @@ -66,6 +66,30 @@ async def accumulate( yield total +async def batched( + iterable: AnyIterable[T], + n: int, + *, + strict: bool = False, +) -> AsyncIterator[Tuple[T, ...]]: + """ + Yield batches of values from the given iterable. The final batch may be shorter. + + Example:: + + async for batch in batched(range(15), 5): + ... # (0, 1, 2, 3, 4), (5, 6, 7, 8, 9), (10, 11, 12, 13, 14) + + """ + if n < 1: + raise ValueError("n must be at least one") + aiterator = iter(iterable) + while batch := await tuple(islice(aiterator, n)): + if strict and len(batch) != n: + raise ValueError("batched: incomplete batch") + yield batch + + class Chain: def __call__(self, *itrs: AnyIterable[T]) -> AsyncIterator[T]: """ @@ -517,7 +541,7 @@ def tee(itr: AnyIterable[T], n: int = 2) -> Tuple[AsyncIterator[T], ...]: break yield value - return tuple(gen(k, q) for k, q in builtins.enumerate(queues)) + return builtins.tuple(gen(k, q) for k, q in builtins.enumerate(queues)) async def zip_longest( @@ -556,4 +580,4 @@ async def zip_longest( raise value if finished >= itr_count: break - yield tuple(values) + yield builtins.tuple(values) diff --git a/aioitertools/tests/itertools.py b/aioitertools/tests/itertools.py index 76c1096..dc9a6ec 100644 --- a/aioitertools/tests/itertools.py +++ b/aioitertools/tests/itertools.py @@ -77,6 +77,27 @@ class ItertoolsTest(TestCase): self.assertEqual(values, []) + @async_test + async def test_batched(self): + test_matrix = [ + ([], 1, []), + ([1, 2, 3], 1, [(1,), (2,), (3,)]), + ([2, 3, 4], 2, [(2, 3), (4,)]), + ([5, 6], 3, [(5, 6)]), + (ait.iter([-2, -1, 0, 1, 2]), 2, [(-2, -1), (0, 1), (2,)]), + ] + for iterable, batch_size, answer in test_matrix: + result = [batch async for batch in ait.batched(iterable, batch_size)] + + self.assertEqual(result, answer) + + @async_test + async def test_batched_errors(self): + with self.assertRaisesRegex(ValueError, "n must be at least one"): + [batch async for batch in ait.batched([1], 0)] + with self.assertRaisesRegex(ValueError, "incomplete batch"): + [batch async for batch in ait.batched([1, 2, 3], 2, strict=True)] + @async_test async def test_chain_lists(self): it = ait.chain(slist, srange)