Merge branch 'async-gather' into devel

- closes #1136
This commit is contained in:
Casper da Costa-Luis 2021-03-04 18:27:49 +00:00
commit ed048af509
No known key found for this signature in database
GPG Key ID: 986B408043AE090D
3 changed files with 31 additions and 4 deletions

View File

@ -10,6 +10,7 @@ from .tests_tqdm import StringIO, closing, mark
tqdm = partial(tqdm_asyncio, miniters=0, mininterval=0)
trange = partial(tarange, miniters=0, mininterval=0)
as_completed = partial(tqdm_asyncio.as_completed, miniters=0, mininterval=0)
gather = partial(tqdm_asyncio.gather, miniters=0, mininterval=0)
def count(start=0, step=1):
@ -112,3 +113,16 @@ async def test_as_completed(capsys, tol):
except AssertionError:
if retry == 2:
raise
async def double(i):
return i * 2
@mark.asyncio
async def test_gather(capsys):
"""Test asyncio gather"""
res = await gather(list(map(double, range(30))))
_, err = capsys.readouterr()
assert '30/30' in err
assert res == list(range(0, 30 * 2, 2))

View File

@ -17,7 +17,7 @@ __all__ = ['tqdm_asyncio', 'tarange', 'tqdm', 'trange']
class tqdm_asyncio(std_tqdm):
"""
Asynchronous-friendly version of tqdm (Python 3.5+).
Asynchronous-friendly version of tqdm (Python 3.6+).
"""
def __init__(self, iterable=None, *args, **kwargs):
super(tqdm_asyncio, self).__init__(iterable, *args, **kwargs)
@ -63,6 +63,19 @@ class tqdm_asyncio(std_tqdm):
yield from cls(asyncio.as_completed(fs, loop=loop, timeout=timeout),
total=total, **tqdm_kwargs)
@classmethod
async def gather(cls, fs, *, loop=None, timeout=None, total=None, **tqdm_kwargs):
"""
Wrapper for `asyncio.gather`.
"""
async def wrap_awaitable(i, f):
return i, await f
ifs = [wrap_awaitable(i, f) for i, f in enumerate(fs)]
res = [await f for f in cls.as_completed(ifs, loop=loop, timeout=timeout,
total=total, **tqdm_kwargs)]
return [i for _, i in sorted(res)]
def tarange(*args, **kwargs):
"""

View File

@ -4,7 +4,7 @@ Enables multiple commonly used features.
Method resolution order:
- `tqdm.autonotebook` without import warnings
- `tqdm.asyncio` on Python3.5+
- `tqdm.asyncio` on Python3.6+
- `tqdm.std` base class
Usage:
@ -22,10 +22,10 @@ with warnings.catch_warnings():
from .autonotebook import tqdm as notebook_tqdm
from .autonotebook import trange as notebook_trange
if sys.version_info[:2] < (3, 5):
if sys.version_info[:2] < (3, 6):
tqdm = notebook_tqdm
trange = notebook_trange
else: # Python3.5+
else: # Python3.6+
from .asyncio import tqdm as asyncio_tqdm
from .std import tqdm as std_tqdm