From ecbb9c4b9f89120ba04642852780d592c024b6ef Mon Sep 17 00:00:00 2001 From: svlandeg Date: Wed, 12 Feb 2020 11:50:42 +0100 Subject: [PATCH] load Underscore state when multiprocessing --- spacy/language.py | 11 ++++++++--- spacy/tokens/underscore.py | 8 ++++++++ 2 files changed, 16 insertions(+), 3 deletions(-) diff --git a/spacy/language.py b/spacy/language.py index 5544b6341..71180a65d 100644 --- a/spacy/language.py +++ b/spacy/language.py @@ -15,6 +15,7 @@ import multiprocessing as mp from itertools import chain, cycle from .tokenizer import Tokenizer +from .tokens.underscore import Underscore from .vocab import Vocab from .lemmatizer import Lemmatizer from .lookups import Lookups @@ -852,7 +853,10 @@ class Language(object): sender.send() procs = [ - mp.Process(target=_apply_pipes, args=(self.make_doc, pipes, rch, sch)) + mp.Process( + target=_apply_pipes, + args=(self.make_doc, pipes, rch, sch, Underscore.get_state()), + ) for rch, sch in zip(texts_q, bytedocs_send_ch) ] for proc in procs: @@ -1107,7 +1111,7 @@ def _pipe(docs, proc, kwargs): yield doc -def _apply_pipes(make_doc, pipes, reciever, sender): +def _apply_pipes(make_doc, pipes, receiver, sender, underscore_state): """Worker for Language.pipe receiver (multiprocessing.Connection): Pipe to receive text. Usually @@ -1115,8 +1119,9 @@ def _apply_pipes(make_doc, pipes, reciever, sender): sender (multiprocessing.Connection): Pipe to send doc. Usually created by `multiprocessing.Pipe()` """ + Underscore.load_state(underscore_state) while True: - texts = reciever.get() + texts = receiver.get() docs = (make_doc(text) for text in texts) for pipe in pipes: docs = pipe(docs) diff --git a/spacy/tokens/underscore.py b/spacy/tokens/underscore.py index b36fe9294..8dac8526e 100644 --- a/spacy/tokens/underscore.py +++ b/spacy/tokens/underscore.py @@ -79,6 +79,14 @@ class Underscore(object): def _get_key(self, name): return ("._.", name, self._start, self._end) + @classmethod + def get_state(cls): + return cls.token_extensions, cls.span_extensions, cls.doc_extensions + + @classmethod + def load_state(cls, state): + cls.token_extensions, cls.span_extensions, cls.doc_extensions = state + def get_ext_args(**kwargs): """Validate and convert arguments. Reused in Doc, Token and Span."""