diff --git a/aioitertools/builtins.py b/aioitertools/builtins.py index 17068e1..ffd1de9 100644 --- a/aioitertools/builtins.py +++ b/aioitertools/builtins.py @@ -427,11 +427,17 @@ async def zip(*itrs: AnyIterable[Any]) -> AsyncIterator[Tuple[Any, ...]]: """ its: List[AsyncIterator[Any]] = [iter(itr) for itr in itrs] + ok = True - while True: + while ok: values = await asyncio.gather( *[it.__anext__() for it in its], return_exceptions=True ) - if builtins.any(isinstance(v, AnyStop) for v in values): - break - yield builtins.tuple(values) + for v in values: + if isinstance(v, BaseException): + if isinstance(v, AnyStop): + ok = False + break + raise v + if ok: + yield builtins.tuple(values) diff --git a/aioitertools/tests/builtins.py b/aioitertools/tests/builtins.py index cb94741..86598a7 100644 --- a/aioitertools/tests/builtins.py +++ b/aioitertools/tests/builtins.py @@ -370,3 +370,20 @@ class BuiltinsTest(TestCase): result = await ait.list(ait.zip(short, long)) expected = [("a", 0), ("b", 1), ("c", 2)] self.assertListEqual(expected, result) + + @async_test + async def test_zip_exception(self): + async def raise_after(x: int): + for i in range(x): + yield i + assert False + + short = raise_after(2) + long = ["a", "b", "c"] + + gen = ait.zip(short, long) + self.assertEqual((0, "a"), await ait.next(gen)) + self.assertEqual((1, "b"), await ait.next(gen)) + + with self.assertRaises(AssertionError): + await ait.next(gen)