From 828cc91545458613dff701e804eaec442423e739 Mon Sep 17 00:00:00 2001 From: Matthew Honnibal Date: Wed, 20 Sep 2017 21:54:31 +0200 Subject: [PATCH] Fix PhraseMatcher for spaCy 2 --- spacy/matcher.pyx | 23 ++++++++++------------- 1 file changed, 10 insertions(+), 13 deletions(-) diff --git a/spacy/matcher.pyx b/spacy/matcher.pyx index c75d23957..d321218b8 100644 --- a/spacy/matcher.pyx +++ b/spacy/matcher.pyx @@ -426,7 +426,7 @@ cdef class PhraseMatcher: self._phrase_key = self.mem.alloc(max_length, sizeof(attr_t)) self.max_length = max_length self.vocab = vocab - self.matcher = Matcher(self.vocab, {}) + self.matcher = Matcher(self.vocab) self.phrase_ids = PreshMap() for phrase in phrases: if len(phrase) < max_length: @@ -435,7 +435,7 @@ cdef class PhraseMatcher: abstract_patterns = [] for length in range(1, max_length): abstract_patterns.append([{tag: True} for tag in get_bilou(length)]) - self.matcher.add('Candidate', 'MWE', {}, abstract_patterns, acceptor=self.accept_match) + self.matcher.add('Candidate', None, *abstract_patterns) def add(self, Doc tokens): cdef int length = tokens.length @@ -454,22 +454,19 @@ cdef class PhraseMatcher: self.phrase_ids[key] = True def __call__(self, Doc doc): - matches = [] - for ent_id, label, start, end in self.matcher(doc): - cand = doc[start : end] - start = cand[0].idx - end = cand[-1].idx + len(cand[-1]) - matches.append((start, end, cand.root.tag_, cand.text, 'MWE')) - for match in matches: - doc.merge(*match) - return matches + matches = self.matcher(doc) + accepted = [] + for ent_id, start, end in matches: + if self.accept_match(doc, ent_id, start, end): + accepted.append((ent_id, start, end)) + return accepted def pipe(self, stream, batch_size=1000, n_threads=2): for doc in stream: self(doc) yield doc - def accept_match(self, Doc doc, attr_t ent_id, attr_t label, int start, int end): + def accept_match(self, Doc doc, attr_t ent_id, int start, int end): assert (end - start) < self.max_length cdef int i, j for i in range(self.max_length): @@ -478,6 +475,6 @@ cdef class PhraseMatcher: self._phrase_key[i] = doc.c[j].lex.orth cdef hash_t key = hash64(self._phrase_key, self.max_length * sizeof(attr_t), 0) if self.phrase_ids.get(key): - return (ent_id, label, start, end) + return True else: return False