diff --git a/spacy/matcher.pyx b/spacy/matcher.pyx index 061936105..d9804a922 100644 --- a/spacy/matcher.pyx +++ b/spacy/matcher.pyx @@ -12,10 +12,10 @@ from cython.operator cimport dereference as deref from murmurhash.mrmr cimport hash64 from libc.stdint cimport int32_t -try: - from libcpp.unordered_map cimport unordered_map as umap -except: - from libcpp.map cimport map as umap +# try: +# from libcpp.unordered_map cimport unordered_map as umap +# except: +# from libcpp.map cimport map as umap from .typedefs cimport attr_t from .typedefs cimport hash_t @@ -72,6 +72,7 @@ cdef enum action_t: ACCEPT_PREV PANIC + # Each token pattern consists of a quantifier and 0+ (attr, value) pairs. # A state is an (int, pattern pointer) pair, where the int is the start # position, and the pattern pointer shows where we're up to @@ -89,7 +90,7 @@ cdef struct TokenPatternC: ctypedef TokenPatternC* TokenPatternC_ptr -ctypedef pair[int, TokenPatternC_ptr] StateC +# ctypedef pair[int, TokenPatternC_ptr] StateC # Match Dictionary entry type cdef struct MatchEntryC: @@ -97,6 +98,19 @@ cdef struct MatchEntryC: int32_t end int32_t offset +# A state instance represents the information that defines a +# partial match +# start: the index of the first token in the partial match +# pattern: a pointer to the current token pattern in the full +# pattern +# last_match: The entry of the last span matched by the +# same pattern +cdef struct StateC: + int32_t start + TokenPatternC_ptr pattern + MatchEntryC* last_match + + cdef TokenPatternC* init_pattern(Pool mem, attr_t entity_id, object token_specs) except NULL: pattern = mem.alloc(len(token_specs) + 1, sizeof(TokenPatternC)) @@ -346,11 +360,15 @@ cdef class Matcher: cdef StateC state cdef int j = 0 cdef int k - cdef bint add_match,overlap = False - cdef TokenPatternC_ptr final_state - cdef umap[TokenPatternC_ptr,MatchEntryC] matches_dict - cdef umap[TokenPatternC_ptr,MatchEntryC].iterator state_match - cdef MatchEntryC new_match + cdef bint overlap = False + cdef MatchEntryC* state_match + cdef MatchEntryC* last_matches = self.mem.alloc(self.patterns.size(),sizeof(MatchEntryC)) + + for i in range(self.patterns.size()): + last_matches[i].start = 0 + last_matches[i].end = 0 + last_matches[i].offset = 0 + matches = [] for token_i in range(doc.length): token = &doc.c[token_i] @@ -361,7 +379,7 @@ cdef class Matcher: j=0 while j < n_partials: state = partials[j] - action = get_action(state.second, token) + action = get_action(state.pattern, token) j += 1 # Skip patterns that would overlap with an existing match # Patterns overlap an existing match if they point to the @@ -369,33 +387,29 @@ cdef class Matcher: # of said match. # Different patterns with the same label are allowed to # overlap. - final_state = state.second - while final_state.nr_attr != 0: - final_state+=1 - state_match = matches_dict.find(final_state) - if (state_match != matches_dict.end() - and state.first>deref(state_match).second.start - and state.first state_match.start + and state.start < state_match.end): continue if action == PANIC: raise Exception("Error selecting action in matcher") while action == ADVANCE_ZERO: - state.second += 1 - action = get_action(state.second, token) + state.pattern += 1 + action = get_action(state.pattern, token) if action == PANIC: raise Exception("Error selecting action in matcher") # ADVANCE_PLUS acts like REPEAT, but also pushes a partial that # acts like and ADVANCE_ZERO if action == ADVANCE_PLUS: - state.second += 1 + state.pattern += 1 partials.push_back(state) n_partials += 1 - state.second -= 1 + state.pattern -= 1 action = REPEAT if action == ADVANCE: - state.second += 1 + state.pattern += 1 # Check for partial matches that are at the same spec in the same pattern # Keep the longer of the matches @@ -404,7 +418,7 @@ cdef class Matcher: overlap=False for i in range(q): - if state.second == partials[i].second and state.first < partials[i].first: + if state.pattern == partials[i].pattern and state.start < partials[i].start: partials[i] = state j = i overlap = True @@ -413,7 +427,7 @@ cdef class Matcher: continue overlap=False for i in range(q): - if state.second == partials[i].second: + if state.pattern == partials[i].pattern: overlap = True break if overlap: @@ -434,60 +448,53 @@ cdef class Matcher: elif action in (ACCEPT, ACCEPT_PREV): # TODO: What to do about patterns starting with ZERO? Need # to adjust the start position. - start = state.first + start = state.start end = token_i+1 if action == ACCEPT else token_i - ent_id = state.second[1].attrs[0].value - # ent_id = get_pattern_key(state.second) - label = state.second[1].attrs[1].value + ent_id = state.pattern[1].attrs[0].value + label = state.pattern[1].attrs[1].value # Check that this match doesn't overlap with an earlier match. # Only overwrite an earlier match if it is a substring of this # match (i.e. it starts after this match starts). - final_state = state.second+1 - state_match = matches_dict.find(final_state) + state_match = state.last_match - if state_match == matches_dict.end(): - new_match.start = start - new_match.end = end - new_match.offset = len(matches) - matches_dict[final_state] = new_match + if start >= state_match.end: + state_match.start = start + state_match.end = end + state_match.offset = len(matches) matches.append((ent_id,start,end)) - elif start >= deref(state_match).second.end: - new_match.start = start - new_match.end = end - new_match.offset = len(matches) - matches_dict[final_state] = new_match - matches.append((ent_id,start,end)) - elif start <= deref(state_match).second.start and end>=deref(state_match).second.end: - i = deref(state_match).second.offset - matches[i] = (ent_id,start,end) - new_match.start = start - new_match.end = end - new_match.offset = i - matches_dict[final_state] = new_match + elif start <= state_match.start and end >= state_match.end: + if len(matches) == 0: + assert state_match.offset==0 + state_match.offset = 0 + matches.append((ent_id,start,end)) + else: + i = state_match.offset + matches[i] = (ent_id,start,end) + state_match.start = start + state_match.end = end else: pass partials.resize(q) n_partials = q # Check whether we open any new patterns on this token + i=0 for pattern in self.patterns: # Skip patterns that would overlap with an existing match - ent_id = get_pattern_key(pattern) - final_state = pattern - while final_state.nr_attr != 0: - final_state+=1 - state_match = matches_dict.find(final_state) - if (state_match != matches_dict.end() - and token_i>deref(state_match).second.start - and token_i state_match.start + and token_i < state_match.end): continue action = get_action(pattern, token) if action == PANIC: raise Exception("Error selecting action in matcher") while action in (ADVANCE_PLUS,ADVANCE_ZERO): if action == ADVANCE_PLUS: - state.first = token_i - state.second = pattern + state.start = token_i + state.pattern = pattern + state.last_match = state_match partials.push_back(state) n_partials += 1 pattern += 1 @@ -498,7 +505,7 @@ cdef class Matcher: j=0 overlap = False for j in range(q): - if pattern == partials[j].second: + if pattern == partials[j].pattern: overlap = True break if overlap: @@ -506,15 +513,17 @@ cdef class Matcher: if action == REPEAT: - state.first = token_i - state.second = pattern + state.start = token_i + state.pattern = pattern + state.last_match = state_match partials.push_back(state) n_partials += 1 elif action == ADVANCE: # TODO: What to do about patterns starting with ZERO? Need # to adjust the start position. - state.first = token_i - state.second = pattern + state.start = token_i + state.pattern = pattern + state.last_match = state_match partials.push_back(state) n_partials += 1 elif action in (ACCEPT, ACCEPT_PREV): @@ -523,60 +532,47 @@ cdef class Matcher: ent_id = pattern[1].attrs[0].value label = pattern[1].attrs[1].value - final_state = pattern+1 - state_match = matches_dict.find(final_state) - if state_match == matches_dict.end(): - new_match.start = start - new_match.end = end - new_match.offset = len(matches) - matches_dict[final_state] = new_match + if start >= state_match.end: + state_match.start = start + state_match.end = end + state_match.offset = len(matches) matches.append((ent_id,start,end)) - elif start >= deref(state_match).second.end: - new_match.start = start - new_match.end = end - new_match.offset = len(matches) - matches_dict[final_state] = new_match - matches.append((ent_id,start,end)) - elif start <= deref(state_match).second.start and end>=deref(state_match).second.end: - j = deref(state_match).second.offset - matches[j] = (ent_id,start,end) - new_match.start = start - new_match.end = end - new_match.offset = j - matches_dict[final_state] = new_match + if start <= state_match.start and end >= state_match.end: + if len(matches) == 0: + state_match.offset = 0 + matches.append((ent_id,start,end)) + else: + j = state_match.offset + matches[j] = (ent_id,start,end) + state_match.start = start + state_match.end = end else: pass # Look for open patterns that are actually satisfied for state in partials: - while state.second.quantifier in (ZERO, ZERO_ONE, ZERO_PLUS): - state.second += 1 - if state.second.nr_attr == 0: - start = state.first + while state.pattern.quantifier in (ZERO, ZERO_ONE, ZERO_PLUS): + state.pattern += 1 + if state.pattern.nr_attr == 0: + start = state.start end = len(doc) - ent_id = state.second.attrs[0].value - label = state.second.attrs[1].value - final_state = state.second - state_match = matches_dict.find(final_state) - if state_match == matches_dict.end(): - new_match.start = start - new_match.end = end - new_match.offset = len(matches) - matches_dict[final_state] = new_match + ent_id = state.pattern.attrs[0].value + label = state.pattern.attrs[1].value + state_match = state.last_match + if start >= state_match.end: + state_match.start = start + state_match.end = end + state_match.offset = len(matches) matches.append((ent_id,start,end)) - elif start >= deref(state_match).second.end: - new_match.start = start - new_match.end = end - new_match.offset = len(matches) - matches_dict[final_state] = new_match - matches.append((ent_id,start,end)) - elif start <= deref(state_match).second.start and end>=deref(state_match).second.end: - j = deref(state_match).second.offset - matches[j] = (ent_id,start,end) - new_match.start = start - new_match.end = end - new_match.offset = j - matches_dict[final_state] = new_match + if start <= state_match.start and end >= state_match.end: + j = state_match.offset + if len(matches) == 0: + state_match.offset = 0 + matches.append((ent_id,start,end)) + else: + matches[j] = (ent_id,start,end) + state_match.start = start + state_match.end = end else: pass for i, (ent_id, start, end) in enumerate(matches):