From 8fdcddb446088241b51da0cd7667c7ad452c0222 Mon Sep 17 00:00:00 2001 From: Casper da Costa-Luis Date: Wed, 3 Mar 2021 11:41:59 +0000 Subject: [PATCH] asyncio: fix, tidy & update `gather` & tests --- tests/py37_asyncio.py | 24 +++++++++--------------- tqdm/asyncio.py | 40 ++++++++-------------------------------- 2 files changed, 17 insertions(+), 47 deletions(-) diff --git a/tests/py37_asyncio.py b/tests/py37_asyncio.py index 1e051178..18997ca7 100644 --- a/tests/py37_asyncio.py +++ b/tests/py37_asyncio.py @@ -115,20 +115,14 @@ async def test_as_completed(capsys, tol): raise -@mark.slow +async def double(i): + return i * 2 + + @mark.asyncio -@mark.parametrize("tol", [0.2 if platform.startswith("darwin") else 0.1]) -async def test_gather(capsys, tol): +async def test_gather(capsys): """Test asyncio gather""" - for retry in range(3): - t = time() - skew = time() - t - await gather([asyncio.sleep(0.01 * i) for i in range(30, 0, -1)]) - t = time() - t - 2 * skew - try: - assert 0.3 * (1 - tol) < t < 0.3 * (1 + tol), t - _, err = capsys.readouterr() - assert '30/30' in err - except AssertionError: - if retry == 2: - raise + res = await gather(list(map(double, range(30)))) + _, err = capsys.readouterr() + assert '30/30' in err + assert res == list(range(0, 30 * 2, 2)) diff --git a/tqdm/asyncio.py b/tqdm/asyncio.py index 5d25749a..a61d28b3 100644 --- a/tqdm/asyncio.py +++ b/tqdm/asyncio.py @@ -8,15 +8,12 @@ Usage: ... ... """ import asyncio -from typing import Awaitable, List, TypeVar from .std import tqdm as std_tqdm __author__ = {"github.com/": ["casperdcl"]} __all__ = ['tqdm_asyncio', 'tarange', 'tqdm', 'trange'] -T = TypeVar("T") - class tqdm_asyncio(std_tqdm): """ @@ -67,38 +64,17 @@ class tqdm_asyncio(std_tqdm): total=total, **tqdm_kwargs) @classmethod - async def gather( - cls, - fs: List[Awaitable[T]], - *, - loop=None, - timeout=None, - total=None, - **tqdm_kwargs - ) -> List[T]: + async def gather(cls, fs, *, loop=None, timeout=None, total=None, **tqdm_kwargs): """ - Re-creating the functionality of asyncio.gather, giving a progress bar like - tqdm.as_completed(), but returning the results in original order. + Wrapper for `asyncio.gather`. """ - async def wrap_awaitable(number: int, awaitable: Awaitable[T]): - return number, await awaitable - if total is None: - total = len(fs) + async def wrap_awaitable(i, f): + return i, await f - numbered_awaitables = [wrap_awaitable(idx, fs[idx]) for idx in range(len(fs))] - - numbered_results = [ - await f for f in cls.as_completed( - numbered_awaitables, - total=total, - loop=loop, - timeout=timeout, - **tqdm_kwargs - ) - ] - - results = [result_tuple[1] for result_tuple in sorted(numbered_results)] - return results + ifs = [wrap_awaitable(i, f) for i, f in enumerate(fs)] + res = [await f for f in cls.as_completed(ifs, loop=loop, timeout=timeout, + total=total, **tqdm_kwargs)] + return [i for _, i in sorted(res)] def tarange(*args, **kwargs):