mirror of https://github.com/explosion/spaCy.git
Fix unpickling of Matcher. Also store correct data in matcher._patterns
This commit is contained in:
parent
42a18ef903
commit
2ad050e668
|
@ -183,6 +183,14 @@ def merge_phrase(matcher, doc, i, matches):
|
|||
span.merge(ent_type=label, ent_id=ent_id)
|
||||
|
||||
|
||||
def unpickle_matcher(vocab, patterns, callbacks):
|
||||
matcher = Matcher(vocab)
|
||||
for key, specs in patterns.items():
|
||||
callback = callbacks.get(key, None)
|
||||
matcher.add(key, callback, *specs)
|
||||
return matcher
|
||||
|
||||
|
||||
cdef class Matcher:
|
||||
"""Match sequences of tokens, based on pattern rules."""
|
||||
cdef Pool mem
|
||||
|
@ -206,7 +214,8 @@ cdef class Matcher:
|
|||
self.mem = Pool()
|
||||
|
||||
def __reduce__(self):
|
||||
return (self.__class__, (self.vocab, self._patterns), None, None)
|
||||
data = (self.vocab, self._patterns, self._callbacks)
|
||||
return (unpickle_matcher, data, None, None)
|
||||
|
||||
def __len__(self):
|
||||
"""Get the number of rules added to the matcher. Note that this only
|
||||
|
@ -259,12 +268,12 @@ cdef class Matcher:
|
|||
"key: {key}\n")
|
||||
raise ValueError(msg.format(key=key))
|
||||
key = self._normalize_key(key)
|
||||
self._patterns.setdefault(key, [])
|
||||
self._callbacks[key] = on_match
|
||||
for pattern in patterns:
|
||||
specs = _convert_strings(pattern, self.vocab.strings)
|
||||
self.patterns.push_back(init_pattern(self.mem, key, specs))
|
||||
self._patterns[key].append(specs)
|
||||
self._patterns.setdefault(key, [])
|
||||
self._callbacks[key] = on_match
|
||||
self._patterns[key].extend(patterns)
|
||||
|
||||
def remove(self, key):
|
||||
"""Remove a rule from the matcher. A KeyError is raised if the key does
|
||||
|
|
Loading…
Reference in New Issue