Add `itertools.batched` v3.13 function (#177)
* Implement itertools.batched * Match upstream args/example, add docs, better error tests --------- Co-authored-by: Stanley Kudrow <stankudrow@reply.no> Co-authored-by: Amethyst Reese <amethyst@n7.gg>
This commit is contained in:
parent
538f1698bc
commit
ad505f36f3
|
@ -25,6 +25,7 @@ from .builtins import (
|
|||
)
|
||||
from .itertools import (
|
||||
accumulate,
|
||||
batched,
|
||||
chain,
|
||||
combinations,
|
||||
combinations_with_replacement,
|
||||
|
|
|
@ -19,7 +19,7 @@ import itertools
|
|||
import operator
|
||||
from typing import Any, AsyncIterator, List, Optional, overload, Tuple
|
||||
|
||||
from .builtins import enumerate, iter, list, next, zip
|
||||
from .builtins import enumerate, iter, list, next, tuple, zip
|
||||
from .helpers import maybe_await
|
||||
from .types import (
|
||||
Accumulator,
|
||||
|
@ -66,6 +66,30 @@ async def accumulate(
|
|||
yield total
|
||||
|
||||
|
||||
async def batched(
|
||||
iterable: AnyIterable[T],
|
||||
n: int,
|
||||
*,
|
||||
strict: bool = False,
|
||||
) -> AsyncIterator[Tuple[T, ...]]:
|
||||
"""
|
||||
Yield batches of values from the given iterable. The final batch may be shorter.
|
||||
|
||||
Example::
|
||||
|
||||
async for batch in batched(range(15), 5):
|
||||
... # (0, 1, 2, 3, 4), (5, 6, 7, 8, 9), (10, 11, 12, 13, 14)
|
||||
|
||||
"""
|
||||
if n < 1:
|
||||
raise ValueError("n must be at least one")
|
||||
aiterator = iter(iterable)
|
||||
while batch := await tuple(islice(aiterator, n)):
|
||||
if strict and len(batch) != n:
|
||||
raise ValueError("batched: incomplete batch")
|
||||
yield batch
|
||||
|
||||
|
||||
class Chain:
|
||||
def __call__(self, *itrs: AnyIterable[T]) -> AsyncIterator[T]:
|
||||
"""
|
||||
|
@ -517,7 +541,7 @@ def tee(itr: AnyIterable[T], n: int = 2) -> Tuple[AsyncIterator[T], ...]:
|
|||
break
|
||||
yield value
|
||||
|
||||
return tuple(gen(k, q) for k, q in builtins.enumerate(queues))
|
||||
return builtins.tuple(gen(k, q) for k, q in builtins.enumerate(queues))
|
||||
|
||||
|
||||
async def zip_longest(
|
||||
|
@ -556,4 +580,4 @@ async def zip_longest(
|
|||
raise value
|
||||
if finished >= itr_count:
|
||||
break
|
||||
yield tuple(values)
|
||||
yield builtins.tuple(values)
|
||||
|
|
|
@ -77,6 +77,27 @@ class ItertoolsTest(TestCase):
|
|||
|
||||
self.assertEqual(values, [])
|
||||
|
||||
@async_test
|
||||
async def test_batched(self):
|
||||
test_matrix = [
|
||||
([], 1, []),
|
||||
([1, 2, 3], 1, [(1,), (2,), (3,)]),
|
||||
([2, 3, 4], 2, [(2, 3), (4,)]),
|
||||
([5, 6], 3, [(5, 6)]),
|
||||
(ait.iter([-2, -1, 0, 1, 2]), 2, [(-2, -1), (0, 1), (2,)]),
|
||||
]
|
||||
for iterable, batch_size, answer in test_matrix:
|
||||
result = [batch async for batch in ait.batched(iterable, batch_size)]
|
||||
|
||||
self.assertEqual(result, answer)
|
||||
|
||||
@async_test
|
||||
async def test_batched_errors(self):
|
||||
with self.assertRaisesRegex(ValueError, "n must be at least one"):
|
||||
[batch async for batch in ait.batched([1], 0)]
|
||||
with self.assertRaisesRegex(ValueError, "incomplete batch"):
|
||||
[batch async for batch in ait.batched([1, 2, 3], 2, strict=True)]
|
||||
|
||||
@async_test
|
||||
async def test_chain_lists(self):
|
||||
it = ait.chain(slist, srange)
|
||||
|
|
Loading…
Reference in New Issue