mirror of https://github.com/tqdm/tqdm.git
contrib.concurrent.process_map: fix threading.RLock pickling error
- fixes #920
This commit is contained in:
parent
a00f7f3b74
commit
572a8e8725
|
@ -35,6 +35,7 @@ def _executor_map(PoolExecutor, fn, *iterables, **tqdm_kwargs):
|
||||||
tqdm_class : [default: tqdm.auto.tqdm].
|
tqdm_class : [default: tqdm.auto.tqdm].
|
||||||
max_workers : [default: min(32, cpu_count() + 4)].
|
max_workers : [default: min(32, cpu_count() + 4)].
|
||||||
chunksize : [default: 1].
|
chunksize : [default: 1].
|
||||||
|
lock_name : [default: "":str].
|
||||||
"""
|
"""
|
||||||
kwargs = tqdm_kwargs.copy()
|
kwargs = tqdm_kwargs.copy()
|
||||||
if "total" not in kwargs:
|
if "total" not in kwargs:
|
||||||
|
@ -42,12 +43,14 @@ def _executor_map(PoolExecutor, fn, *iterables, **tqdm_kwargs):
|
||||||
tqdm_class = kwargs.pop("tqdm_class", tqdm_auto)
|
tqdm_class = kwargs.pop("tqdm_class", tqdm_auto)
|
||||||
max_workers = kwargs.pop("max_workers", min(32, cpu_count() + 4))
|
max_workers = kwargs.pop("max_workers", min(32, cpu_count() + 4))
|
||||||
chunksize = kwargs.pop("chunksize", 1)
|
chunksize = kwargs.pop("chunksize", 1)
|
||||||
|
lock_name = kwargs.pop("lock_name", "")
|
||||||
pool_kwargs = dict(max_workers=max_workers)
|
pool_kwargs = dict(max_workers=max_workers)
|
||||||
sys_version = sys.version_info[:2]
|
sys_version = sys.version_info[:2]
|
||||||
if sys_version >= (3, 7):
|
if sys_version >= (3, 7):
|
||||||
# share lock in case workers are already using `tqdm`
|
# share lock in case workers are already using `tqdm`
|
||||||
pool_kwargs.update(
|
lock = tqdm_class.get_lock()
|
||||||
initializer=tqdm_class.set_lock, initargs=(tqdm_class.get_lock(),))
|
lock = getattr(lock, lock_name, lock)
|
||||||
|
pool_kwargs.update(initializer=tqdm_class.set_lock, initargs=(lock,))
|
||||||
map_args = {}
|
map_args = {}
|
||||||
if not (3, 0) < sys_version < (3, 5):
|
if not (3, 0) < sys_version < (3, 5):
|
||||||
map_args.update(chunksize=chunksize)
|
map_args.update(chunksize=chunksize)
|
||||||
|
@ -90,6 +93,8 @@ def process_map(fn, *iterables, **tqdm_kwargs):
|
||||||
chunksize : int, optional
|
chunksize : int, optional
|
||||||
Size of chunks sent to worker processes; passed to
|
Size of chunks sent to worker processes; passed to
|
||||||
`concurrent.futures.ProcessPoolExecutor.map`. [default: 1].
|
`concurrent.futures.ProcessPoolExecutor.map`. [default: 1].
|
||||||
|
lock_name : str, optional
|
||||||
|
Member of `tqdm_class.get_lock()` to use [default: mp_lock].
|
||||||
"""
|
"""
|
||||||
from concurrent.futures import ProcessPoolExecutor
|
from concurrent.futures import ProcessPoolExecutor
|
||||||
if iterables and "chunksize" not in tqdm_kwargs:
|
if iterables and "chunksize" not in tqdm_kwargs:
|
||||||
|
@ -102,4 +107,7 @@ def process_map(fn, *iterables, **tqdm_kwargs):
|
||||||
" This may seriously degrade multiprocess performance."
|
" This may seriously degrade multiprocess performance."
|
||||||
" Set `chunksize=1` or more." % longest_iterable_len,
|
" Set `chunksize=1` or more." % longest_iterable_len,
|
||||||
TqdmWarning, stacklevel=2)
|
TqdmWarning, stacklevel=2)
|
||||||
|
if "lock_name" not in tqdm_kwargs:
|
||||||
|
tqdm_kwargs = tqdm_kwargs.copy()
|
||||||
|
tqdm_kwargs["lock_name"] = "mp_lock"
|
||||||
return _executor_map(ProcessPoolExecutor, fn, *iterables, **tqdm_kwargs)
|
return _executor_map(ProcessPoolExecutor, fn, *iterables, **tqdm_kwargs)
|
||||||
|
|
Loading…
Reference in New Issue