Move pattern_id out of TokenPattern

This commit is contained in:
Matthew Honnibal 2018-02-12 12:05:54 +01:00
parent d34c732635
commit b00326a7fe
1 changed files with 20 additions and 12 deletions

View File

@ -42,13 +42,12 @@ cdef struct ActionC:
cdef struct PatternStateC:
TokenPatternC* state
int32_t pattern_id
int32_t start
ActionC last_action
cdef struct MatchC:
int32_t pattern_id
attr_t pattern_id
int32_t start
int32_t end
@ -57,15 +56,16 @@ cdef find_matches(TokenPatternC** patterns, int n, Doc doc):
cdef vector[PatternStateC] init_states
cdef ActionC null_action = ActionC(-1, -1, -1)
for i in range(n):
init_states.push_back(PatternStateC(patterns[i], i, -1, last_action=null_action))
init_states.push_back(PatternStateC(patterns[i], -1, last_action=null_action))
cdef vector[PatternStateC] curr_states
cdef vector[PatternStateC] nexts
cdef vector[MatchC] matches
cdef PreshMap cache = PreshMap()
cdef PreshMap cache
cdef Pool mem = Pool()
# TODO: Prefill this with the extra attribute values.
extra_attrs = <attr_t**>mem.alloc(len(doc), sizeof(attr_t*))
for i in range(doc.length):
cache = PreshMap()
nexts.clear()
for j in range(curr_states.size()):
action = get_action(curr_states[j], &doc.c[i], extra_attrs[i], cache)
@ -79,12 +79,13 @@ cdef find_matches(TokenPatternC** patterns, int n, Doc doc):
# Filter out matches that have a longer equivalent.
longest_matches = {}
for i in range(matches.size()):
key = matches[i].pattern_id, matches[i].start
key = (matches[i].pattern_id, matches[i].start)
length = matches[i].end - matches[i].start
if key not in longest_matches or length > longest_matches[key]:
longest_matches[key] = length
return [(pattern_id, start, length-start)
for (pattern_id, start), length in longest_matches]
print(longest_matches)
return [(pattern_id, start, start+length)
for (pattern_id, start), length in longest_matches.items()]
cdef void transition(vector[MatchC]& matches, vector[PatternStateC]& nexts,
@ -92,14 +93,15 @@ cdef void transition(vector[MatchC]& matches, vector[PatternStateC]& nexts,
if state.start == -1:
state.start = token
if action.is_match:
ent_id = state.state[1].attrs.value
matches.push_back(
MatchC(pattern_id=state.pattern_id, start=state.start, end=token+1))
MatchC(pattern_id=ent_id, start=state.start, end=token+1))
if action.keep_state:
nexts.push_back(PatternStateC(pattern_id=pattern_id,
start=state.start, state=state.state, last_action=action))
nexts.push_back(PatternStateC(start=state.start, state=state.state,
last_action=action))
if action.advance_state:
nexts.push_back(PatternStateC(pattern_id=pattern_id,
start=state.start, state=state.state+1, last_action=action))
nexts.push_back(PatternStateC(start=state.start,
state=state.state+1, last_action=action))
cdef ActionC get_action(PatternStateC state, const TokenC* token, const attr_t* extra_attrs,
@ -387,6 +389,12 @@ cdef class Matcher:
matches = find_matches(&self.patterns[0], self.patterns.size(), doc)
return matches
def _normalize_key(self, key):
if isinstance(key, basestring):
return self.vocab.strings.add(key)
else:
return key
def unpickle_matcher(vocab, patterns, callbacks):
matcher = Matcher(vocab)