mirror of https://github.com/explosion/spaCy.git
Fix zero-width quantifiers. Passes test_matcher
This commit is contained in:
parent
1b01685f47
commit
b4cc39eb74
|
@ -35,28 +35,30 @@ cdef struct TokenPatternC:
|
||||||
|
|
||||||
|
|
||||||
cdef struct ActionC:
|
cdef struct ActionC:
|
||||||
char is_match
|
char emit_match
|
||||||
char keep_state
|
char next_state_next_token
|
||||||
char advance_state
|
char next_state_same_token
|
||||||
|
char same_state_next_token
|
||||||
|
|
||||||
|
|
||||||
cdef struct PatternStateC:
|
cdef struct PatternStateC:
|
||||||
TokenPatternC* state
|
TokenPatternC* pattern
|
||||||
int32_t start
|
int32_t start
|
||||||
ActionC last_action
|
int32_t length
|
||||||
|
|
||||||
|
|
||||||
cdef struct MatchC:
|
cdef struct MatchC:
|
||||||
attr_t pattern_id
|
attr_t pattern_id
|
||||||
int32_t start
|
int32_t start
|
||||||
int32_t end
|
int32_t length
|
||||||
|
|
||||||
|
|
||||||
cdef find_matches(TokenPatternC** patterns, int n, Doc doc):
|
cdef find_matches(TokenPatternC** patterns, int n, Doc doc):
|
||||||
|
print("N patterns: ", n)
|
||||||
cdef vector[PatternStateC] init_states
|
cdef vector[PatternStateC] init_states
|
||||||
cdef ActionC null_action = ActionC(-1, -1, -1)
|
cdef ActionC null_action = ActionC(-1, -1, -1, -1)
|
||||||
for i in range(n):
|
for i in range(n):
|
||||||
init_states.push_back(PatternStateC(patterns[i], -1, last_action=null_action))
|
init_states.push_back(PatternStateC(patterns[i], -1, 0))
|
||||||
cdef vector[PatternStateC] curr_states
|
cdef vector[PatternStateC] curr_states
|
||||||
cdef vector[PatternStateC] nexts
|
cdef vector[PatternStateC] nexts
|
||||||
cdef vector[MatchC] matches
|
cdef vector[MatchC] matches
|
||||||
|
@ -65,48 +67,65 @@ cdef find_matches(TokenPatternC** patterns, int n, Doc doc):
|
||||||
# TODO: Prefill this with the extra attribute values.
|
# TODO: Prefill this with the extra attribute values.
|
||||||
extra_attrs = <attr_t**>mem.alloc(len(doc), sizeof(attr_t*))
|
extra_attrs = <attr_t**>mem.alloc(len(doc), sizeof(attr_t*))
|
||||||
for i in range(doc.length):
|
for i in range(doc.length):
|
||||||
cache = PreshMap()
|
|
||||||
nexts.clear()
|
nexts.clear()
|
||||||
|
cache = PreshMap()
|
||||||
for j in range(curr_states.size()):
|
for j in range(curr_states.size()):
|
||||||
transition(matches, nexts,
|
transition(matches, nexts,
|
||||||
curr_states[j], i, doc, extra_attrs, cache)
|
curr_states[j], i, &doc.c[i], extra_attrs[i], cache)
|
||||||
for j in range(init_states.size()):
|
for j in range(init_states.size()):
|
||||||
transition(matches, nexts,
|
transition(matches, nexts,
|
||||||
init_states[j], i, doc, extra_attrs, cache)
|
init_states[j], i, &doc.c[i], extra_attrs[i], cache)
|
||||||
nexts, curr_states = curr_states, nexts
|
nexts, curr_states = curr_states, nexts
|
||||||
|
# Handle patterns that end with zero-width
|
||||||
|
for j in range(curr_states.size()):
|
||||||
|
state = curr_states[j]
|
||||||
|
while get_quantifier(state) in (ZERO_PLUS, ZERO_ONE):
|
||||||
|
is_final = get_is_final(state)
|
||||||
|
if is_final:
|
||||||
|
ent_id = state.pattern[1].attrs.value
|
||||||
|
matches.push_back(
|
||||||
|
MatchC(pattern_id=ent_id, start=state.start, length=state.length))
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
state.pattern += 1
|
||||||
# Filter out matches that have a longer equivalent.
|
# Filter out matches that have a longer equivalent.
|
||||||
longest_matches = {}
|
longest_matches = {}
|
||||||
for i in range(matches.size()):
|
for i in range(matches.size()):
|
||||||
key = (matches[i].pattern_id, matches[i].start)
|
key = (matches[i].pattern_id, matches[i].start)
|
||||||
length = matches[i].end - matches[i].start
|
length = matches[i].length
|
||||||
if key not in longest_matches or length > longest_matches[key]:
|
if key not in longest_matches or length > longest_matches[key]:
|
||||||
longest_matches[key] = length
|
longest_matches[key] = length
|
||||||
print(longest_matches)
|
|
||||||
return [(pattern_id, start, start+length)
|
return [(pattern_id, start, start+length)
|
||||||
for (pattern_id, start), length in longest_matches.items()]
|
for (pattern_id, start), length in longest_matches.items()]
|
||||||
|
|
||||||
|
|
||||||
cdef void transition(vector[MatchC]& matches, vector[PatternStateC]& nexts,
|
cdef void transition(vector[MatchC]& matches, vector[PatternStateC]& nexts,
|
||||||
PatternStateC state, int token,
|
PatternStateC state, int i, const TokenC* token, const attr_t* extra_attrs,
|
||||||
Doc doc, const attr_t* const* extra_attrs, PreshMap cache) except *:
|
PreshMap cache) except *:
|
||||||
action = get_action(state, &doc.c[token], extra_attrs[token], cache)
|
action = get_action(state, token, extra_attrs, cache)
|
||||||
if state.start == -1:
|
if state.start == -1:
|
||||||
state.start = token
|
state.start = i
|
||||||
if action.is_match:
|
if action.emit_match == 1:
|
||||||
ent_id = state.state[1].attrs.value
|
ent_id = state.pattern[1].attrs.value
|
||||||
matches.push_back(
|
matches.push_back(
|
||||||
MatchC(pattern_id=ent_id, start=state.start, end=token+1))
|
MatchC(pattern_id=ent_id, start=state.start, length=state.length+1))
|
||||||
if action.advance_state:
|
elif action.emit_match == 2:
|
||||||
|
ent_id = state.pattern[1].attrs.value
|
||||||
|
matches.push_back(
|
||||||
|
MatchC(pattern_id=ent_id, start=state.start, length=state.length))
|
||||||
|
if action.next_state_next_token:
|
||||||
nexts.push_back(PatternStateC(start=state.start,
|
nexts.push_back(PatternStateC(start=state.start,
|
||||||
state=state.state+1, last_action=action))
|
pattern=&state.pattern[1], length=state.length+1))
|
||||||
|
if action.same_state_next_token:
|
||||||
|
nexts.push_back(PatternStateC(start=state.start,
|
||||||
|
pattern=state.pattern, length=state.length+1))
|
||||||
cdef PatternStateC next_state
|
cdef PatternStateC next_state
|
||||||
if action.keep_state and token < doc.length:
|
if action.next_state_same_token:
|
||||||
# Keeping the state needs to not consume a token, so we call transition
|
# 0+ and ? non-matches need to not consume a token, so we call transition
|
||||||
# with the next state
|
# with the same state
|
||||||
next_state = PatternStateC(start=state.start, state=state.state+1,
|
next_state = PatternStateC(start=state.start, pattern=&state.pattern[1],
|
||||||
last_action=action)
|
length=state.length)
|
||||||
transition(matches, nexts, next_state, token, doc, extra_attrs, cache)
|
transition(matches, nexts, next_state, i, token, extra_attrs, cache)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
cdef ActionC get_action(PatternStateC state, const TokenC* token, const attr_t* extra_attrs,
|
cdef ActionC get_action(PatternStateC state, const TokenC* token, const attr_t* extra_attrs,
|
||||||
|
@ -117,74 +136,108 @@ cdef ActionC get_action(PatternStateC state, const TokenC* token, const attr_t*
|
||||||
b) What's the quantifier? [1, 0+, ?]
|
b) What's the quantifier? [1, 0+, ?]
|
||||||
c) Is this the last specification? [final, non-final]
|
c) Is this the last specification? [final, non-final]
|
||||||
|
|
||||||
We therefore have 12 cases to consider. For each case, we need to know
|
We can transition in the following ways:
|
||||||
whether to emit a match, whether to keep the current state in the partials,
|
|
||||||
and whether to add an advanced state to the partials.
|
|
||||||
|
|
||||||
We therefore have eight possible results for these three booleans, which
|
a) Do we emit a match?
|
||||||
we'll code as 000, 001 etc.
|
b) Do we add a state with (next state, next token)?
|
||||||
|
c) Do we add a state with (next state, same token)?
|
||||||
|
d) Do we add a state with (same state, next token)?
|
||||||
|
|
||||||
|
We'll code the actions as boolean strings, so 0000 means no to all 4,
|
||||||
|
1000 means match but no states added, etc.
|
||||||
|
|
||||||
1:
|
1:
|
||||||
- Match, final:
|
Yes, final:
|
||||||
100
|
1000
|
||||||
- Match, non-final:
|
Yes, non-final:
|
||||||
001
|
0100
|
||||||
- No match:
|
No, final:
|
||||||
000
|
0000
|
||||||
|
No, non-final
|
||||||
|
0000
|
||||||
0+:
|
0+:
|
||||||
- Match, final:
|
Yes, final:
|
||||||
100
|
1001
|
||||||
- Match, non-final:
|
Yes, non-final:
|
||||||
011
|
0011
|
||||||
- Non-match, final:
|
No, final:
|
||||||
100
|
1000 (note: Don't include last token!)
|
||||||
- Non-match, non-final:
|
No, non-final:
|
||||||
010
|
0010
|
||||||
|
?:
|
||||||
|
Yes, final:
|
||||||
|
1000
|
||||||
|
Yes, non-final:
|
||||||
|
0100
|
||||||
|
No, final:
|
||||||
|
1000 (note: Don't include last token!)
|
||||||
|
No, non-final:
|
||||||
|
0010
|
||||||
|
|
||||||
Problem: If a quantifier is matching, we're adding a lot of open partials
|
Problem: If a quantifier is matching, we're adding a lot of open partials
|
||||||
Question: Is it worth doing a lookahead, to see if we add?
|
|
||||||
'''
|
'''
|
||||||
cached_match = <uint64_t>cache.get(state.state.key)
|
cached_match = <uint64_t>cache.get(state.pattern.key)
|
||||||
cdef char is_match
|
cdef char is_match
|
||||||
if cached_match == 0:
|
if cached_match == 0:
|
||||||
is_match = get_is_match(state, token, extra_attrs)
|
is_match = get_is_match(state, token, extra_attrs)
|
||||||
cached_match = is_match + 1
|
cached_match = is_match + 1
|
||||||
cache.set(state.state.key, <void*>cached_match)
|
cache.set(state.pattern.key, <void*>cached_match)
|
||||||
elif cached_match == 1:
|
elif cached_match == 1:
|
||||||
is_match = 0
|
is_match = 0
|
||||||
else:
|
else:
|
||||||
is_match = 1
|
is_match = 1
|
||||||
quantifier = get_quantifier(state, token)
|
quantifier = get_quantifier(state)
|
||||||
is_final = get_is_final(state, token)
|
is_final = get_is_final(state)
|
||||||
|
if quantifier == ZERO:
|
||||||
|
is_match = not is_match
|
||||||
|
quantifier = ONE
|
||||||
if quantifier == ONE:
|
if quantifier == ONE:
|
||||||
if not is_match:
|
if is_match and is_final:
|
||||||
return ActionC(is_match=0, keep_state=0, advance_state=0)
|
# Yes, final: 1000
|
||||||
elif is_final:
|
return ActionC(1, 0, 0, 0)
|
||||||
return ActionC(is_match=1, keep_state=0, advance_state=0)
|
elif is_match and not is_final:
|
||||||
else:
|
# Yes, non-final: 0100
|
||||||
return ActionC(is_match=0, keep_state=0, advance_state=1)
|
return ActionC(0, 1, 0, 0)
|
||||||
|
elif not is_match and is_final:
|
||||||
|
# No, final: 0000
|
||||||
|
return ActionC(0, 0, 0, 0)
|
||||||
|
else:
|
||||||
|
# No, non-final 0000
|
||||||
|
return ActionC(0, 0, 0, 0)
|
||||||
|
|
||||||
elif quantifier == ZERO_PLUS:
|
elif quantifier == ZERO_PLUS:
|
||||||
if is_final:
|
if is_match and is_final:
|
||||||
return ActionC(is_match=1, keep_state=0, advance_state=0)
|
# Yes, final: 1001
|
||||||
elif is_match:
|
return ActionC(1, 0, 0, 1)
|
||||||
return ActionC(is_match=0, keep_state=1, advance_state=1)
|
elif is_match and not is_final:
|
||||||
else:
|
# Yes, non-final: 0011
|
||||||
return ActionC(is_match=0, keep_state=1, advance_state=0)
|
return ActionC(0, 0, 1, 1)
|
||||||
|
elif not is_match and is_final:
|
||||||
|
# No, final 1000 (note: Don't include last token!)
|
||||||
|
return ActionC(2, 0, 0, 0)
|
||||||
|
else:
|
||||||
|
# No, non-final 0010
|
||||||
|
return ActionC(0, 0, 1, 0)
|
||||||
elif quantifier == ZERO_ONE:
|
elif quantifier == ZERO_ONE:
|
||||||
if is_final:
|
if is_match and is_final:
|
||||||
return ActionC(is_match=1, keep_state=0, advance_state=0)
|
# Yes, final: 1000
|
||||||
elif is_match:
|
return ActionC(1, 0, 0, 0)
|
||||||
if state.last_action.keep_state:
|
elif is_match and not is_final:
|
||||||
return ActionC(is_match=0, keep_state=0, advance_state=1)
|
# Yes, non-final: 0100
|
||||||
else:
|
return ActionC(0, 1, 0, 0)
|
||||||
return ActionC(is_match=0, keep_state=1, advance_state=1)
|
elif not is_match and is_final:
|
||||||
|
# No, final 1000 (note: Don't include last token!)
|
||||||
|
return ActionC(2, 0, 0, 0)
|
||||||
|
else:
|
||||||
|
# No, non-final 0010
|
||||||
|
return ActionC(0, 0, 1, 0)
|
||||||
else:
|
else:
|
||||||
print(quantifier, is_match, is_final)
|
print(quantifier, is_match, is_final)
|
||||||
raise ValueError
|
raise ValueError
|
||||||
|
|
||||||
|
|
||||||
cdef char get_is_match(PatternStateC state, const TokenC* token, const attr_t* extra_attrs) nogil:
|
cdef char get_is_match(PatternStateC state, const TokenC* token, const attr_t* extra_attrs) nogil:
|
||||||
spec = state.state
|
spec = state.pattern
|
||||||
for attr in spec.attrs[:spec.nr_attr]:
|
for attr in spec.attrs[:spec.nr_attr]:
|
||||||
if get_token_attr(token, attr.attr) != attr.value:
|
if get_token_attr(token, attr.attr) != attr.value:
|
||||||
return 0
|
return 0
|
||||||
|
@ -192,15 +245,15 @@ cdef char get_is_match(PatternStateC state, const TokenC* token, const attr_t* e
|
||||||
return 1
|
return 1
|
||||||
|
|
||||||
|
|
||||||
cdef char get_is_final(PatternStateC state, const TokenC* token) nogil:
|
cdef char get_is_final(PatternStateC state) nogil:
|
||||||
if state.state[1].attrs[0].attr == ID and state.state[1].nr_attr == 0:
|
if state.pattern[1].attrs[0].attr == ID and state.pattern[1].nr_attr == 0:
|
||||||
return 1
|
return 1
|
||||||
else:
|
else:
|
||||||
return 0
|
return 0
|
||||||
|
|
||||||
|
|
||||||
cdef char get_quantifier(PatternStateC state, const TokenC* token) nogil:
|
cdef char get_quantifier(PatternStateC state) nogil:
|
||||||
return state.state.quantifier
|
return state.pattern.quantifier
|
||||||
|
|
||||||
|
|
||||||
cdef TokenPatternC* init_pattern(Pool mem, attr_t entity_id,
|
cdef TokenPatternC* init_pattern(Pool mem, attr_t entity_id,
|
||||||
|
@ -232,7 +285,7 @@ cdef attr_t get_pattern_key(const TokenPatternC* pattern) nogil:
|
||||||
def _convert_strings(token_specs, string_store):
|
def _convert_strings(token_specs, string_store):
|
||||||
# Support 'syntactic sugar' operator '+', as combination of ONE, ZERO_PLUS
|
# Support 'syntactic sugar' operator '+', as combination of ONE, ZERO_PLUS
|
||||||
operators = {'*': (ZERO_PLUS,), '+': (ONE, ZERO_PLUS),
|
operators = {'*': (ZERO_PLUS,), '+': (ONE, ZERO_PLUS),
|
||||||
'?': (ZERO_ONE,), '1': (ONE,)}
|
'?': (ZERO_ONE,), '1': (ONE,), '!': (ZERO,)}
|
||||||
tokens = []
|
tokens = []
|
||||||
op = ONE
|
op = ONE
|
||||||
for spec in token_specs:
|
for spec in token_specs:
|
||||||
|
@ -392,6 +445,10 @@ cdef class Matcher:
|
||||||
`doc[start:end]`. The `label_id` and `key` are both integers.
|
`doc[start:end]`. The `label_id` and `key` are both integers.
|
||||||
"""
|
"""
|
||||||
matches = find_matches(&self.patterns[0], self.patterns.size(), doc)
|
matches = find_matches(&self.patterns[0], self.patterns.size(), doc)
|
||||||
|
for i, (key, start, end) in enumerate(matches):
|
||||||
|
on_match = self._callbacks.get(key, None)
|
||||||
|
if on_match is not None:
|
||||||
|
on_match(self, doc, i, matches)
|
||||||
return matches
|
return matches
|
||||||
|
|
||||||
def _normalize_key(self, key):
|
def _normalize_key(self, key):
|
||||||
|
|
Loading…
Reference in New Issue