Make helper function to get longest matches

This commit is contained in:
Matthew Honnibal 2018-02-15 15:26:15 +01:00
parent d19dc67886
commit 9ebf2fe7c3
1 changed files with 17 additions and 10 deletions

View File

@ -94,25 +94,21 @@ cdef struct MatchC:
cdef find_matches(TokenPatternC** patterns, int n, Doc doc): cdef find_matches(TokenPatternC** patterns, int n, Doc doc):
cdef vector[PatternStateC] states cdef vector[PatternStateC] states
cdef vector[MatchC] matches cdef vector[MatchC] matches
cdef PatternStateC state
cdef Pool mem = Pool() cdef Pool mem = Pool()
# TODO: Prefill this with the extra attribute values. # TODO: Prefill this with the extra attribute values.
extra_attrs = <attr_t**>mem.alloc(len(doc), sizeof(attr_t*)) extra_attrs = <attr_t**>mem.alloc(len(doc), sizeof(attr_t*))
# Main loop # Main loop
cdef int i, j
for i in range(doc.length): for i in range(doc.length):
for j in range(n): for j in range(n):
states.push_back(PatternStateC(patterns[j], i, 0)) states.push_back(PatternStateC(patterns[j], i, 0))
transition_states(states, matches, &doc.c[i], extra_attrs[i]) transition_states(states, matches, &doc.c[i], extra_attrs[i])
# Handle matches that end in 0-width patterns # Handle matches that end in 0-width patterns
finish_states(matches, states) finish_states(matches, states)
# Filter out matches that have a longer equivalent. return [(matches[i].pattern_id, matches[i].start, matches[i].start+matches[i].length)
longest_matches = {} for i in range(matches.size())]
for i in range(matches.size()):
key = (matches[i].pattern_id, matches[i].start)
length = matches[i].length
if key not in longest_matches or length > longest_matches[key]:
longest_matches[key] = length
return [(pattern_id, start, start+length)
for (pattern_id, start), length in longest_matches.items()]
cdef void transition_states(vector[PatternStateC]& states, vector[MatchC]& matches, cdef void transition_states(vector[PatternStateC]& states, vector[MatchC]& matches,
@ -493,7 +489,6 @@ cdef class Matcher:
self(doc) self(doc)
yield doc yield doc
def __call__(self, Doc doc): def __call__(self, Doc doc):
"""Find all token sequences matching the supplied pattern. """Find all token sequences matching the supplied pattern.
@ -524,6 +519,18 @@ def unpickle_matcher(vocab, patterns, callbacks):
return matcher return matcher
def _get_longest_matches(matches):
'''Filter out matches that have a longer equivalent.'''
longest_matches = {}
for pattern_id, start, end in matches:
key = (pattern_id, start)
length = end-start
if key not in longest_matches or length > longest_matches[key]:
longest_matches[key] = length
return [(pattern_id, start, start+length)
for (pattern_id, start), length in longest_matches.items()]
def get_bilou(length): def get_bilou(length):
if length == 1: if length == 1:
return [U_ENT] return [U_ENT]