diff --git a/spacy/matcher.pyx b/spacy/matcher.pyx index 7c694431e..72a4a97d6 100644 --- a/spacy/matcher.pyx +++ b/spacy/matcher.pyx @@ -142,7 +142,7 @@ def _convert_strings(token_specs, string_store): tokens = [] op = ONE for spec in token_specs: - token = [] + token = [] ops = (ONE,) for attr, value in spec.items(): if isinstance(attr, basestring) and attr.upper() == 'OP': @@ -173,7 +173,7 @@ cdef class Matcher: cdef public object _entities cdef public object _callbacks cdef public object _acceptors - + @classmethod def load(cls, path, vocab): '''Load the matcher and patterns from a file path. @@ -218,7 +218,7 @@ cdef class Matcher: def __reduce__(self): return (self.__class__, (self.vocab, self._patterns), None, None) - + property n_patterns: def __get__(self): return self.patterns.size() @@ -492,14 +492,14 @@ cdef class PhraseMatcher: abstract_patterns = [] for length in range(1, max_length): abstract_patterns.append([{tag: True} for tag in get_bilou(length)]) - self.matcher.add('Candidate', 'MWE', {}, abstract_patterns) + self.matcher.add('Candidate', 'MWE', {}, abstract_patterns, acceptor=self.accept_match) def add(self, Doc tokens): cdef int length = tokens.length assert length < self.max_length tags = get_bilou(length) assert len(tags) == length, length - + cdef int i for i in range(self.max_length): self._phrase_key[i] = 0 @@ -512,7 +512,7 @@ cdef class PhraseMatcher: def __call__(self, Doc doc): matches = [] - for label, start, end in self.matcher(doc, acceptor=self.accept_match): + for ent_id, label, start, end in self.matcher(doc): cand = doc[start : end] start = cand[0].idx end = cand[-1].idx + len(cand[-1]) @@ -526,7 +526,7 @@ cdef class PhraseMatcher: self(doc) yield doc - def accept_match(self, Doc doc, int label, int start, int end): + def accept_match(self, Doc doc, int ent_id, int label, int start, int end): assert (end - start) < self.max_length cdef int i, j for i in range(self.max_length): @@ -535,6 +535,6 @@ cdef class PhraseMatcher: self._phrase_key[i] = doc.c[j].lex.orth cdef hash_t key = hash64(self._phrase_key, self.max_length * sizeof(attr_t), 0) if self.phrase_ids.get(key): - return True + return (ent_id, label, start, end) else: return False