diff --git a/spacy/matcher2.pyx b/spacy/matcher2.pyx index 37aa5ed61..4545a2f31 100644 --- a/spacy/matcher2.pyx +++ b/spacy/matcher2.pyx @@ -35,28 +35,30 @@ cdef struct TokenPatternC: cdef struct ActionC: - char is_match - char keep_state - char advance_state + char emit_match + char next_state_next_token + char next_state_same_token + char same_state_next_token cdef struct PatternStateC: - TokenPatternC* state + TokenPatternC* pattern int32_t start - ActionC last_action + int32_t length cdef struct MatchC: attr_t pattern_id int32_t start - int32_t end + int32_t length cdef find_matches(TokenPatternC** patterns, int n, Doc doc): + print("N patterns: ", n) cdef vector[PatternStateC] init_states - cdef ActionC null_action = ActionC(-1, -1, -1) + cdef ActionC null_action = ActionC(-1, -1, -1, -1) for i in range(n): - init_states.push_back(PatternStateC(patterns[i], -1, last_action=null_action)) + init_states.push_back(PatternStateC(patterns[i], -1, 0)) cdef vector[PatternStateC] curr_states cdef vector[PatternStateC] nexts cdef vector[MatchC] matches @@ -65,48 +67,65 @@ cdef find_matches(TokenPatternC** patterns, int n, Doc doc): # TODO: Prefill this with the extra attribute values. extra_attrs = mem.alloc(len(doc), sizeof(attr_t*)) for i in range(doc.length): - cache = PreshMap() nexts.clear() + cache = PreshMap() for j in range(curr_states.size()): transition(matches, nexts, - curr_states[j], i, doc, extra_attrs, cache) + curr_states[j], i, &doc.c[i], extra_attrs[i], cache) for j in range(init_states.size()): transition(matches, nexts, - init_states[j], i, doc, extra_attrs, cache) + init_states[j], i, &doc.c[i], extra_attrs[i], cache) nexts, curr_states = curr_states, nexts + # Handle patterns that end with zero-width + for j in range(curr_states.size()): + state = curr_states[j] + while get_quantifier(state) in (ZERO_PLUS, ZERO_ONE): + is_final = get_is_final(state) + if is_final: + ent_id = state.pattern[1].attrs.value + matches.push_back( + MatchC(pattern_id=ent_id, start=state.start, length=state.length)) + break + else: + state.pattern += 1 # Filter out matches that have a longer equivalent. longest_matches = {} for i in range(matches.size()): key = (matches[i].pattern_id, matches[i].start) - length = matches[i].end - matches[i].start + length = matches[i].length if key not in longest_matches or length > longest_matches[key]: longest_matches[key] = length - print(longest_matches) return [(pattern_id, start, start+length) for (pattern_id, start), length in longest_matches.items()] cdef void transition(vector[MatchC]& matches, vector[PatternStateC]& nexts, - PatternStateC state, int token, - Doc doc, const attr_t* const* extra_attrs, PreshMap cache) except *: - action = get_action(state, &doc.c[token], extra_attrs[token], cache) + PatternStateC state, int i, const TokenC* token, const attr_t* extra_attrs, + PreshMap cache) except *: + action = get_action(state, token, extra_attrs, cache) if state.start == -1: - state.start = token - if action.is_match: - ent_id = state.state[1].attrs.value + state.start = i + if action.emit_match == 1: + ent_id = state.pattern[1].attrs.value matches.push_back( - MatchC(pattern_id=ent_id, start=state.start, end=token+1)) - if action.advance_state: + MatchC(pattern_id=ent_id, start=state.start, length=state.length+1)) + elif action.emit_match == 2: + ent_id = state.pattern[1].attrs.value + matches.push_back( + MatchC(pattern_id=ent_id, start=state.start, length=state.length)) + if action.next_state_next_token: nexts.push_back(PatternStateC(start=state.start, - state=state.state+1, last_action=action)) + pattern=&state.pattern[1], length=state.length+1)) + if action.same_state_next_token: + nexts.push_back(PatternStateC(start=state.start, + pattern=state.pattern, length=state.length+1)) cdef PatternStateC next_state - if action.keep_state and token < doc.length: - # Keeping the state needs to not consume a token, so we call transition - # with the next state - next_state = PatternStateC(start=state.start, state=state.state+1, - last_action=action) - transition(matches, nexts, next_state, token, doc, extra_attrs, cache) - + if action.next_state_same_token: + # 0+ and ? non-matches need to not consume a token, so we call transition + # with the same state + next_state = PatternStateC(start=state.start, pattern=&state.pattern[1], + length=state.length) + transition(matches, nexts, next_state, i, token, extra_attrs, cache) cdef ActionC get_action(PatternStateC state, const TokenC* token, const attr_t* extra_attrs, @@ -117,74 +136,108 @@ cdef ActionC get_action(PatternStateC state, const TokenC* token, const attr_t* b) What's the quantifier? [1, 0+, ?] c) Is this the last specification? [final, non-final] - We therefore have 12 cases to consider. For each case, we need to know - whether to emit a match, whether to keep the current state in the partials, - and whether to add an advanced state to the partials. + We can transition in the following ways: - We therefore have eight possible results for these three booleans, which - we'll code as 000, 001 etc. + a) Do we emit a match? + b) Do we add a state with (next state, next token)? + c) Do we add a state with (next state, same token)? + d) Do we add a state with (same state, next token)? + + We'll code the actions as boolean strings, so 0000 means no to all 4, + 1000 means match but no states added, etc. 1: - - Match, final: - 100 - - Match, non-final: - 001 - - No match: - 000 + Yes, final: + 1000 + Yes, non-final: + 0100 + No, final: + 0000 + No, non-final + 0000 0+: - - Match, final: - 100 - - Match, non-final: - 011 - - Non-match, final: - 100 - - Non-match, non-final: - 010 + Yes, final: + 1001 + Yes, non-final: + 0011 + No, final: + 1000 (note: Don't include last token!) + No, non-final: + 0010 + ?: + Yes, final: + 1000 + Yes, non-final: + 0100 + No, final: + 1000 (note: Don't include last token!) + No, non-final: + 0010 Problem: If a quantifier is matching, we're adding a lot of open partials - Question: Is it worth doing a lookahead, to see if we add? ''' - cached_match = cache.get(state.state.key) + cached_match = cache.get(state.pattern.key) cdef char is_match if cached_match == 0: is_match = get_is_match(state, token, extra_attrs) cached_match = is_match + 1 - cache.set(state.state.key, cached_match) + cache.set(state.pattern.key, cached_match) elif cached_match == 1: is_match = 0 else: is_match = 1 - quantifier = get_quantifier(state, token) - is_final = get_is_final(state, token) + quantifier = get_quantifier(state) + is_final = get_is_final(state) + if quantifier == ZERO: + is_match = not is_match + quantifier = ONE if quantifier == ONE: - if not is_match: - return ActionC(is_match=0, keep_state=0, advance_state=0) - elif is_final: - return ActionC(is_match=1, keep_state=0, advance_state=0) - else: - return ActionC(is_match=0, keep_state=0, advance_state=1) + if is_match and is_final: + # Yes, final: 1000 + return ActionC(1, 0, 0, 0) + elif is_match and not is_final: + # Yes, non-final: 0100 + return ActionC(0, 1, 0, 0) + elif not is_match and is_final: + # No, final: 0000 + return ActionC(0, 0, 0, 0) + else: + # No, non-final 0000 + return ActionC(0, 0, 0, 0) + elif quantifier == ZERO_PLUS: - if is_final: - return ActionC(is_match=1, keep_state=0, advance_state=0) - elif is_match: - return ActionC(is_match=0, keep_state=1, advance_state=1) - else: - return ActionC(is_match=0, keep_state=1, advance_state=0) + if is_match and is_final: + # Yes, final: 1001 + return ActionC(1, 0, 0, 1) + elif is_match and not is_final: + # Yes, non-final: 0011 + return ActionC(0, 0, 1, 1) + elif not is_match and is_final: + # No, final 1000 (note: Don't include last token!) + return ActionC(2, 0, 0, 0) + else: + # No, non-final 0010 + return ActionC(0, 0, 1, 0) elif quantifier == ZERO_ONE: - if is_final: - return ActionC(is_match=1, keep_state=0, advance_state=0) - elif is_match: - if state.last_action.keep_state: - return ActionC(is_match=0, keep_state=0, advance_state=1) - else: - return ActionC(is_match=0, keep_state=1, advance_state=1) + if is_match and is_final: + # Yes, final: 1000 + return ActionC(1, 0, 0, 0) + elif is_match and not is_final: + # Yes, non-final: 0100 + return ActionC(0, 1, 0, 0) + elif not is_match and is_final: + # No, final 1000 (note: Don't include last token!) + return ActionC(2, 0, 0, 0) + else: + # No, non-final 0010 + return ActionC(0, 0, 1, 0) else: print(quantifier, is_match, is_final) raise ValueError cdef char get_is_match(PatternStateC state, const TokenC* token, const attr_t* extra_attrs) nogil: - spec = state.state + spec = state.pattern for attr in spec.attrs[:spec.nr_attr]: if get_token_attr(token, attr.attr) != attr.value: return 0 @@ -192,15 +245,15 @@ cdef char get_is_match(PatternStateC state, const TokenC* token, const attr_t* e return 1 -cdef char get_is_final(PatternStateC state, const TokenC* token) nogil: - if state.state[1].attrs[0].attr == ID and state.state[1].nr_attr == 0: +cdef char get_is_final(PatternStateC state) nogil: + if state.pattern[1].attrs[0].attr == ID and state.pattern[1].nr_attr == 0: return 1 else: return 0 -cdef char get_quantifier(PatternStateC state, const TokenC* token) nogil: - return state.state.quantifier +cdef char get_quantifier(PatternStateC state) nogil: + return state.pattern.quantifier cdef TokenPatternC* init_pattern(Pool mem, attr_t entity_id, @@ -232,7 +285,7 @@ cdef attr_t get_pattern_key(const TokenPatternC* pattern) nogil: def _convert_strings(token_specs, string_store): # Support 'syntactic sugar' operator '+', as combination of ONE, ZERO_PLUS operators = {'*': (ZERO_PLUS,), '+': (ONE, ZERO_PLUS), - '?': (ZERO_ONE,), '1': (ONE,)} + '?': (ZERO_ONE,), '1': (ONE,), '!': (ZERO,)} tokens = [] op = ONE for spec in token_specs: @@ -392,6 +445,10 @@ cdef class Matcher: `doc[start:end]`. The `label_id` and `key` are both integers. """ matches = find_matches(&self.patterns[0], self.patterns.size(), doc) + for i, (key, start, end) in enumerate(matches): + on_match = self._callbacks.get(key, None) + if on_match is not None: + on_match(self, doc, i, matches) return matches def _normalize_key(self, key):