mirror of https://github.com/tqdm/tqdm.git
spilt `contrib.wraps` into `contrib` and `contrib.concurrent`
This commit is contained in:
parent
cf74393cd3
commit
f8f06a986c
|
@ -1,5 +1,16 @@
|
|||
"""
|
||||
Thin wrappers around common functions.
|
||||
|
||||
Subpackages contain potentially unstable extensions.
|
||||
"""
|
||||
from tqdm import tqdm
|
||||
from tqdm.auto import tqdm as tqdm_auto
|
||||
from tqdm.utils import ObjectWrapper
|
||||
from copy import deepcopy
|
||||
import functools
|
||||
import sys
|
||||
__author__ = {"github.com/": ["casperdcl"]}
|
||||
__all__ = ['tenumerate', 'tzip', 'tmap']
|
||||
|
||||
|
||||
class DummyTqdmFile(ObjectWrapper):
|
||||
|
@ -8,3 +19,62 @@ class DummyTqdmFile(ObjectWrapper):
|
|||
# Avoid print() second call (useless \n)
|
||||
if len(x.rstrip()) > 0:
|
||||
tqdm.write(x, file=self._wrapped, nolock=nolock)
|
||||
|
||||
|
||||
def tenumerate(iterable, start=0, total=None, tqdm_class=tqdm_auto,
|
||||
**tqdm_kwargs):
|
||||
"""
|
||||
Equivalent of `numpy.ndenumerate` or builtin `enumerate`.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
tqdm_class : [default: tqdm.auto.tqdm].
|
||||
"""
|
||||
try:
|
||||
import numpy as np
|
||||
except ImportError:
|
||||
pass
|
||||
else:
|
||||
if isinstance(iterable, np.ndarray):
|
||||
return tqdm_class(np.ndenumerate(iterable),
|
||||
total=total or len(iterable), **tqdm_kwargs)
|
||||
return enumerate(tqdm_class(iterable, start, **tqdm_kwargs))
|
||||
|
||||
|
||||
def _tzip(iter1, *iter2plus, **tqdm_kwargs):
|
||||
"""
|
||||
Equivalent of builtin `zip`.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
tqdm_class : [default: tqdm.auto.tqdm].
|
||||
"""
|
||||
kwargs = deepcopy(tqdm_kwargs)
|
||||
tqdm_class = kwargs.pop("tqdm_class", tqdm_auto)
|
||||
for i in zip(tqdm_class(iter1, **tqdm_kwargs), *iter2plus):
|
||||
yield i
|
||||
|
||||
|
||||
def _tmap(function, *sequences, **tqdm_kwargs):
|
||||
"""
|
||||
Equivalent of builtin `map`.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
tqdm_class : [default: tqdm.auto.tqdm].
|
||||
"""
|
||||
for i in _tzip(*sequences, **tqdm_kwargs):
|
||||
yield function(*i)
|
||||
|
||||
|
||||
if sys.version_info[:1] < (3,):
|
||||
@functools.wraps(_tzip)
|
||||
def tzip(*args, **kwargs):
|
||||
return list(_tzip(*args, **kwargs))
|
||||
|
||||
@functools.wraps(_tmap)
|
||||
def tmap(*args, **kwargs):
|
||||
return list(_tmap(*args, **kwargs))
|
||||
else:
|
||||
tzip = _tzip
|
||||
tmap = _tmap
|
||||
|
|
|
@ -0,0 +1,58 @@
|
|||
"""
|
||||
Thin wrappers around `concurrent.futures`.
|
||||
"""
|
||||
from __future__ import absolute_import
|
||||
from tqdm.auto import tqdm as tqdm_auto
|
||||
from copy import deepcopy
|
||||
try:
|
||||
from os import cpu_count
|
||||
except ImportError:
|
||||
try:
|
||||
from multiprocessing import cpu_count
|
||||
except ImportError:
|
||||
def cpu_count():
|
||||
return 4
|
||||
__author__ = {"github.com/": ["casperdcl"]}
|
||||
__all__ = ['thread_map', 'process_map']
|
||||
|
||||
|
||||
def _executor_map(PoolExecutor, fn, *iterables, **tqdm_kwargs):
|
||||
"""
|
||||
Implementation of `thread_map` and `process_map`.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
tqdm_class : [default: tqdm.auto.tqdm].
|
||||
"""
|
||||
kwargs = deepcopy(tqdm_kwargs)
|
||||
kwargs.setdefault("total", len(iterables[0]))
|
||||
tqdm_class = kwargs.pop("tqdm_class", tqdm_auto)
|
||||
max_workers = kwargs.pop("max_workers", min(32, cpu_count() + 4))
|
||||
with PoolExecutor(max_workers=max_workers) as ex:
|
||||
return list(tqdm_class(ex.map(fn, *iterables), **kwargs))
|
||||
|
||||
|
||||
def thread_map(fn, *iterables, **tqdm_kwargs):
|
||||
"""
|
||||
Equivalent of `list(map(fn, *iterables))`
|
||||
driven by `concurrent.futures.ThreadPoolExecutor`.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
tqdm_class : [default: tqdm.auto.tqdm].
|
||||
"""
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
return _executor_map(ThreadPoolExecutor, fn, *iterables, **tqdm_kwargs)
|
||||
|
||||
|
||||
def process_map(fn, *iterables, **tqdm_kwargs):
|
||||
"""
|
||||
Equivalent of `list(map(fn, *iterables))`
|
||||
driven by `concurrent.futures.ProcessPoolExecutor`.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
tqdm_class : [default: tqdm.auto.tqdm].
|
||||
"""
|
||||
from concurrent.futures import ProcessPoolExecutor
|
||||
return _executor_map(ProcessPoolExecutor, fn, *iterables, **tqdm_kwargs)
|
|
@ -1,119 +0,0 @@
|
|||
"""
|
||||
Thin wrappers around common functions.
|
||||
"""
|
||||
from tqdm.auto import tqdm as tqdm_auto
|
||||
from copy import deepcopy
|
||||
import functools
|
||||
try:
|
||||
from os import cpu_count
|
||||
except ImportError:
|
||||
try:
|
||||
from multiprocessing import cpu_count
|
||||
except ImportError:
|
||||
def cpu_count():
|
||||
return 4
|
||||
import sys
|
||||
|
||||
__author__ = {"github.com/": ["casperdcl"]}
|
||||
__all__ = ['tenumerate', 'tzip', 'tmap', 'thread_map', 'process_map']
|
||||
|
||||
|
||||
def tenumerate(iterable, start=0, total=None, tqdm_class=tqdm_auto,
|
||||
**tqdm_kwargs):
|
||||
"""
|
||||
Equivalent of `numpy.ndenumerate` or builtin `enumerate`.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
tqdm_class : [default: tqdm.auto.tqdm].
|
||||
"""
|
||||
try:
|
||||
import numpy as np
|
||||
except ImportError:
|
||||
pass
|
||||
else:
|
||||
if isinstance(iterable, np.ndarray):
|
||||
return tqdm_class(np.ndenumerate(iterable),
|
||||
total=total or len(iterable), **tqdm_kwargs)
|
||||
return enumerate(tqdm_class(iterable, start, **tqdm_kwargs))
|
||||
|
||||
|
||||
def _tzip(iter1, *iter2plus, **tqdm_kwargs):
|
||||
"""
|
||||
Equivalent of builtin `zip`.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
tqdm_class : [default: tqdm.auto.tqdm].
|
||||
"""
|
||||
kwargs = deepcopy(tqdm_kwargs)
|
||||
tqdm_class = kwargs.pop("tqdm_class", tqdm_auto)
|
||||
for i in zip(tqdm_class(iter1, **tqdm_kwargs), *iter2plus):
|
||||
yield i
|
||||
|
||||
|
||||
def _tmap(function, *sequences, **tqdm_kwargs):
|
||||
"""
|
||||
Equivalent of builtin `map`.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
tqdm_class : [default: tqdm.auto.tqdm].
|
||||
"""
|
||||
for i in _tzip(*sequences, **tqdm_kwargs):
|
||||
yield function(*i)
|
||||
|
||||
|
||||
def _executor_map(PoolExecutor, fn, *iterables, **tqdm_kwargs):
|
||||
"""
|
||||
Implementation of `thread_map` and `process_map`.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
tqdm_class : [default: tqdm.auto.tqdm].
|
||||
"""
|
||||
kwargs = deepcopy(tqdm_kwargs)
|
||||
kwargs.setdefault("total", len(iterables[0]))
|
||||
tqdm_class = kwargs.pop("tqdm_class", tqdm_auto)
|
||||
max_workers = kwargs.pop("max_workers", min(32, cpu_count() + 4))
|
||||
with PoolExecutor(max_workers=max_workers) as ex:
|
||||
return list(tqdm_class(ex.map(fn, *iterables), **kwargs))
|
||||
|
||||
|
||||
def thread_map(fn, *iterables, **tqdm_kwargs):
|
||||
"""
|
||||
Equivalent of `list(map(fn, *iterables))`
|
||||
driven by `concurrent.futures.ThreadPoolExecutor`.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
tqdm_class : [default: tqdm.auto.tqdm].
|
||||
"""
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
return _executor_map(ThreadPoolExecutor, fn, *iterables, **tqdm_kwargs)
|
||||
|
||||
|
||||
def process_map(fn, *iterables, **tqdm_kwargs):
|
||||
"""
|
||||
Equivalent of `list(map(fn, *iterables))`
|
||||
driven by `concurrent.futures.ProcessPoolExecutor`.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
tqdm_class : [default: tqdm.auto.tqdm].
|
||||
"""
|
||||
from concurrent.futures import ProcessPoolExecutor
|
||||
return _executor_map(ProcessPoolExecutor, fn, *iterables, **tqdm_kwargs)
|
||||
|
||||
|
||||
if sys.version_info[:1] < (3,):
|
||||
@functools.wraps(_tzip)
|
||||
def tzip(*args, **kwargs):
|
||||
return list(_tzip(*args, **kwargs))
|
||||
|
||||
@functools.wraps(_tmap)
|
||||
def tmap(*args, **kwargs):
|
||||
return list(_tmap(*args, **kwargs))
|
||||
else:
|
||||
tzip = _tzip
|
||||
tmap = _tmap
|
|
@ -0,0 +1,36 @@
|
|||
"""
|
||||
Tests for `tqdm.contrib`
|
||||
"""
|
||||
import sys
|
||||
from tqdm.contrib.concurrent import thread_map, process_map
|
||||
from tests_tqdm import with_setup, pretest, posttest, SkipTest, StringIO, \
|
||||
closing
|
||||
|
||||
|
||||
def incr(x):
|
||||
"""Dummy function"""
|
||||
return x + 1
|
||||
|
||||
@with_setup(pretest, posttest)
|
||||
def test_thread_map():
|
||||
"""Test contrib.concurrent.thread_map"""
|
||||
with closing(StringIO()) as our_file:
|
||||
a = range(9)
|
||||
b = [i + 1 for i in a]
|
||||
try:
|
||||
assert thread_map(lambda x: x + 1, a, file=our_file) == b
|
||||
except ImportError:
|
||||
raise SkipTest
|
||||
assert thread_map(incr, a, file=our_file) == b
|
||||
|
||||
|
||||
@with_setup(pretest, posttest)
|
||||
def test_process_map():
|
||||
"""Test contrib.concurrent.process_map"""
|
||||
with closing(StringIO()) as our_file:
|
||||
a = range(9)
|
||||
b = [i + 1 for i in a]
|
||||
try:
|
||||
assert process_map(incr, a, file=our_file) == b
|
||||
except ImportError:
|
||||
raise SkipTest
|
|
@ -1,20 +1,20 @@
|
|||
"""
|
||||
Tests for `tqdm.contrib`
|
||||
"""
|
||||
from __future__ import division
|
||||
import sys
|
||||
from tqdm.contrib.wraps import tenumerate, tzip, tmap, thread_map, process_map
|
||||
from tqdm.contrib import tenumerate, tzip, tmap
|
||||
from tests_tqdm import with_setup, pretest, posttest, SkipTest, StringIO, \
|
||||
closing
|
||||
|
||||
|
||||
def incr(x):
|
||||
"""Dummy function"""
|
||||
return x + 1
|
||||
|
||||
|
||||
@with_setup(pretest, posttest)
|
||||
def test_enumerate():
|
||||
"""Test contrib.wraps.tenumerate"""
|
||||
"""Test contrib.tenumerate"""
|
||||
with closing(StringIO()) as our_file:
|
||||
a = range(9)
|
||||
assert list(tenumerate(a, file=our_file)) == list(enumerate(a))
|
||||
|
@ -22,7 +22,7 @@ def test_enumerate():
|
|||
|
||||
@with_setup(pretest, posttest)
|
||||
def test_enumerate_numpy():
|
||||
"""Test contrib.wraps.tenumerate(numpy.ndarray)"""
|
||||
"""Test contrib.tenumerate(numpy.ndarray)"""
|
||||
try:
|
||||
import numpy as np
|
||||
except ImportError:
|
||||
|
@ -34,7 +34,7 @@ def test_enumerate_numpy():
|
|||
|
||||
@with_setup(pretest, posttest)
|
||||
def test_zip():
|
||||
"""Test contrib.wraps.tzip"""
|
||||
"""Test contrib.tzip"""
|
||||
with closing(StringIO()) as our_file:
|
||||
a = range(9)
|
||||
b = [i + 1 for i in a]
|
||||
|
@ -48,7 +48,7 @@ def test_zip():
|
|||
|
||||
@with_setup(pretest, posttest)
|
||||
def test_map():
|
||||
"""Test contrib.wraps.tmap"""
|
||||
"""Test contrib.tmap"""
|
||||
with closing(StringIO()) as our_file:
|
||||
a = range(9)
|
||||
b = [i + 1 for i in a]
|
||||
|
@ -58,28 +58,3 @@ def test_map():
|
|||
gen = tmap(lambda x: x + 1, a, file=our_file)
|
||||
assert gen != b
|
||||
assert list(gen) == b
|
||||
|
||||
|
||||
@with_setup(pretest, posttest)
|
||||
def test_thread_map():
|
||||
"""Test contrib.wraps.thread_map"""
|
||||
with closing(StringIO()) as our_file:
|
||||
a = range(9)
|
||||
b = [i + 1 for i in a]
|
||||
try:
|
||||
assert thread_map(lambda x: x + 1, a, file=our_file) == b
|
||||
except ImportError:
|
||||
raise SkipTest
|
||||
assert thread_map(incr, a, file=our_file) == b
|
||||
|
||||
|
||||
@with_setup(pretest, posttest)
|
||||
def test_process_map():
|
||||
"""Test contrib.wraps.process_map"""
|
||||
with closing(StringIO()) as our_file:
|
||||
a = range(9)
|
||||
b = [i + 1 for i in a]
|
||||
try:
|
||||
assert process_map(incr, a, file=our_file) == b
|
||||
except ImportError:
|
||||
raise SkipTest
|
||||
|
|
Loading…
Reference in New Issue