tqdm/tests/tests_contrib.py

72 lines
2.2 KiB
Python
Raw Normal View History

2020-01-17 19:29:14 +00:00
"""
2020-01-19 16:47:11 +00:00
Tests for `tqdm.contrib`.
2020-01-17 19:29:14 +00:00
"""
2020-01-17 19:02:06 +00:00
import sys
import pytest
from tqdm import tqdm
2021-01-09 17:00:18 +00:00
from tqdm.contrib import tenumerate, tmap, tzip
from .tests_tqdm import StringIO, closing, importorskip
2020-10-24 18:36:45 +00:00
def incr(x):
"""Dummy function"""
return x + 1
2021-03-16 17:59:55 +00:00
@pytest.mark.parametrize("tqdm_kwargs", [{}, {"tqdm_class": tqdm}])
def test_enumerate(tqdm_kwargs):
2020-10-24 18:36:45 +00:00
"""Test contrib.tenumerate"""
with closing(StringIO()) as our_file:
a = range(9)
assert list(tenumerate(a, file=our_file, **tqdm_kwargs)) == list(enumerate(a))
assert list(tenumerate(a, 42, file=our_file, **tqdm_kwargs)) == list(
enumerate(a, 42)
)
2020-10-24 18:36:45 +00:00
with closing(StringIO()) as our_file:
_ = list(tenumerate((i for i in a), file=our_file, **tqdm_kwargs))
2020-10-24 18:36:45 +00:00
assert "100%" not in our_file.getvalue()
with closing(StringIO()) as our_file:
_ = list(tenumerate((i for i in a), file=our_file, total=len(a), **tqdm_kwargs))
2020-10-24 18:36:45 +00:00
assert "100%" in our_file.getvalue()
def test_enumerate_numpy():
"""Test contrib.tenumerate(numpy.ndarray)"""
np = importorskip("numpy")
2020-10-24 18:36:45 +00:00
with closing(StringIO()) as our_file:
2021-01-04 03:00:27 +00:00
a = np.random.random((42, 7))
2020-10-24 18:36:45 +00:00
assert list(tenumerate(a, file=our_file)) == list(np.ndenumerate(a))
2021-03-16 17:59:55 +00:00
@pytest.mark.parametrize("tqdm_kwargs", [{}, {"tqdm_class": tqdm}])
def test_zip(tqdm_kwargs):
2020-10-24 18:36:45 +00:00
"""Test contrib.tzip"""
with closing(StringIO()) as our_file:
a = range(9)
b = [i + 1 for i in a]
if sys.version_info[:1] < (3,):
assert tzip(a, b, file=our_file, **tqdm_kwargs) == zip(a, b)
2020-10-24 18:36:45 +00:00
else:
gen = tzip(a, b, file=our_file, **tqdm_kwargs)
2020-10-24 18:36:45 +00:00
assert gen != list(zip(a, b))
assert list(gen) == list(zip(a, b))
2021-03-16 17:59:55 +00:00
@pytest.mark.parametrize("tqdm_kwargs", [{}, {"tqdm_class": tqdm}])
def test_map(tqdm_kwargs):
2020-10-24 18:36:45 +00:00
"""Test contrib.tmap"""
with closing(StringIO()) as our_file:
a = range(9)
b = [i + 1 for i in a]
if sys.version_info[:1] < (3,):
assert tmap(lambda x: x + 1, a, file=our_file, **tqdm_kwargs) == map(
incr, a
)
2020-10-24 18:36:45 +00:00
else:
gen = tmap(lambda x: x + 1, a, file=our_file, **tqdm_kwargs)
2020-10-24 18:36:45 +00:00
assert gen != b
assert list(gen) == b