Add `itertools.batched` v3.13 function (#177)

* Implement itertools.batched

* Match upstream args/example, add docs, better error tests

---------

Co-authored-by: Stanley Kudrow <stankudrow@reply.no>
Co-authored-by: Amethyst Reese <amethyst@n7.gg>
This commit is contained in:
Stanley Kudrow 2024-09-01 10:18:39 +03:00 committed by GitHub
parent 538f1698bc
commit ad505f36f3
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 49 additions and 3 deletions

View File

@ -25,6 +25,7 @@ from .builtins import (
)
from .itertools import (
accumulate,
batched,
chain,
combinations,
combinations_with_replacement,

View File

@ -19,7 +19,7 @@ import itertools
import operator
from typing import Any, AsyncIterator, List, Optional, overload, Tuple
from .builtins import enumerate, iter, list, next, zip
from .builtins import enumerate, iter, list, next, tuple, zip
from .helpers import maybe_await
from .types import (
Accumulator,
@ -66,6 +66,30 @@ async def accumulate(
yield total
async def batched(
iterable: AnyIterable[T],
n: int,
*,
strict: bool = False,
) -> AsyncIterator[Tuple[T, ...]]:
"""
Yield batches of values from the given iterable. The final batch may be shorter.
Example::
async for batch in batched(range(15), 5):
... # (0, 1, 2, 3, 4), (5, 6, 7, 8, 9), (10, 11, 12, 13, 14)
"""
if n < 1:
raise ValueError("n must be at least one")
aiterator = iter(iterable)
while batch := await tuple(islice(aiterator, n)):
if strict and len(batch) != n:
raise ValueError("batched: incomplete batch")
yield batch
class Chain:
def __call__(self, *itrs: AnyIterable[T]) -> AsyncIterator[T]:
"""
@ -517,7 +541,7 @@ def tee(itr: AnyIterable[T], n: int = 2) -> Tuple[AsyncIterator[T], ...]:
break
yield value
return tuple(gen(k, q) for k, q in builtins.enumerate(queues))
return builtins.tuple(gen(k, q) for k, q in builtins.enumerate(queues))
async def zip_longest(
@ -556,4 +580,4 @@ async def zip_longest(
raise value
if finished >= itr_count:
break
yield tuple(values)
yield builtins.tuple(values)

View File

@ -77,6 +77,27 @@ class ItertoolsTest(TestCase):
self.assertEqual(values, [])
@async_test
async def test_batched(self):
test_matrix = [
([], 1, []),
([1, 2, 3], 1, [(1,), (2,), (3,)]),
([2, 3, 4], 2, [(2, 3), (4,)]),
([5, 6], 3, [(5, 6)]),
(ait.iter([-2, -1, 0, 1, 2]), 2, [(-2, -1), (0, 1), (2,)]),
]
for iterable, batch_size, answer in test_matrix:
result = [batch async for batch in ait.batched(iterable, batch_size)]
self.assertEqual(result, answer)
@async_test
async def test_batched_errors(self):
with self.assertRaisesRegex(ValueError, "n must be at least one"):
[batch async for batch in ait.batched([1], 0)]
with self.assertRaisesRegex(ValueError, "incomplete batch"):
[batch async for batch in ait.batched([1, 2, 3], 2, strict=True)]
@async_test
async def test_chain_lists(self):
it = ait.chain(slist, srange)