asyncio: fix, tidy & update `gather` & tests

This commit is contained in:
Casper da Costa-Luis 2021-03-03 11:41:59 +00:00
parent 82e0851f63
commit 8fdcddb446
No known key found for this signature in database
GPG Key ID: 986B408043AE090D
2 changed files with 17 additions and 47 deletions

View File

@ -115,20 +115,14 @@ async def test_as_completed(capsys, tol):
raise raise
@mark.slow async def double(i):
return i * 2
@mark.asyncio @mark.asyncio
@mark.parametrize("tol", [0.2 if platform.startswith("darwin") else 0.1]) async def test_gather(capsys):
async def test_gather(capsys, tol):
"""Test asyncio gather""" """Test asyncio gather"""
for retry in range(3): res = await gather(list(map(double, range(30))))
t = time() _, err = capsys.readouterr()
skew = time() - t assert '30/30' in err
await gather([asyncio.sleep(0.01 * i) for i in range(30, 0, -1)]) assert res == list(range(0, 30 * 2, 2))
t = time() - t - 2 * skew
try:
assert 0.3 * (1 - tol) < t < 0.3 * (1 + tol), t
_, err = capsys.readouterr()
assert '30/30' in err
except AssertionError:
if retry == 2:
raise

View File

@ -8,15 +8,12 @@ Usage:
... ... ... ...
""" """
import asyncio import asyncio
from typing import Awaitable, List, TypeVar
from .std import tqdm as std_tqdm from .std import tqdm as std_tqdm
__author__ = {"github.com/": ["casperdcl"]} __author__ = {"github.com/": ["casperdcl"]}
__all__ = ['tqdm_asyncio', 'tarange', 'tqdm', 'trange'] __all__ = ['tqdm_asyncio', 'tarange', 'tqdm', 'trange']
T = TypeVar("T")
class tqdm_asyncio(std_tqdm): class tqdm_asyncio(std_tqdm):
""" """
@ -67,38 +64,17 @@ class tqdm_asyncio(std_tqdm):
total=total, **tqdm_kwargs) total=total, **tqdm_kwargs)
@classmethod @classmethod
async def gather( async def gather(cls, fs, *, loop=None, timeout=None, total=None, **tqdm_kwargs):
cls,
fs: List[Awaitable[T]],
*,
loop=None,
timeout=None,
total=None,
**tqdm_kwargs
) -> List[T]:
""" """
Re-creating the functionality of asyncio.gather, giving a progress bar like Wrapper for `asyncio.gather`.
tqdm.as_completed(), but returning the results in original order.
""" """
async def wrap_awaitable(number: int, awaitable: Awaitable[T]): async def wrap_awaitable(i, f):
return number, await awaitable return i, await f
if total is None:
total = len(fs)
numbered_awaitables = [wrap_awaitable(idx, fs[idx]) for idx in range(len(fs))] ifs = [wrap_awaitable(i, f) for i, f in enumerate(fs)]
res = [await f for f in cls.as_completed(ifs, loop=loop, timeout=timeout,
numbered_results = [ total=total, **tqdm_kwargs)]
await f for f in cls.as_completed( return [i for _, i in sorted(res)]
numbered_awaitables,
total=total,
loop=loop,
timeout=timeout,
**tqdm_kwargs
)
]
results = [result_tuple[1] for result_tuple in sorted(numbered_results)]
return results
def tarange(*args, **kwargs): def tarange(*args, **kwargs):