From f8f06a986c0a1aa06a0e04af4cf80eea75241bc0 Mon Sep 17 00:00:00 2001 From: Casper da Costa-Luis Date: Sun, 19 Jan 2020 15:46:44 +0000 Subject: [PATCH] spilt `contrib.wraps` into `contrib` and `contrib.concurrent` --- tqdm/contrib/__init__.py | 70 +++++++++++++++++++ tqdm/contrib/concurrent.py | 58 ++++++++++++++++ tqdm/contrib/wraps.py | 119 --------------------------------- tqdm/tests/tests_concurrent.py | 36 ++++++++++ tqdm/tests/tests_contrib.py | 37 ++-------- 5 files changed, 170 insertions(+), 150 deletions(-) create mode 100644 tqdm/contrib/concurrent.py delete mode 100644 tqdm/contrib/wraps.py create mode 100644 tqdm/tests/tests_concurrent.py diff --git a/tqdm/contrib/__init__.py b/tqdm/contrib/__init__.py index 380a18d8..1dddacf4 100644 --- a/tqdm/contrib/__init__.py +++ b/tqdm/contrib/__init__.py @@ -1,5 +1,16 @@ +""" +Thin wrappers around common functions. + +Subpackages contain potentially unstable extensions. +""" from tqdm import tqdm +from tqdm.auto import tqdm as tqdm_auto from tqdm.utils import ObjectWrapper +from copy import deepcopy +import functools +import sys +__author__ = {"github.com/": ["casperdcl"]} +__all__ = ['tenumerate', 'tzip', 'tmap'] class DummyTqdmFile(ObjectWrapper): @@ -8,3 +19,62 @@ class DummyTqdmFile(ObjectWrapper): # Avoid print() second call (useless \n) if len(x.rstrip()) > 0: tqdm.write(x, file=self._wrapped, nolock=nolock) + + +def tenumerate(iterable, start=0, total=None, tqdm_class=tqdm_auto, + **tqdm_kwargs): + """ + Equivalent of `numpy.ndenumerate` or builtin `enumerate`. + + Parameters + ---------- + tqdm_class : [default: tqdm.auto.tqdm]. + """ + try: + import numpy as np + except ImportError: + pass + else: + if isinstance(iterable, np.ndarray): + return tqdm_class(np.ndenumerate(iterable), + total=total or len(iterable), **tqdm_kwargs) + return enumerate(tqdm_class(iterable, start, **tqdm_kwargs)) + + +def _tzip(iter1, *iter2plus, **tqdm_kwargs): + """ + Equivalent of builtin `zip`. + + Parameters + ---------- + tqdm_class : [default: tqdm.auto.tqdm]. + """ + kwargs = deepcopy(tqdm_kwargs) + tqdm_class = kwargs.pop("tqdm_class", tqdm_auto) + for i in zip(tqdm_class(iter1, **tqdm_kwargs), *iter2plus): + yield i + + +def _tmap(function, *sequences, **tqdm_kwargs): + """ + Equivalent of builtin `map`. + + Parameters + ---------- + tqdm_class : [default: tqdm.auto.tqdm]. + """ + for i in _tzip(*sequences, **tqdm_kwargs): + yield function(*i) + + +if sys.version_info[:1] < (3,): + @functools.wraps(_tzip) + def tzip(*args, **kwargs): + return list(_tzip(*args, **kwargs)) + + @functools.wraps(_tmap) + def tmap(*args, **kwargs): + return list(_tmap(*args, **kwargs)) +else: + tzip = _tzip + tmap = _tmap diff --git a/tqdm/contrib/concurrent.py b/tqdm/contrib/concurrent.py new file mode 100644 index 00000000..e5219681 --- /dev/null +++ b/tqdm/contrib/concurrent.py @@ -0,0 +1,58 @@ +""" +Thin wrappers around `concurrent.futures`. +""" +from __future__ import absolute_import +from tqdm.auto import tqdm as tqdm_auto +from copy import deepcopy +try: + from os import cpu_count +except ImportError: + try: + from multiprocessing import cpu_count + except ImportError: + def cpu_count(): + return 4 +__author__ = {"github.com/": ["casperdcl"]} +__all__ = ['thread_map', 'process_map'] + + +def _executor_map(PoolExecutor, fn, *iterables, **tqdm_kwargs): + """ + Implementation of `thread_map` and `process_map`. + + Parameters + ---------- + tqdm_class : [default: tqdm.auto.tqdm]. + """ + kwargs = deepcopy(tqdm_kwargs) + kwargs.setdefault("total", len(iterables[0])) + tqdm_class = kwargs.pop("tqdm_class", tqdm_auto) + max_workers = kwargs.pop("max_workers", min(32, cpu_count() + 4)) + with PoolExecutor(max_workers=max_workers) as ex: + return list(tqdm_class(ex.map(fn, *iterables), **kwargs)) + + +def thread_map(fn, *iterables, **tqdm_kwargs): + """ + Equivalent of `list(map(fn, *iterables))` + driven by `concurrent.futures.ThreadPoolExecutor`. + + Parameters + ---------- + tqdm_class : [default: tqdm.auto.tqdm]. + """ + from concurrent.futures import ThreadPoolExecutor + return _executor_map(ThreadPoolExecutor, fn, *iterables, **tqdm_kwargs) + + +def process_map(fn, *iterables, **tqdm_kwargs): + """ + Equivalent of `list(map(fn, *iterables))` + driven by `concurrent.futures.ProcessPoolExecutor`. + + Parameters + ---------- + tqdm_class : [default: tqdm.auto.tqdm]. + """ + from concurrent.futures import ProcessPoolExecutor + return _executor_map(ProcessPoolExecutor, fn, *iterables, **tqdm_kwargs) diff --git a/tqdm/contrib/wraps.py b/tqdm/contrib/wraps.py deleted file mode 100644 index 7a4ddb0e..00000000 --- a/tqdm/contrib/wraps.py +++ /dev/null @@ -1,119 +0,0 @@ -""" -Thin wrappers around common functions. -""" -from tqdm.auto import tqdm as tqdm_auto -from copy import deepcopy -import functools -try: - from os import cpu_count -except ImportError: - try: - from multiprocessing import cpu_count - except ImportError: - def cpu_count(): - return 4 -import sys - -__author__ = {"github.com/": ["casperdcl"]} -__all__ = ['tenumerate', 'tzip', 'tmap', 'thread_map', 'process_map'] - - -def tenumerate(iterable, start=0, total=None, tqdm_class=tqdm_auto, - **tqdm_kwargs): - """ - Equivalent of `numpy.ndenumerate` or builtin `enumerate`. - - Parameters - ---------- - tqdm_class : [default: tqdm.auto.tqdm]. - """ - try: - import numpy as np - except ImportError: - pass - else: - if isinstance(iterable, np.ndarray): - return tqdm_class(np.ndenumerate(iterable), - total=total or len(iterable), **tqdm_kwargs) - return enumerate(tqdm_class(iterable, start, **tqdm_kwargs)) - - -def _tzip(iter1, *iter2plus, **tqdm_kwargs): - """ - Equivalent of builtin `zip`. - - Parameters - ---------- - tqdm_class : [default: tqdm.auto.tqdm]. - """ - kwargs = deepcopy(tqdm_kwargs) - tqdm_class = kwargs.pop("tqdm_class", tqdm_auto) - for i in zip(tqdm_class(iter1, **tqdm_kwargs), *iter2plus): - yield i - - -def _tmap(function, *sequences, **tqdm_kwargs): - """ - Equivalent of builtin `map`. - - Parameters - ---------- - tqdm_class : [default: tqdm.auto.tqdm]. - """ - for i in _tzip(*sequences, **tqdm_kwargs): - yield function(*i) - - -def _executor_map(PoolExecutor, fn, *iterables, **tqdm_kwargs): - """ - Implementation of `thread_map` and `process_map`. - - Parameters - ---------- - tqdm_class : [default: tqdm.auto.tqdm]. - """ - kwargs = deepcopy(tqdm_kwargs) - kwargs.setdefault("total", len(iterables[0])) - tqdm_class = kwargs.pop("tqdm_class", tqdm_auto) - max_workers = kwargs.pop("max_workers", min(32, cpu_count() + 4)) - with PoolExecutor(max_workers=max_workers) as ex: - return list(tqdm_class(ex.map(fn, *iterables), **kwargs)) - - -def thread_map(fn, *iterables, **tqdm_kwargs): - """ - Equivalent of `list(map(fn, *iterables))` - driven by `concurrent.futures.ThreadPoolExecutor`. - - Parameters - ---------- - tqdm_class : [default: tqdm.auto.tqdm]. - """ - from concurrent.futures import ThreadPoolExecutor - return _executor_map(ThreadPoolExecutor, fn, *iterables, **tqdm_kwargs) - - -def process_map(fn, *iterables, **tqdm_kwargs): - """ - Equivalent of `list(map(fn, *iterables))` - driven by `concurrent.futures.ProcessPoolExecutor`. - - Parameters - ---------- - tqdm_class : [default: tqdm.auto.tqdm]. - """ - from concurrent.futures import ProcessPoolExecutor - return _executor_map(ProcessPoolExecutor, fn, *iterables, **tqdm_kwargs) - - -if sys.version_info[:1] < (3,): - @functools.wraps(_tzip) - def tzip(*args, **kwargs): - return list(_tzip(*args, **kwargs)) - - @functools.wraps(_tmap) - def tmap(*args, **kwargs): - return list(_tmap(*args, **kwargs)) -else: - tzip = _tzip - tmap = _tmap diff --git a/tqdm/tests/tests_concurrent.py b/tqdm/tests/tests_concurrent.py new file mode 100644 index 00000000..57412cae --- /dev/null +++ b/tqdm/tests/tests_concurrent.py @@ -0,0 +1,36 @@ +""" +Tests for `tqdm.contrib` +""" +import sys +from tqdm.contrib.concurrent import thread_map, process_map +from tests_tqdm import with_setup, pretest, posttest, SkipTest, StringIO, \ + closing + + +def incr(x): + """Dummy function""" + return x + 1 + +@with_setup(pretest, posttest) +def test_thread_map(): + """Test contrib.concurrent.thread_map""" + with closing(StringIO()) as our_file: + a = range(9) + b = [i + 1 for i in a] + try: + assert thread_map(lambda x: x + 1, a, file=our_file) == b + except ImportError: + raise SkipTest + assert thread_map(incr, a, file=our_file) == b + + +@with_setup(pretest, posttest) +def test_process_map(): + """Test contrib.concurrent.process_map""" + with closing(StringIO()) as our_file: + a = range(9) + b = [i + 1 for i in a] + try: + assert process_map(incr, a, file=our_file) == b + except ImportError: + raise SkipTest diff --git a/tqdm/tests/tests_contrib.py b/tqdm/tests/tests_contrib.py index e3772839..218f9d87 100644 --- a/tqdm/tests/tests_contrib.py +++ b/tqdm/tests/tests_contrib.py @@ -1,20 +1,20 @@ """ Tests for `tqdm.contrib` """ -from __future__ import division import sys -from tqdm.contrib.wraps import tenumerate, tzip, tmap, thread_map, process_map +from tqdm.contrib import tenumerate, tzip, tmap from tests_tqdm import with_setup, pretest, posttest, SkipTest, StringIO, \ closing def incr(x): + """Dummy function""" return x + 1 @with_setup(pretest, posttest) def test_enumerate(): - """Test contrib.wraps.tenumerate""" + """Test contrib.tenumerate""" with closing(StringIO()) as our_file: a = range(9) assert list(tenumerate(a, file=our_file)) == list(enumerate(a)) @@ -22,7 +22,7 @@ def test_enumerate(): @with_setup(pretest, posttest) def test_enumerate_numpy(): - """Test contrib.wraps.tenumerate(numpy.ndarray)""" + """Test contrib.tenumerate(numpy.ndarray)""" try: import numpy as np except ImportError: @@ -34,7 +34,7 @@ def test_enumerate_numpy(): @with_setup(pretest, posttest) def test_zip(): - """Test contrib.wraps.tzip""" + """Test contrib.tzip""" with closing(StringIO()) as our_file: a = range(9) b = [i + 1 for i in a] @@ -48,7 +48,7 @@ def test_zip(): @with_setup(pretest, posttest) def test_map(): - """Test contrib.wraps.tmap""" + """Test contrib.tmap""" with closing(StringIO()) as our_file: a = range(9) b = [i + 1 for i in a] @@ -58,28 +58,3 @@ def test_map(): gen = tmap(lambda x: x + 1, a, file=our_file) assert gen != b assert list(gen) == b - - -@with_setup(pretest, posttest) -def test_thread_map(): - """Test contrib.wraps.thread_map""" - with closing(StringIO()) as our_file: - a = range(9) - b = [i + 1 for i in a] - try: - assert thread_map(lambda x: x + 1, a, file=our_file) == b - except ImportError: - raise SkipTest - assert thread_map(incr, a, file=our_file) == b - - -@with_setup(pretest, posttest) -def test_process_map(): - """Test contrib.wraps.process_map""" - with closing(StringIO()) as our_file: - a = range(9) - b = [i + 1 for i in a] - try: - assert process_map(incr, a, file=our_file) == b - except ImportError: - raise SkipTest