mirror of https://github.com/explosion/spaCy.git
Make helper function to get longest matches
This commit is contained in:
parent
d19dc67886
commit
9ebf2fe7c3
|
@ -94,25 +94,21 @@ cdef struct MatchC:
|
|||
cdef find_matches(TokenPatternC** patterns, int n, Doc doc):
|
||||
cdef vector[PatternStateC] states
|
||||
cdef vector[MatchC] matches
|
||||
cdef PatternStateC state
|
||||
cdef Pool mem = Pool()
|
||||
# TODO: Prefill this with the extra attribute values.
|
||||
extra_attrs = <attr_t**>mem.alloc(len(doc), sizeof(attr_t*))
|
||||
# Main loop
|
||||
cdef int i, j
|
||||
for i in range(doc.length):
|
||||
for j in range(n):
|
||||
states.push_back(PatternStateC(patterns[j], i, 0))
|
||||
transition_states(states, matches, &doc.c[i], extra_attrs[i])
|
||||
# Handle matches that end in 0-width patterns
|
||||
finish_states(matches, states)
|
||||
# Filter out matches that have a longer equivalent.
|
||||
longest_matches = {}
|
||||
for i in range(matches.size()):
|
||||
key = (matches[i].pattern_id, matches[i].start)
|
||||
length = matches[i].length
|
||||
if key not in longest_matches or length > longest_matches[key]:
|
||||
longest_matches[key] = length
|
||||
return [(pattern_id, start, start+length)
|
||||
for (pattern_id, start), length in longest_matches.items()]
|
||||
return [(matches[i].pattern_id, matches[i].start, matches[i].start+matches[i].length)
|
||||
for i in range(matches.size())]
|
||||
|
||||
|
||||
|
||||
cdef void transition_states(vector[PatternStateC]& states, vector[MatchC]& matches,
|
||||
|
@ -493,7 +489,6 @@ cdef class Matcher:
|
|||
self(doc)
|
||||
yield doc
|
||||
|
||||
|
||||
def __call__(self, Doc doc):
|
||||
"""Find all token sequences matching the supplied pattern.
|
||||
|
||||
|
@ -524,6 +519,18 @@ def unpickle_matcher(vocab, patterns, callbacks):
|
|||
return matcher
|
||||
|
||||
|
||||
def _get_longest_matches(matches):
|
||||
'''Filter out matches that have a longer equivalent.'''
|
||||
longest_matches = {}
|
||||
for pattern_id, start, end in matches:
|
||||
key = (pattern_id, start)
|
||||
length = end-start
|
||||
if key not in longest_matches or length > longest_matches[key]:
|
||||
longest_matches[key] = length
|
||||
return [(pattern_id, start, start+length)
|
||||
for (pattern_id, start), length in longest_matches.items()]
|
||||
|
||||
|
||||
def get_bilou(length):
|
||||
if length == 1:
|
||||
return [U_ENT]
|
||||
|
|
Loading…
Reference in New Issue