diff --git a/aioitertools/more_itertools.py b/aioitertools/more_itertools.py index 9a79357..69f61a7 100644 --- a/aioitertools/more_itertools.py +++ b/aioitertools/more_itertools.py @@ -1,11 +1,15 @@ # Copyright 2020 John Reese # Licensed under the MIT license -from typing import AsyncIterable, List, TypeVar +import asyncio +from typing import AsyncIterable, List, Tuple, TypeVar + +from aioitertools.helpers import maybe_await from .builtins import iter from .itertools import islice -from .types import AnyIterable +from .types import AnyIterable, Predicate + T = TypeVar("T") @@ -44,3 +48,47 @@ async def chunked(iterable: AnyIterable[T], n: int) -> AsyncIterable[List[T]]: while chunk != []: yield chunk chunk = await take(n, it) + + +async def before_and_after( + predicate: Predicate[T], iterable: AnyIterable[T] +) -> Tuple[AsyncIterable[T], AsyncIterable[T]]: + """ + A variant of :func:`aioitertools.takewhile` that allows complete access to the + remainder of the iterator. + + >>> it = iter('ABCdEfGhI') + >>> all_upper, remainder = await before_and_after(str.isupper, it) + >>> ''.join([char async for char in all_upper]) + 'ABC' + >>> ''.join([char async for char in remainder]) + 'dEfGhI' + + Note that the first iterator must be fully consumed before the second + iterator can generate valid results. + """ + + it = iter(iterable) + + transition = asyncio.get_running_loop().create_future() + + async def true_iterator(): + async for elem in it: + if await maybe_await(predicate(elem)): + yield elem + else: + transition.set_result(elem) + return + + transition.set_exception(StopAsyncIteration) + + async def remainder_iterator(): + try: + yield (await transition) + except StopAsyncIteration: + return + + async for elm in it: + yield elm + + return true_iterator(), remainder_iterator() diff --git a/aioitertools/tests/more_itertools.py b/aioitertools/tests/more_itertools.py index 3a7a3f6..0d4334e 100644 --- a/aioitertools/tests/more_itertools.py +++ b/aioitertools/tests/more_itertools.py @@ -56,3 +56,41 @@ class MoreItertoolsTest(TestCase): @async_test async def test_chunked_empty(self) -> None: self.assertEqual([], [chunk async for chunk in mit.chunked(_empty(), 2)]) + + @async_test + async def test_before_and_after_split(self) -> None: + it = _gen() + before, after = await mit.before_and_after(lambda i: i <= 2, it) + self.assertEqual([elm async for elm in before], [0, 1, 2]) + self.assertEqual([elm async for elm in after], [3, 4]) + + @async_test + async def test_before_and_after_before_only(self) -> None: + it = _gen() + before, after = await mit.before_and_after(lambda i: True, it) + self.assertEqual([elm async for elm in before], [0, 1, 2, 3, 4]) + self.assertEqual([elm async for elm in after], []) + + @async_test + async def test_before_and_after_after_only(self) -> None: + it = _gen() + before, after = await mit.before_and_after(lambda i: False, it) + self.assertEqual([elm async for elm in before], []) + self.assertEqual([elm async for elm in after], [0, 1, 2, 3, 4]) + + @async_test + async def test_before_and_after_async_predicate(self) -> None: + async def predicate(elm: int) -> bool: + return elm <= 2 + + it = _gen() + before, after = await mit.before_and_after(predicate, it) + self.assertEqual([elm async for elm in before], [0, 1, 2]) + self.assertEqual([elm async for elm in after], [3, 4]) + + @async_test + async def test_before_and_after_empty(self) -> None: + it = _empty() + before, after = await mit.before_and_after(lambda i: True, it) + self.assertEqual([elm async for elm in before], []) + self.assertEqual([elm async for elm in after], [])