From 572a8e87256ea4ed06f9dccf187fb7a52a7fb363 Mon Sep 17 00:00:00 2001 From: Casper da Costa-Luis Date: Wed, 7 Oct 2020 15:11:23 +0100 Subject: [PATCH] contrib.concurrent.process_map: fix threading.RLock pickling error - fixes #920 --- tqdm/contrib/concurrent.py | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/tqdm/contrib/concurrent.py b/tqdm/contrib/concurrent.py index eda09078..8f6077b9 100644 --- a/tqdm/contrib/concurrent.py +++ b/tqdm/contrib/concurrent.py @@ -35,6 +35,7 @@ def _executor_map(PoolExecutor, fn, *iterables, **tqdm_kwargs): tqdm_class : [default: tqdm.auto.tqdm]. max_workers : [default: min(32, cpu_count() + 4)]. chunksize : [default: 1]. + lock_name : [default: "":str]. """ kwargs = tqdm_kwargs.copy() if "total" not in kwargs: @@ -42,12 +43,14 @@ def _executor_map(PoolExecutor, fn, *iterables, **tqdm_kwargs): tqdm_class = kwargs.pop("tqdm_class", tqdm_auto) max_workers = kwargs.pop("max_workers", min(32, cpu_count() + 4)) chunksize = kwargs.pop("chunksize", 1) + lock_name = kwargs.pop("lock_name", "") pool_kwargs = dict(max_workers=max_workers) sys_version = sys.version_info[:2] if sys_version >= (3, 7): # share lock in case workers are already using `tqdm` - pool_kwargs.update( - initializer=tqdm_class.set_lock, initargs=(tqdm_class.get_lock(),)) + lock = tqdm_class.get_lock() + lock = getattr(lock, lock_name, lock) + pool_kwargs.update(initializer=tqdm_class.set_lock, initargs=(lock,)) map_args = {} if not (3, 0) < sys_version < (3, 5): map_args.update(chunksize=chunksize) @@ -90,6 +93,8 @@ def process_map(fn, *iterables, **tqdm_kwargs): chunksize : int, optional Size of chunks sent to worker processes; passed to `concurrent.futures.ProcessPoolExecutor.map`. [default: 1]. + lock_name : str, optional + Member of `tqdm_class.get_lock()` to use [default: mp_lock]. """ from concurrent.futures import ProcessPoolExecutor if iterables and "chunksize" not in tqdm_kwargs: @@ -102,4 +107,7 @@ def process_map(fn, *iterables, **tqdm_kwargs): " This may seriously degrade multiprocess performance." " Set `chunksize=1` or more." % longest_iterable_len, TqdmWarning, stacklevel=2) + if "lock_name" not in tqdm_kwargs: + tqdm_kwargs = tqdm_kwargs.copy() + tqdm_kwargs["lock_name"] = "mp_lock" return _executor_map(ProcessPoolExecutor, fn, *iterables, **tqdm_kwargs)