diff --git a/spacy/matcher.pyx b/spacy/matcher.pyx index a6b02ba2c..ec87dce12 100644 --- a/spacy/matcher.pyx +++ b/spacy/matcher.pyx @@ -183,6 +183,14 @@ def merge_phrase(matcher, doc, i, matches): span.merge(ent_type=label, ent_id=ent_id) +def unpickle_matcher(vocab, patterns, callbacks): + matcher = Matcher(vocab) + for key, specs in patterns.items(): + callback = callbacks.get(key, None) + matcher.add(key, callback, *specs) + return matcher + + cdef class Matcher: """Match sequences of tokens, based on pattern rules.""" cdef Pool mem @@ -206,7 +214,8 @@ cdef class Matcher: self.mem = Pool() def __reduce__(self): - return (self.__class__, (self.vocab, self._patterns), None, None) + data = (self.vocab, self._patterns, self._callbacks) + return (unpickle_matcher, data, None, None) def __len__(self): """Get the number of rules added to the matcher. Note that this only @@ -259,12 +268,12 @@ cdef class Matcher: "key: {key}\n") raise ValueError(msg.format(key=key)) key = self._normalize_key(key) - self._patterns.setdefault(key, []) - self._callbacks[key] = on_match for pattern in patterns: specs = _convert_strings(pattern, self.vocab.strings) self.patterns.push_back(init_pattern(self.mem, key, specs)) - self._patterns[key].append(specs) + self._patterns.setdefault(key, []) + self._callbacks[key] = on_match + self._patterns[key].extend(patterns) def remove(self, key): """Remove a rule from the matcher. A KeyError is raised if the key does