Restore support for pickling

This commit is contained in:
Adriane Boyd 2019-09-19 20:20:53 +02:00
parent 3931368ce8
commit 0d851db6d9
1 changed files with 17 additions and 0 deletions

View File

@ -30,6 +30,7 @@ cdef class PhraseMatcher:
cdef attr_id_t attr
cdef object _callbacks
cdef object _keywords
cdef object _docs
cdef bint _validate
def __init__(self, Vocab vocab, max_length=0, attr="ORTH", validate=False):
@ -49,6 +50,7 @@ cdef class PhraseMatcher:
self.keyword_trie_dict = dict()
self._callbacks = {}
self._keywords = {}
self._docs = {}
self._validate = validate
if isinstance(attr, long):
@ -80,6 +82,10 @@ cdef class PhraseMatcher:
"""
return key in self._callbacks
def __reduce__(self):
data = (self.vocab, self._docs, self._callbacks)
return (unpickle_matcher, data, None, None)
def remove(self, key):
"""Remove a match-rule from the matcher by match ID.
@ -120,6 +126,7 @@ cdef class PhraseMatcher:
del self._keywords[key]
del self._callbacks[key]
del self._docs[key]
def add(self, key, on_match, *docs):
"""Add a match-rule to the phrase-matcher. A match-rule consists of: an ID
@ -135,6 +142,8 @@ cdef class PhraseMatcher:
_ = self.vocab[key]
self._callbacks[key] = on_match
self._keywords.setdefault(key, [])
self._docs.setdefault(key, set())
self._docs[key].update(docs)
for doc in docs:
if len(doc) == 0:
@ -285,3 +294,11 @@ cdef class PhraseMatcher:
def _convert_to_array(self, Doc doc):
return np.array([self.get_lex_value(doc, i) for i in range(len(doc))], dtype=np.uint64)
def unpickle_matcher(vocab, docs, callbacks):
matcher = PhraseMatcher(vocab)
for key, specs in docs.items():
callback = callbacks.get(key, None)
matcher.add(key, callback, *specs)
return matcher