From 6285944f42395ffdca96aa388d2db1e4f545c7bb Mon Sep 17 00:00:00 2001 From: Amethyst Reese Date: Mon, 5 Feb 2024 19:54:51 -0800 Subject: [PATCH] Make zip raise any exceptions from iterators --- aioitertools/builtins.py | 14 ++++++++++---- aioitertools/tests/builtins.py | 17 +++++++++++++++++ 2 files changed, 27 insertions(+), 4 deletions(-) diff --git a/aioitertools/builtins.py b/aioitertools/builtins.py index 17068e1..ffd1de9 100644 --- a/aioitertools/builtins.py +++ b/aioitertools/builtins.py @@ -427,11 +427,17 @@ async def zip(*itrs: AnyIterable[Any]) -> AsyncIterator[Tuple[Any, ...]]: """ its: List[AsyncIterator[Any]] = [iter(itr) for itr in itrs] + ok = True - while True: + while ok: values = await asyncio.gather( *[it.__anext__() for it in its], return_exceptions=True ) - if builtins.any(isinstance(v, AnyStop) for v in values): - break - yield builtins.tuple(values) + for v in values: + if isinstance(v, BaseException): + if isinstance(v, AnyStop): + ok = False + break + raise v + if ok: + yield builtins.tuple(values) diff --git a/aioitertools/tests/builtins.py b/aioitertools/tests/builtins.py index cb94741..86598a7 100644 --- a/aioitertools/tests/builtins.py +++ b/aioitertools/tests/builtins.py @@ -370,3 +370,20 @@ class BuiltinsTest(TestCase): result = await ait.list(ait.zip(short, long)) expected = [("a", 0), ("b", 1), ("c", 2)] self.assertListEqual(expected, result) + + @async_test + async def test_zip_exception(self): + async def raise_after(x: int): + for i in range(x): + yield i + assert False + + short = raise_after(2) + long = ["a", "b", "c"] + + gen = ait.zip(short, long) + self.assertEqual((0, "a"), await ait.next(gen)) + self.assertEqual((1, "b"), await ait.next(gen)) + + with self.assertRaises(AssertionError): + await ait.next(gen)