mirror of https://github.com/tqdm/tqdm.git
Fix that tqdm_class cannot be passed to `tmap` and `tzip`
This commit is contained in:
parent
bcce20f771
commit
ea92a0b5fb
|
@ -3,6 +3,9 @@ Tests for `tqdm.contrib`.
|
|||
"""
|
||||
import sys
|
||||
|
||||
import pytest
|
||||
|
||||
from tqdm import tqdm
|
||||
from tqdm.contrib import tenumerate, tmap, tzip
|
||||
|
||||
from .tests_tqdm import StringIO, closing, importorskip
|
||||
|
@ -13,49 +16,56 @@ def incr(x):
|
|||
return x + 1
|
||||
|
||||
|
||||
def test_enumerate():
|
||||
@pytest.mark.parametrize("tqdm_kwargs", [dict(), dict(tqdm_class=tqdm)])
|
||||
def test_enumerate(tqdm_kwargs):
|
||||
"""Test contrib.tenumerate"""
|
||||
with closing(StringIO()) as our_file:
|
||||
a = range(9)
|
||||
assert list(tenumerate(a, file=our_file)) == list(enumerate(a))
|
||||
assert list(tenumerate(a, 42, file=our_file)) == list(enumerate(a, 42))
|
||||
assert list(tenumerate(a, file=our_file, **tqdm_kwargs)) == list(enumerate(a))
|
||||
assert list(tenumerate(a, 42, file=our_file, **tqdm_kwargs)) == list(
|
||||
enumerate(a, 42)
|
||||
)
|
||||
with closing(StringIO()) as our_file:
|
||||
_ = list(tenumerate((i for i in a), file=our_file))
|
||||
_ = list(tenumerate((i for i in a), file=our_file, **tqdm_kwargs))
|
||||
assert "100%" not in our_file.getvalue()
|
||||
with closing(StringIO()) as our_file:
|
||||
_ = list(tenumerate((i for i in a), file=our_file, total=len(a)))
|
||||
_ = list(tenumerate((i for i in a), file=our_file, total=len(a), **tqdm_kwargs))
|
||||
assert "100%" in our_file.getvalue()
|
||||
|
||||
|
||||
def test_enumerate_numpy():
|
||||
"""Test contrib.tenumerate(numpy.ndarray)"""
|
||||
np = importorskip('numpy')
|
||||
np = importorskip("numpy")
|
||||
with closing(StringIO()) as our_file:
|
||||
a = np.random.random((42, 7))
|
||||
assert list(tenumerate(a, file=our_file)) == list(np.ndenumerate(a))
|
||||
|
||||
|
||||
def test_zip():
|
||||
@pytest.mark.parametrize("tqdm_kwargs", [dict(), dict(tqdm_class=tqdm)])
|
||||
def test_zip(tqdm_kwargs):
|
||||
"""Test contrib.tzip"""
|
||||
with closing(StringIO()) as our_file:
|
||||
a = range(9)
|
||||
b = [i + 1 for i in a]
|
||||
if sys.version_info[:1] < (3,):
|
||||
assert tzip(a, b, file=our_file) == zip(a, b)
|
||||
assert tzip(a, b, file=our_file, **tqdm_kwargs) == zip(a, b)
|
||||
else:
|
||||
gen = tzip(a, b, file=our_file)
|
||||
gen = tzip(a, b, file=our_file, **tqdm_kwargs)
|
||||
assert gen != list(zip(a, b))
|
||||
assert list(gen) == list(zip(a, b))
|
||||
|
||||
|
||||
def test_map():
|
||||
@pytest.mark.parametrize("tqdm_kwargs", [dict(), dict(tqdm_class=tqdm)])
|
||||
def test_map(tqdm_kwargs):
|
||||
"""Test contrib.tmap"""
|
||||
with closing(StringIO()) as our_file:
|
||||
a = range(9)
|
||||
b = [i + 1 for i in a]
|
||||
if sys.version_info[:1] < (3,):
|
||||
assert tmap(lambda x: x + 1, a, file=our_file) == map(incr, a)
|
||||
assert tmap(lambda x: x + 1, a, file=our_file, **tqdm_kwargs) == map(
|
||||
incr, a
|
||||
)
|
||||
else:
|
||||
gen = tmap(lambda x: x + 1, a, file=our_file)
|
||||
gen = tmap(lambda x: x + 1, a, file=our_file, **tqdm_kwargs)
|
||||
assert gen != b
|
||||
assert list(gen) == b
|
||||
|
|
|
@ -16,6 +16,7 @@ __all__ = ['tenumerate', 'tzip', 'tmap']
|
|||
|
||||
class DummyTqdmFile(ObjectWrapper):
|
||||
"""Dummy file-like that will write to tqdm"""
|
||||
|
||||
def __init__(self, wrapped):
|
||||
super(DummyTqdmFile, self).__init__(wrapped)
|
||||
self._buf = []
|
||||
|
@ -80,7 +81,7 @@ def tzip(iter1, *iter2plus, **tqdm_kwargs):
|
|||
"""
|
||||
kwargs = tqdm_kwargs.copy()
|
||||
tqdm_class = kwargs.pop("tqdm_class", tqdm_auto)
|
||||
for i in zip(tqdm_class(iter1, **tqdm_kwargs), *iter2plus):
|
||||
for i in zip(tqdm_class(iter1, **kwargs), *iter2plus):
|
||||
yield i
|
||||
|
||||
|
||||
|
|
Loading…
Reference in New Issue