Make zip raise any exceptions from iterators
This commit is contained in:
parent
83d4343917
commit
6285944f42
|
@ -427,11 +427,17 @@ async def zip(*itrs: AnyIterable[Any]) -> AsyncIterator[Tuple[Any, ...]]:
|
|||
|
||||
"""
|
||||
its: List[AsyncIterator[Any]] = [iter(itr) for itr in itrs]
|
||||
ok = True
|
||||
|
||||
while True:
|
||||
while ok:
|
||||
values = await asyncio.gather(
|
||||
*[it.__anext__() for it in its], return_exceptions=True
|
||||
)
|
||||
if builtins.any(isinstance(v, AnyStop) for v in values):
|
||||
break
|
||||
yield builtins.tuple(values)
|
||||
for v in values:
|
||||
if isinstance(v, BaseException):
|
||||
if isinstance(v, AnyStop):
|
||||
ok = False
|
||||
break
|
||||
raise v
|
||||
if ok:
|
||||
yield builtins.tuple(values)
|
||||
|
|
|
@ -370,3 +370,20 @@ class BuiltinsTest(TestCase):
|
|||
result = await ait.list(ait.zip(short, long))
|
||||
expected = [("a", 0), ("b", 1), ("c", 2)]
|
||||
self.assertListEqual(expected, result)
|
||||
|
||||
@async_test
|
||||
async def test_zip_exception(self):
|
||||
async def raise_after(x: int):
|
||||
for i in range(x):
|
||||
yield i
|
||||
assert False
|
||||
|
||||
short = raise_after(2)
|
||||
long = ["a", "b", "c"]
|
||||
|
||||
gen = ait.zip(short, long)
|
||||
self.assertEqual((0, "a"), await ait.next(gen))
|
||||
self.assertEqual((1, "b"), await ait.next(gen))
|
||||
|
||||
with self.assertRaises(AssertionError):
|
||||
await ait.next(gen)
|
||||
|
|
Loading…
Reference in New Issue