mirror of https://github.com/explosion/spaCy.git
* Fix matching on extension attrs and predicates * Fix detection of match_id when using extension attributes. The match ID is stored as the last entry in the pattern. We were checking for this with nr_attr == 0, which didn't account for extension attributes. * Fix handling of predicates. The wrong count was being passed through, so even patterns that didn't have a predicate were being checked. * Fix regex pattern * Fix matcher set value test
This commit is contained in:
parent
f73d01aa32
commit
0d1ca15b13
|
@ -44,7 +44,7 @@ cdef find_matches(TokenPatternC** patterns, int n, Doc doc, extensions=None,
|
|||
cdef Pool mem = Pool()
|
||||
predicate_cache = <char*>mem.alloc(doc.length * len(predicates), sizeof(char))
|
||||
if extensions is not None and len(extensions) >= 1:
|
||||
nr_extra_attr = max(extensions.values())
|
||||
nr_extra_attr = max(extensions.values()) + 1
|
||||
extra_attr_values = <attr_t*>mem.alloc(doc.length * nr_extra_attr, sizeof(attr_t))
|
||||
else:
|
||||
nr_extra_attr = 0
|
||||
|
@ -60,9 +60,8 @@ cdef find_matches(TokenPatternC** patterns, int n, Doc doc, extensions=None,
|
|||
for i in range(doc.length):
|
||||
for j in range(n):
|
||||
states.push_back(PatternStateC(patterns[j], i, 0))
|
||||
transition_states(states, matches, predicate_cache,
|
||||
transition_states(states, matches, &predicate_cache[i],
|
||||
doc[i], extra_attr_values, predicates)
|
||||
predicate_cache += nr_predicate
|
||||
extra_attr_values += nr_extra_attr
|
||||
# Handle matches that end in 0-width patterns
|
||||
finish_states(matches, states)
|
||||
|
@ -74,6 +73,7 @@ cdef find_matches(TokenPatternC** patterns, int n, Doc doc, extensions=None,
|
|||
matches[i].start,
|
||||
matches[i].start+matches[i].length
|
||||
)
|
||||
|
||||
# We need to deduplicate, because we could otherwise arrive at the same
|
||||
# match through two paths, e.g. .?.? matching 'a'. Are we matching the
|
||||
# first .?, or the second .? -- it doesn't matter, it's just one match.
|
||||
|
@ -89,7 +89,8 @@ cdef attr_t get_ent_id(const TokenPatternC* pattern) nogil:
|
|||
# showed this wasn't the case when we had a reject-and-continue before a
|
||||
# match. I still don't really understand what's going on here, but this
|
||||
# workaround does resolve the issue.
|
||||
while pattern.attrs.attr != ID and pattern.nr_attr > 0:
|
||||
while pattern.attrs.attr != ID and \
|
||||
(pattern.nr_attr > 0 or pattern.nr_extra_attr > 0 or pattern.nr_py > 0):
|
||||
pattern += 1
|
||||
return pattern.attrs.value
|
||||
|
||||
|
@ -101,13 +102,17 @@ cdef void transition_states(vector[PatternStateC]& states, vector[MatchC]& match
|
|||
cdef vector[PatternStateC] new_states
|
||||
cdef int nr_predicate = len(py_predicates)
|
||||
for i in range(states.size()):
|
||||
if states[i].pattern.nr_py != 0:
|
||||
if states[i].pattern.nr_py >= 1:
|
||||
update_predicate_cache(cached_py_predicates,
|
||||
states[i].pattern, token, py_predicates)
|
||||
for i in range(states.size()):
|
||||
action = get_action(states[i], token.c, extra_attrs,
|
||||
cached_py_predicates, nr_predicate)
|
||||
cached_py_predicates)
|
||||
if action == REJECT:
|
||||
continue
|
||||
# Keep only a subset of states (the active ones). Index q is the
|
||||
# states which are still alive. If we reject a state, we overwrite
|
||||
# it in the states list, because q doesn't advance.
|
||||
state = states[i]
|
||||
states[q] = state
|
||||
while action in (RETRY, RETRY_ADVANCE, RETRY_EXTEND):
|
||||
|
@ -126,7 +131,7 @@ cdef void transition_states(vector[PatternStateC]& states, vector[MatchC]& match
|
|||
update_predicate_cache(cached_py_predicates,
|
||||
states[q].pattern, token, py_predicates)
|
||||
action = get_action(states[q], token.c, extra_attrs,
|
||||
cached_py_predicates, nr_predicate)
|
||||
cached_py_predicates)
|
||||
if action == REJECT:
|
||||
pass
|
||||
elif action == ADVANCE:
|
||||
|
@ -154,8 +159,8 @@ cdef void transition_states(vector[PatternStateC]& states, vector[MatchC]& match
|
|||
states.push_back(new_states[i])
|
||||
|
||||
|
||||
cdef void update_predicate_cache(char* cache,
|
||||
const TokenPatternC* pattern, Token token, predicates):
|
||||
cdef int update_predicate_cache(char* cache,
|
||||
const TokenPatternC* pattern, Token token, predicates) except -1:
|
||||
# If the state references any extra predicates, check whether they match.
|
||||
# These are cached, so that we don't call these potentially expensive
|
||||
# Python functions more than we need to.
|
||||
|
@ -192,7 +197,7 @@ cdef void finish_states(vector[MatchC]& matches, vector[PatternStateC]& states)
|
|||
|
||||
cdef action_t get_action(PatternStateC state,
|
||||
const TokenC* token, const attr_t* extra_attrs,
|
||||
const char* predicate_matches, int nr_predicate) nogil:
|
||||
const char* predicate_matches) nogil:
|
||||
'''We need to consider:
|
||||
|
||||
a) Does the token match the specification? [Yes, No]
|
||||
|
@ -252,7 +257,7 @@ cdef action_t get_action(PatternStateC state,
|
|||
Problem: If a quantifier is matching, we're adding a lot of open partials
|
||||
'''
|
||||
cdef char is_match
|
||||
is_match = get_is_match(state, token, extra_attrs, predicate_matches, nr_predicate)
|
||||
is_match = get_is_match(state, token, extra_attrs, predicate_matches)
|
||||
quantifier = get_quantifier(state)
|
||||
is_final = get_is_final(state)
|
||||
if quantifier == ZERO:
|
||||
|
@ -303,9 +308,9 @@ cdef action_t get_action(PatternStateC state,
|
|||
|
||||
cdef char get_is_match(PatternStateC state,
|
||||
const TokenC* token, const attr_t* extra_attrs,
|
||||
const char* predicate_matches, int nr_predicate) nogil:
|
||||
for i in range(nr_predicate):
|
||||
if predicate_matches[i] == -1:
|
||||
const char* predicate_matches) nogil:
|
||||
for i in range(state.pattern.nr_py):
|
||||
if predicate_matches[state.pattern.py_predicates[i]] == -1:
|
||||
return 0
|
||||
spec = state.pattern
|
||||
for attr in spec.attrs[:spec.nr_attr]:
|
||||
|
@ -333,7 +338,7 @@ DEF PADDING = 5
|
|||
|
||||
cdef TokenPatternC* init_pattern(Pool mem, attr_t entity_id, object token_specs) except NULL:
|
||||
pattern = <TokenPatternC*>mem.alloc(len(token_specs) + 1, sizeof(TokenPatternC))
|
||||
cdef int i
|
||||
cdef int i, index
|
||||
for i, (quantifier, spec, extensions, predicates) in enumerate(token_specs):
|
||||
pattern[i].quantifier = quantifier
|
||||
pattern[i].attrs = <AttrValueC*>mem.alloc(len(spec), sizeof(AttrValueC))
|
||||
|
@ -356,11 +361,13 @@ cdef TokenPatternC* init_pattern(Pool mem, attr_t entity_id, object token_specs)
|
|||
pattern[i].attrs[0].attr = ID
|
||||
pattern[i].attrs[0].value = entity_id
|
||||
pattern[i].nr_attr = 0
|
||||
pattern[i].nr_extra_attr = 0
|
||||
pattern[i].nr_py = 0
|
||||
return pattern
|
||||
|
||||
|
||||
cdef attr_t get_pattern_key(const TokenPatternC* pattern) nogil:
|
||||
while pattern.nr_attr != 0:
|
||||
while pattern.nr_attr != 0 or pattern.nr_extra_attr != 0 or pattern.nr_py != 0:
|
||||
pattern += 1
|
||||
id_attr = pattern[0].attrs[0]
|
||||
if id_attr.attr != ID:
|
||||
|
@ -384,7 +391,6 @@ def _preprocess_pattern(token_specs, string_store, extensions_table, extra_predi
|
|||
extra_predicates.
|
||||
"""
|
||||
tokens = []
|
||||
seen_predicates = {}
|
||||
for spec in token_specs:
|
||||
if not spec:
|
||||
# Signifier for 'any token'
|
||||
|
@ -393,7 +399,7 @@ def _preprocess_pattern(token_specs, string_store, extensions_table, extra_predi
|
|||
ops = _get_operators(spec)
|
||||
attr_values = _get_attr_values(spec, string_store)
|
||||
extensions = _get_extensions(spec, string_store, extensions_table)
|
||||
predicates = _get_extra_predicates(spec, extra_predicates, seen_predicates)
|
||||
predicates = _get_extra_predicates(spec, extra_predicates)
|
||||
for op in ops:
|
||||
tokens.append((op, list(attr_values), list(extensions), list(predicates)))
|
||||
return tokens
|
||||
|
@ -430,6 +436,7 @@ class _RegexPredicate(object):
|
|||
self.value = re.compile(value)
|
||||
self.predicate = predicate
|
||||
self.is_extension = is_extension
|
||||
self.key = (attr, self.predicate, srsly.json_dumps(value, sort_keys=True))
|
||||
assert self.predicate == 'REGEX'
|
||||
|
||||
def __call__(self, Token token):
|
||||
|
@ -447,6 +454,7 @@ class _SetMemberPredicate(object):
|
|||
self.value = set(get_string_id(v) for v in value)
|
||||
self.predicate = predicate
|
||||
self.is_extension = is_extension
|
||||
self.key = (attr, self.predicate, srsly.json_dumps(value, sort_keys=True))
|
||||
assert self.predicate in ('IN', 'NOT_IN')
|
||||
|
||||
def __call__(self, Token token):
|
||||
|
@ -459,6 +467,9 @@ class _SetMemberPredicate(object):
|
|||
else:
|
||||
return value not in self.value
|
||||
|
||||
def __repr__(self):
|
||||
return repr(('SetMemberPredicate', self.i, self.attr, self.value, self.predicate))
|
||||
|
||||
|
||||
class _ComparisonPredicate(object):
|
||||
def __init__(self, i, attr, value, predicate, is_extension=False):
|
||||
|
@ -467,6 +478,7 @@ class _ComparisonPredicate(object):
|
|||
self.value = value
|
||||
self.predicate = predicate
|
||||
self.is_extension = is_extension
|
||||
self.key = (attr, self.predicate, srsly.json_dumps(value, sort_keys=True))
|
||||
assert self.predicate in ('==', '!=', '>=', '<=', '>', '<')
|
||||
|
||||
def __call__(self, Token token):
|
||||
|
@ -488,7 +500,7 @@ class _ComparisonPredicate(object):
|
|||
return value < self.value
|
||||
|
||||
|
||||
def _get_extra_predicates(spec, extra_predicates, seen_predicates):
|
||||
def _get_extra_predicates(spec, extra_predicates):
|
||||
predicate_types = {
|
||||
'REGEX': _RegexPredicate,
|
||||
'IN': _SetMemberPredicate,
|
||||
|
@ -499,6 +511,7 @@ def _get_extra_predicates(spec, extra_predicates, seen_predicates):
|
|||
'>': _ComparisonPredicate,
|
||||
'<': _ComparisonPredicate,
|
||||
}
|
||||
seen_predicates = {pred.key: pred.i for pred in extra_predicates}
|
||||
output = []
|
||||
for attr, value in spec.items():
|
||||
if isinstance(attr, basestring):
|
||||
|
@ -516,16 +529,15 @@ def _get_extra_predicates(spec, extra_predicates, seen_predicates):
|
|||
if isinstance(value, dict):
|
||||
for type_, cls in predicate_types.items():
|
||||
if type_ in value:
|
||||
key = (attr, type_, srsly.json_dumps(value[type_], sort_keys=True))
|
||||
predicate = cls(len(extra_predicates), attr, value[type_], type_)
|
||||
# Don't create a redundant predicates.
|
||||
# This helps with efficiency, as we're caching the results.
|
||||
if key in seen_predicates:
|
||||
output.append(seen_predicates[key])
|
||||
if predicate.key in seen_predicates:
|
||||
output.append(seen_predicates[predicate.key])
|
||||
else:
|
||||
predicate = cls(len(extra_predicates), attr, value[type_], type_)
|
||||
extra_predicates.append(predicate)
|
||||
output.append(predicate.i)
|
||||
seen_predicates[key] = predicate.i
|
||||
seen_predicates[predicate.key] = predicate.i
|
||||
return output
|
||||
|
||||
|
||||
|
|
|
@ -207,14 +207,13 @@ def test_matcher_set_value(en_vocab):
|
|||
assert len(matches) == 0
|
||||
|
||||
|
||||
@pytest.mark.xfail
|
||||
def test_matcher_set_value_operator(en_vocab):
|
||||
matcher = Matcher(en_vocab)
|
||||
pattern = [{"ORTH": {"IN": ["a", "the"]}, "OP": "?"}, {"ORTH": "house"}]
|
||||
matcher.add("DET_HOUSE", None, pattern)
|
||||
doc = Doc(en_vocab, words=["In", "a", "house"])
|
||||
matches = matcher(doc)
|
||||
assert len(matches) == 1
|
||||
assert len(matches) == 2
|
||||
doc = Doc(en_vocab, words=["my", "house"])
|
||||
matches = matcher(doc)
|
||||
assert len(matches) == 1
|
||||
|
|
|
@ -6,7 +6,6 @@ from spacy.matcher import Matcher
|
|||
from spacy.tokens import Token, Doc
|
||||
|
||||
|
||||
@pytest.mark.xfail
|
||||
def test_issue1971(en_vocab):
|
||||
# Possibly related to #2675 and #2671?
|
||||
matcher = Matcher(en_vocab)
|
||||
|
@ -22,21 +21,20 @@ def test_issue1971(en_vocab):
|
|||
# We could also assert length 1 here, but this is more conclusive, because
|
||||
# the real problem here is that it returns a duplicate match for a match_id
|
||||
# that's not actually in the vocab!
|
||||
assert all(match_id in en_vocab.strings for match_id, start, end in matcher(doc))
|
||||
matches = matcher(doc)
|
||||
assert all([match_id in en_vocab.strings for match_id, start, end in matches])
|
||||
|
||||
|
||||
@pytest.mark.xfail
|
||||
def test_issue_1971_2(en_vocab):
|
||||
matcher = Matcher(en_vocab)
|
||||
pattern1 = [{"LOWER": {"IN": ["eur"]}}, {"LIKE_NUM": True}]
|
||||
pattern2 = list(reversed(pattern1))
|
||||
pattern1 = [{"ORTH": "EUR", "LOWER": {"IN": ["eur"]}}, {"LIKE_NUM": True}]
|
||||
pattern2 = [{"LIKE_NUM": True}, {"ORTH": "EUR"}] #{"IN": ["EUR"]}}]
|
||||
doc = Doc(en_vocab, words=["EUR", "10", "is", "10", "EUR"])
|
||||
matcher.add("TEST", None, pattern1, pattern2)
|
||||
matcher.add("TEST1", None, pattern1, pattern2)
|
||||
matches = matcher(doc)
|
||||
assert len(matches) == 2
|
||||
|
||||
|
||||
@pytest.mark.xfail
|
||||
def test_issue_1971_3(en_vocab):
|
||||
"""Test that pattern matches correctly for multiple extension attributes."""
|
||||
Token.set_extension("a", default=1)
|
||||
|
@ -50,7 +48,6 @@ def test_issue_1971_3(en_vocab):
|
|||
assert matches == sorted([("A", 0, 1), ("A", 1, 2), ("B", 0, 1), ("B", 1, 2)])
|
||||
|
||||
|
||||
# @pytest.mark.xfail
|
||||
def test_issue_1971_4(en_vocab):
|
||||
"""Test that pattern matches correctly with multiple extension attribute
|
||||
values on a single token.
|
||||
|
|
Loading…
Reference in New Issue