mirror of https://github.com/tqdm/tqdm.git
asyncio: fix, tidy & update `gather` & tests
This commit is contained in:
parent
82e0851f63
commit
8fdcddb446
|
@ -115,20 +115,14 @@ async def test_as_completed(capsys, tol):
|
||||||
raise
|
raise
|
||||||
|
|
||||||
|
|
||||||
@mark.slow
|
async def double(i):
|
||||||
|
return i * 2
|
||||||
|
|
||||||
|
|
||||||
@mark.asyncio
|
@mark.asyncio
|
||||||
@mark.parametrize("tol", [0.2 if platform.startswith("darwin") else 0.1])
|
async def test_gather(capsys):
|
||||||
async def test_gather(capsys, tol):
|
|
||||||
"""Test asyncio gather"""
|
"""Test asyncio gather"""
|
||||||
for retry in range(3):
|
res = await gather(list(map(double, range(30))))
|
||||||
t = time()
|
_, err = capsys.readouterr()
|
||||||
skew = time() - t
|
assert '30/30' in err
|
||||||
await gather([asyncio.sleep(0.01 * i) for i in range(30, 0, -1)])
|
assert res == list(range(0, 30 * 2, 2))
|
||||||
t = time() - t - 2 * skew
|
|
||||||
try:
|
|
||||||
assert 0.3 * (1 - tol) < t < 0.3 * (1 + tol), t
|
|
||||||
_, err = capsys.readouterr()
|
|
||||||
assert '30/30' in err
|
|
||||||
except AssertionError:
|
|
||||||
if retry == 2:
|
|
||||||
raise
|
|
||||||
|
|
|
@ -8,15 +8,12 @@ Usage:
|
||||||
... ...
|
... ...
|
||||||
"""
|
"""
|
||||||
import asyncio
|
import asyncio
|
||||||
from typing import Awaitable, List, TypeVar
|
|
||||||
|
|
||||||
from .std import tqdm as std_tqdm
|
from .std import tqdm as std_tqdm
|
||||||
|
|
||||||
__author__ = {"github.com/": ["casperdcl"]}
|
__author__ = {"github.com/": ["casperdcl"]}
|
||||||
__all__ = ['tqdm_asyncio', 'tarange', 'tqdm', 'trange']
|
__all__ = ['tqdm_asyncio', 'tarange', 'tqdm', 'trange']
|
||||||
|
|
||||||
T = TypeVar("T")
|
|
||||||
|
|
||||||
|
|
||||||
class tqdm_asyncio(std_tqdm):
|
class tqdm_asyncio(std_tqdm):
|
||||||
"""
|
"""
|
||||||
|
@ -67,38 +64,17 @@ class tqdm_asyncio(std_tqdm):
|
||||||
total=total, **tqdm_kwargs)
|
total=total, **tqdm_kwargs)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
async def gather(
|
async def gather(cls, fs, *, loop=None, timeout=None, total=None, **tqdm_kwargs):
|
||||||
cls,
|
|
||||||
fs: List[Awaitable[T]],
|
|
||||||
*,
|
|
||||||
loop=None,
|
|
||||||
timeout=None,
|
|
||||||
total=None,
|
|
||||||
**tqdm_kwargs
|
|
||||||
) -> List[T]:
|
|
||||||
"""
|
"""
|
||||||
Re-creating the functionality of asyncio.gather, giving a progress bar like
|
Wrapper for `asyncio.gather`.
|
||||||
tqdm.as_completed(), but returning the results in original order.
|
|
||||||
"""
|
"""
|
||||||
async def wrap_awaitable(number: int, awaitable: Awaitable[T]):
|
async def wrap_awaitable(i, f):
|
||||||
return number, await awaitable
|
return i, await f
|
||||||
if total is None:
|
|
||||||
total = len(fs)
|
|
||||||
|
|
||||||
numbered_awaitables = [wrap_awaitable(idx, fs[idx]) for idx in range(len(fs))]
|
ifs = [wrap_awaitable(i, f) for i, f in enumerate(fs)]
|
||||||
|
res = [await f for f in cls.as_completed(ifs, loop=loop, timeout=timeout,
|
||||||
numbered_results = [
|
total=total, **tqdm_kwargs)]
|
||||||
await f for f in cls.as_completed(
|
return [i for _, i in sorted(res)]
|
||||||
numbered_awaitables,
|
|
||||||
total=total,
|
|
||||||
loop=loop,
|
|
||||||
timeout=timeout,
|
|
||||||
**tqdm_kwargs
|
|
||||||
)
|
|
||||||
]
|
|
||||||
|
|
||||||
results = [result_tuple[1] for result_tuple in sorted(numbered_results)]
|
|
||||||
return results
|
|
||||||
|
|
||||||
|
|
||||||
def tarange(*args, **kwargs):
|
def tarange(*args, **kwargs):
|
||||||
|
|
Loading…
Reference in New Issue