master: document ThreadWatcher

This commit is contained in:
David Wilson 2018-10-04 19:15:59 +00:00
parent a7b1831ddf
commit 74cf9c3c96
1 changed files with 34 additions and 22 deletions

View File

@ -179,24 +179,35 @@ def scan_code_imports(co):
class ThreadWatcher(object): class ThreadWatcher(object):
""" """
Manage threads that waits for nother threads to shutdown, before invoking Manage threads that wait for another thread to shut down, before invoking
`on_join()`. In CPython it seems possible to use this method to ensure a `on_join()` for each associated ThreadWatcher.
non-main thread is signalled when the main thread has exitted, using yet
another thread as a proxy. In CPython it seems possible to use this method to ensure a non-main thread
is signalled when the main thread has exited, using a third thread as a
proxy.
""" """
_lock = threading.Lock() #: Protects remaining _cls_* members.
_pid = None _cls_lock = threading.Lock()
_instances_by_target = {}
_thread_by_target = {} #: PID of the process that last modified the class data. If the PID
#: changes, it means the thread watch dict refers to threads that no longer
#: exist in the current process (since it forked), and so must be reset.
_cls_pid = None
#: Map watched Thread -> list of ThreadWatcher instances.
_cls_instances_by_target = {}
#: Map watched Thread -> watcher Thread for each watched thread.
_cls_thread_by_target = {}
@classmethod @classmethod
def _reset(cls): def _reset(cls):
"""If we have forked since the watch dictionaries were initialized, all """If we have forked since the watch dictionaries were initialized, all
that has is garbage, so clear it.""" that has is garbage, so clear it."""
if os.getpid() != cls._pid: if os.getpid() != cls._cls_pid:
cls._pid = os.getpid() cls._cls_pid = os.getpid()
cls._instances_by_target.clear() cls._cls_instances_by_target.clear()
cls._thread_by_target.clear() cls._cls_thread_by_target.clear()
def __init__(self, target, on_join): def __init__(self, target, on_join):
self.target = target self.target = target
@ -205,33 +216,34 @@ class ThreadWatcher(object):
@classmethod @classmethod
def _watch(cls, target): def _watch(cls, target):
target.join() target.join()
for watcher in cls._instances_by_target[target]: for watcher in cls._cls_instances_by_target[target]:
watcher.on_join() watcher.on_join()
def install(self): def install(self):
self._lock.acquire() self._cls_lock.acquire()
try: try:
self._reset() self._reset()
self._instances_by_target.setdefault(self.target, []).append(self) lst = self._cls_instances_by_target.setdefault(self.target, [])
if self.target not in self._thread_by_target: lst.append(self)
self._thread_by_target[self.target] = threading.Thread( if self.target not in self._cls_thread_by_target:
self._cls_thread_by_target[self.target] = threading.Thread(
name='mitogen.master.join_thread_async', name='mitogen.master.join_thread_async',
target=self._watch, target=self._watch,
args=(self.target,) args=(self.target,)
) )
self._thread_by_target[self.target].start() self._cls_thread_by_target[self.target].start()
finally: finally:
self._lock.release() self._cls_lock.release()
def remove(self): def remove(self):
self._lock.acquire() self._cls_lock.acquire()
try: try:
self._reset() self._reset()
lst = self._instances_by_target.get(self.target, []) lst = self._cls_instances_by_target.get(self.target, [])
if self in lst: if self in lst:
lst.remove(self) lst.remove(self)
finally: finally:
self._lock.release() self._cls_lock.release()
@classmethod @classmethod
def watch(cls, target, on_join): def watch(cls, target, on_join):