spilt `contrib.wraps` into `contrib` and `contrib.concurrent`

This commit is contained in:
Casper da Costa-Luis 2020-01-19 15:46:44 +00:00
parent cf74393cd3
commit f8f06a986c
No known key found for this signature in database
GPG Key ID: 986B408043AE090D
5 changed files with 170 additions and 150 deletions

View File

@ -1,5 +1,16 @@
"""
Thin wrappers around common functions.
Subpackages contain potentially unstable extensions.
"""
from tqdm import tqdm
from tqdm.auto import tqdm as tqdm_auto
from tqdm.utils import ObjectWrapper
from copy import deepcopy
import functools
import sys
__author__ = {"github.com/": ["casperdcl"]}
__all__ = ['tenumerate', 'tzip', 'tmap']
class DummyTqdmFile(ObjectWrapper):
@ -8,3 +19,62 @@ class DummyTqdmFile(ObjectWrapper):
# Avoid print() second call (useless \n)
if len(x.rstrip()) > 0:
tqdm.write(x, file=self._wrapped, nolock=nolock)
def tenumerate(iterable, start=0, total=None, tqdm_class=tqdm_auto,
**tqdm_kwargs):
"""
Equivalent of `numpy.ndenumerate` or builtin `enumerate`.
Parameters
----------
tqdm_class : [default: tqdm.auto.tqdm].
"""
try:
import numpy as np
except ImportError:
pass
else:
if isinstance(iterable, np.ndarray):
return tqdm_class(np.ndenumerate(iterable),
total=total or len(iterable), **tqdm_kwargs)
return enumerate(tqdm_class(iterable, start, **tqdm_kwargs))
def _tzip(iter1, *iter2plus, **tqdm_kwargs):
"""
Equivalent of builtin `zip`.
Parameters
----------
tqdm_class : [default: tqdm.auto.tqdm].
"""
kwargs = deepcopy(tqdm_kwargs)
tqdm_class = kwargs.pop("tqdm_class", tqdm_auto)
for i in zip(tqdm_class(iter1, **tqdm_kwargs), *iter2plus):
yield i
def _tmap(function, *sequences, **tqdm_kwargs):
"""
Equivalent of builtin `map`.
Parameters
----------
tqdm_class : [default: tqdm.auto.tqdm].
"""
for i in _tzip(*sequences, **tqdm_kwargs):
yield function(*i)
if sys.version_info[:1] < (3,):
@functools.wraps(_tzip)
def tzip(*args, **kwargs):
return list(_tzip(*args, **kwargs))
@functools.wraps(_tmap)
def tmap(*args, **kwargs):
return list(_tmap(*args, **kwargs))
else:
tzip = _tzip
tmap = _tmap

View File

@ -0,0 +1,58 @@
"""
Thin wrappers around `concurrent.futures`.
"""
from __future__ import absolute_import
from tqdm.auto import tqdm as tqdm_auto
from copy import deepcopy
try:
from os import cpu_count
except ImportError:
try:
from multiprocessing import cpu_count
except ImportError:
def cpu_count():
return 4
__author__ = {"github.com/": ["casperdcl"]}
__all__ = ['thread_map', 'process_map']
def _executor_map(PoolExecutor, fn, *iterables, **tqdm_kwargs):
"""
Implementation of `thread_map` and `process_map`.
Parameters
----------
tqdm_class : [default: tqdm.auto.tqdm].
"""
kwargs = deepcopy(tqdm_kwargs)
kwargs.setdefault("total", len(iterables[0]))
tqdm_class = kwargs.pop("tqdm_class", tqdm_auto)
max_workers = kwargs.pop("max_workers", min(32, cpu_count() + 4))
with PoolExecutor(max_workers=max_workers) as ex:
return list(tqdm_class(ex.map(fn, *iterables), **kwargs))
def thread_map(fn, *iterables, **tqdm_kwargs):
"""
Equivalent of `list(map(fn, *iterables))`
driven by `concurrent.futures.ThreadPoolExecutor`.
Parameters
----------
tqdm_class : [default: tqdm.auto.tqdm].
"""
from concurrent.futures import ThreadPoolExecutor
return _executor_map(ThreadPoolExecutor, fn, *iterables, **tqdm_kwargs)
def process_map(fn, *iterables, **tqdm_kwargs):
"""
Equivalent of `list(map(fn, *iterables))`
driven by `concurrent.futures.ProcessPoolExecutor`.
Parameters
----------
tqdm_class : [default: tqdm.auto.tqdm].
"""
from concurrent.futures import ProcessPoolExecutor
return _executor_map(ProcessPoolExecutor, fn, *iterables, **tqdm_kwargs)

View File

@ -1,119 +0,0 @@
"""
Thin wrappers around common functions.
"""
from tqdm.auto import tqdm as tqdm_auto
from copy import deepcopy
import functools
try:
from os import cpu_count
except ImportError:
try:
from multiprocessing import cpu_count
except ImportError:
def cpu_count():
return 4
import sys
__author__ = {"github.com/": ["casperdcl"]}
__all__ = ['tenumerate', 'tzip', 'tmap', 'thread_map', 'process_map']
def tenumerate(iterable, start=0, total=None, tqdm_class=tqdm_auto,
**tqdm_kwargs):
"""
Equivalent of `numpy.ndenumerate` or builtin `enumerate`.
Parameters
----------
tqdm_class : [default: tqdm.auto.tqdm].
"""
try:
import numpy as np
except ImportError:
pass
else:
if isinstance(iterable, np.ndarray):
return tqdm_class(np.ndenumerate(iterable),
total=total or len(iterable), **tqdm_kwargs)
return enumerate(tqdm_class(iterable, start, **tqdm_kwargs))
def _tzip(iter1, *iter2plus, **tqdm_kwargs):
"""
Equivalent of builtin `zip`.
Parameters
----------
tqdm_class : [default: tqdm.auto.tqdm].
"""
kwargs = deepcopy(tqdm_kwargs)
tqdm_class = kwargs.pop("tqdm_class", tqdm_auto)
for i in zip(tqdm_class(iter1, **tqdm_kwargs), *iter2plus):
yield i
def _tmap(function, *sequences, **tqdm_kwargs):
"""
Equivalent of builtin `map`.
Parameters
----------
tqdm_class : [default: tqdm.auto.tqdm].
"""
for i in _tzip(*sequences, **tqdm_kwargs):
yield function(*i)
def _executor_map(PoolExecutor, fn, *iterables, **tqdm_kwargs):
"""
Implementation of `thread_map` and `process_map`.
Parameters
----------
tqdm_class : [default: tqdm.auto.tqdm].
"""
kwargs = deepcopy(tqdm_kwargs)
kwargs.setdefault("total", len(iterables[0]))
tqdm_class = kwargs.pop("tqdm_class", tqdm_auto)
max_workers = kwargs.pop("max_workers", min(32, cpu_count() + 4))
with PoolExecutor(max_workers=max_workers) as ex:
return list(tqdm_class(ex.map(fn, *iterables), **kwargs))
def thread_map(fn, *iterables, **tqdm_kwargs):
"""
Equivalent of `list(map(fn, *iterables))`
driven by `concurrent.futures.ThreadPoolExecutor`.
Parameters
----------
tqdm_class : [default: tqdm.auto.tqdm].
"""
from concurrent.futures import ThreadPoolExecutor
return _executor_map(ThreadPoolExecutor, fn, *iterables, **tqdm_kwargs)
def process_map(fn, *iterables, **tqdm_kwargs):
"""
Equivalent of `list(map(fn, *iterables))`
driven by `concurrent.futures.ProcessPoolExecutor`.
Parameters
----------
tqdm_class : [default: tqdm.auto.tqdm].
"""
from concurrent.futures import ProcessPoolExecutor
return _executor_map(ProcessPoolExecutor, fn, *iterables, **tqdm_kwargs)
if sys.version_info[:1] < (3,):
@functools.wraps(_tzip)
def tzip(*args, **kwargs):
return list(_tzip(*args, **kwargs))
@functools.wraps(_tmap)
def tmap(*args, **kwargs):
return list(_tmap(*args, **kwargs))
else:
tzip = _tzip
tmap = _tmap

View File

@ -0,0 +1,36 @@
"""
Tests for `tqdm.contrib`
"""
import sys
from tqdm.contrib.concurrent import thread_map, process_map
from tests_tqdm import with_setup, pretest, posttest, SkipTest, StringIO, \
closing
def incr(x):
"""Dummy function"""
return x + 1
@with_setup(pretest, posttest)
def test_thread_map():
"""Test contrib.concurrent.thread_map"""
with closing(StringIO()) as our_file:
a = range(9)
b = [i + 1 for i in a]
try:
assert thread_map(lambda x: x + 1, a, file=our_file) == b
except ImportError:
raise SkipTest
assert thread_map(incr, a, file=our_file) == b
@with_setup(pretest, posttest)
def test_process_map():
"""Test contrib.concurrent.process_map"""
with closing(StringIO()) as our_file:
a = range(9)
b = [i + 1 for i in a]
try:
assert process_map(incr, a, file=our_file) == b
except ImportError:
raise SkipTest

View File

@ -1,20 +1,20 @@
"""
Tests for `tqdm.contrib`
"""
from __future__ import division
import sys
from tqdm.contrib.wraps import tenumerate, tzip, tmap, thread_map, process_map
from tqdm.contrib import tenumerate, tzip, tmap
from tests_tqdm import with_setup, pretest, posttest, SkipTest, StringIO, \
closing
def incr(x):
"""Dummy function"""
return x + 1
@with_setup(pretest, posttest)
def test_enumerate():
"""Test contrib.wraps.tenumerate"""
"""Test contrib.tenumerate"""
with closing(StringIO()) as our_file:
a = range(9)
assert list(tenumerate(a, file=our_file)) == list(enumerate(a))
@ -22,7 +22,7 @@ def test_enumerate():
@with_setup(pretest, posttest)
def test_enumerate_numpy():
"""Test contrib.wraps.tenumerate(numpy.ndarray)"""
"""Test contrib.tenumerate(numpy.ndarray)"""
try:
import numpy as np
except ImportError:
@ -34,7 +34,7 @@ def test_enumerate_numpy():
@with_setup(pretest, posttest)
def test_zip():
"""Test contrib.wraps.tzip"""
"""Test contrib.tzip"""
with closing(StringIO()) as our_file:
a = range(9)
b = [i + 1 for i in a]
@ -48,7 +48,7 @@ def test_zip():
@with_setup(pretest, posttest)
def test_map():
"""Test contrib.wraps.tmap"""
"""Test contrib.tmap"""
with closing(StringIO()) as our_file:
a = range(9)
b = [i + 1 for i in a]
@ -58,28 +58,3 @@ def test_map():
gen = tmap(lambda x: x + 1, a, file=our_file)
assert gen != b
assert list(gen) == b
@with_setup(pretest, posttest)
def test_thread_map():
"""Test contrib.wraps.thread_map"""
with closing(StringIO()) as our_file:
a = range(9)
b = [i + 1 for i in a]
try:
assert thread_map(lambda x: x + 1, a, file=our_file) == b
except ImportError:
raise SkipTest
assert thread_map(incr, a, file=our_file) == b
@with_setup(pretest, posttest)
def test_process_map():
"""Test contrib.wraps.process_map"""
with closing(StringIO()) as our_file:
a = range(9)
b = [i + 1 for i in a]
try:
assert process_map(incr, a, file=our_file) == b
except ImportError:
raise SkipTest