From bfc15b2ee6ae2536e41fdac26d74a4004b2570c1 Mon Sep 17 00:00:00 2001 From: Zsolt Dollenstein Date: Thu, 26 Mar 2020 22:21:50 +0000 Subject: [PATCH] Add more_itertools module with `take` and `chunked` --- README.md | 4 +- aioitertools/itertools.py | 2 +- aioitertools/more_itertools.py | 46 ++++++++++++++++++++++ aioitertools/tests/__init__.py | 1 + aioitertools/tests/more_itertools.py | 59 ++++++++++++++++++++++++++++ 5 files changed, 109 insertions(+), 3 deletions(-) create mode 100644 aioitertools/more_itertools.py create mode 100644 aioitertools/tests/more_itertools.py diff --git a/README.md b/README.md index e801e37..616a836 100644 --- a/README.md +++ b/README.md @@ -72,8 +72,8 @@ async for value in islice(generator1(), 2, None, 2): ``` -See [builtins.py][] and [itertools.py][] for full documentation -of functions and abilities. +See [builtins.py][], [itertools.py][], and [more_itertools.py][] for full +documentation of functions and abilities. License diff --git a/aioitertools/itertools.py b/aioitertools/itertools.py index 5673433..75235f4 100644 --- a/aioitertools/itertools.py +++ b/aioitertools/itertools.py @@ -340,7 +340,7 @@ async def islice(itr: AnyIterable[T], *args: Optional[int]) -> AsyncIterator[T]: if not args: raise ValueError("must pass stop index") if len(args) == 1: - stop, = args + (stop,) = args elif len(args) == 2: start, stop = args # type: ignore elif len(args) == 3: diff --git a/aioitertools/more_itertools.py b/aioitertools/more_itertools.py new file mode 100644 index 0000000..5a062d3 --- /dev/null +++ b/aioitertools/more_itertools.py @@ -0,0 +1,46 @@ +# Copyright 2020 John Reese +# Licensed under the MIT license + +from typing import AsyncIterable, List, TypeVar + +from .builtins import iter +from .itertools import islice +from .types import AnyIterable + +T = TypeVar("T") + + +async def take(n: int, iterable: AnyIterable[T]) -> List[T]: + """ + Return the first n items of iterable as a list. + + If there are too few items in iterable, all of them are returned. + n needs to be at least 0. If it is 0, an empty list is returned. + + Example: + + first_two = await take(2, [1, 2, 3, 4, 5]) + + """ + if n < 0: + raise ValueError("take's first parameter can't be negative") + return [item async for item in islice(iterable, n)] + + +async def chunked(iterable: AnyIterable[T], n: int) -> AsyncIterable[List[T]]: + """ + Break iterable into chunks of length n. + + The last chunk will be shorter if the total number of items is not + divisible by n. + + Example: + + async for chunk in chunked([1, 2, 3, 4, 5], n=2): + ... # first iteration: chunk == [1, 2]; last one: chunk == [5] + """ + it = iter(iterable) + chunk = await take(n, it) + while chunk != []: + yield chunk + chunk = await take(n, it) diff --git a/aioitertools/tests/__init__.py b/aioitertools/tests/__init__.py index 75dbe82..1089a10 100644 --- a/aioitertools/tests/__init__.py +++ b/aioitertools/tests/__init__.py @@ -5,3 +5,4 @@ from .asyncio import AsyncioTest from .builtins import BuiltinsTest from .helpers import HelpersTest from .itertools import ItertoolsTest +from .more_itertools import MoreItertoolsTest diff --git a/aioitertools/tests/more_itertools.py b/aioitertools/tests/more_itertools.py new file mode 100644 index 0000000..ad3b39e --- /dev/null +++ b/aioitertools/tests/more_itertools.py @@ -0,0 +1,59 @@ +# Copyright 2020 John Reese +# Licensed under the MIT license + +from typing import AsyncIterable +from unittest import TestCase + +import aioitertools.more_itertools as mit + +from .helpers import async_test + + +async def _gen() -> AsyncIterable[int]: + for i in range(5): + yield i + + +async def _empty() -> AsyncIterable[int]: + return + yield 0 # pylint: disable=unreachable + + +class MoreItertoolsTest(TestCase): + @async_test + async def test_take(self) -> None: + self.assertEqual(await mit.take(2, _gen()), [0, 1]) + self.assertEqual(await mit.take(2, range(5)), [0, 1]) + + @async_test + async def test_take_zero(self) -> None: + self.assertEqual(await mit.take(0, _gen()), []) + + @async_test + async def test_take_negative(self) -> None: + with self.assertRaises(ValueError): + await mit.take(-1, _gen()) + + @async_test + async def test_take_more_than_iterable(self) -> None: + self.assertEqual(await mit.take(10, _gen()), list(range(5))) + + @async_test + async def test_take_empty(self) -> None: + it = _gen() + self.assertEqual(len(await mit.take(5, it)), 5) + self.assertEqual(await mit.take(1, it), []) + self.assertEqual(await mit.take(1, _empty()), []) + + @async_test + async def test_chunked(self) -> None: + self.assertEqual( + [chunk async for chunk in mit.chunked(_gen(), 2)], [[0, 1], [2, 3], [4]] + ) + self.assertEqual( + [chunk async for chunk in mit.chunked(range(5), 2)], [[0, 1], [2, 3], [4]] + ) + + @async_test + async def test_chunked_empty(self) -> None: + self.assertEqual([], [chunk async for chunk in mit.chunked(_empty(), 2)])