Move pattern_id out of TokenPattern

This commit is contained in:
Matthew Honnibal 2018-02-12 12:05:54 +01:00
parent d34c732635
commit b00326a7fe
1 changed files with 20 additions and 12 deletions

View File

@ -42,13 +42,12 @@ cdef struct ActionC:
cdef struct PatternStateC: cdef struct PatternStateC:
TokenPatternC* state TokenPatternC* state
int32_t pattern_id
int32_t start int32_t start
ActionC last_action ActionC last_action
cdef struct MatchC: cdef struct MatchC:
int32_t pattern_id attr_t pattern_id
int32_t start int32_t start
int32_t end int32_t end
@ -57,15 +56,16 @@ cdef find_matches(TokenPatternC** patterns, int n, Doc doc):
cdef vector[PatternStateC] init_states cdef vector[PatternStateC] init_states
cdef ActionC null_action = ActionC(-1, -1, -1) cdef ActionC null_action = ActionC(-1, -1, -1)
for i in range(n): for i in range(n):
init_states.push_back(PatternStateC(patterns[i], i, -1, last_action=null_action)) init_states.push_back(PatternStateC(patterns[i], -1, last_action=null_action))
cdef vector[PatternStateC] curr_states cdef vector[PatternStateC] curr_states
cdef vector[PatternStateC] nexts cdef vector[PatternStateC] nexts
cdef vector[MatchC] matches cdef vector[MatchC] matches
cdef PreshMap cache = PreshMap() cdef PreshMap cache
cdef Pool mem = Pool() cdef Pool mem = Pool()
# TODO: Prefill this with the extra attribute values. # TODO: Prefill this with the extra attribute values.
extra_attrs = <attr_t**>mem.alloc(len(doc), sizeof(attr_t*)) extra_attrs = <attr_t**>mem.alloc(len(doc), sizeof(attr_t*))
for i in range(doc.length): for i in range(doc.length):
cache = PreshMap()
nexts.clear() nexts.clear()
for j in range(curr_states.size()): for j in range(curr_states.size()):
action = get_action(curr_states[j], &doc.c[i], extra_attrs[i], cache) action = get_action(curr_states[j], &doc.c[i], extra_attrs[i], cache)
@ -79,12 +79,13 @@ cdef find_matches(TokenPatternC** patterns, int n, Doc doc):
# Filter out matches that have a longer equivalent. # Filter out matches that have a longer equivalent.
longest_matches = {} longest_matches = {}
for i in range(matches.size()): for i in range(matches.size()):
key = matches[i].pattern_id, matches[i].start key = (matches[i].pattern_id, matches[i].start)
length = matches[i].end - matches[i].start length = matches[i].end - matches[i].start
if key not in longest_matches or length > longest_matches[key]: if key not in longest_matches or length > longest_matches[key]:
longest_matches[key] = length longest_matches[key] = length
return [(pattern_id, start, length-start) print(longest_matches)
for (pattern_id, start), length in longest_matches] return [(pattern_id, start, start+length)
for (pattern_id, start), length in longest_matches.items()]
cdef void transition(vector[MatchC]& matches, vector[PatternStateC]& nexts, cdef void transition(vector[MatchC]& matches, vector[PatternStateC]& nexts,
@ -92,14 +93,15 @@ cdef void transition(vector[MatchC]& matches, vector[PatternStateC]& nexts,
if state.start == -1: if state.start == -1:
state.start = token state.start = token
if action.is_match: if action.is_match:
ent_id = state.state[1].attrs.value
matches.push_back( matches.push_back(
MatchC(pattern_id=state.pattern_id, start=state.start, end=token+1)) MatchC(pattern_id=ent_id, start=state.start, end=token+1))
if action.keep_state: if action.keep_state:
nexts.push_back(PatternStateC(pattern_id=pattern_id, nexts.push_back(PatternStateC(start=state.start, state=state.state,
start=state.start, state=state.state, last_action=action)) last_action=action))
if action.advance_state: if action.advance_state:
nexts.push_back(PatternStateC(pattern_id=pattern_id, nexts.push_back(PatternStateC(start=state.start,
start=state.start, state=state.state+1, last_action=action)) state=state.state+1, last_action=action))
cdef ActionC get_action(PatternStateC state, const TokenC* token, const attr_t* extra_attrs, cdef ActionC get_action(PatternStateC state, const TokenC* token, const attr_t* extra_attrs,
@ -387,6 +389,12 @@ cdef class Matcher:
matches = find_matches(&self.patterns[0], self.patterns.size(), doc) matches = find_matches(&self.patterns[0], self.patterns.size(), doc)
return matches return matches
def _normalize_key(self, key):
if isinstance(key, basestring):
return self.vocab.strings.add(key)
else:
return key
def unpickle_matcher(vocab, patterns, callbacks): def unpickle_matcher(vocab, patterns, callbacks):
matcher = Matcher(vocab) matcher = Matcher(vocab)