From 0d1ca15b1351041469356c74db6e2727fc934836 Mon Sep 17 00:00:00 2001 From: Matthew Honnibal Date: Wed, 20 Feb 2019 21:30:39 +0100 Subject: [PATCH] =?UTF-8?q?=F0=9F=92=AB=20Fix=20bugs=20in=20matcher=20exte?= =?UTF-8?q?nsions.=20Closes=20#1971=20(#3301)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Fix matching on extension attrs and predicates * Fix detection of match_id when using extension attributes. The match ID is stored as the last entry in the pattern. We were checking for this with nr_attr == 0, which didn't account for extension attributes. * Fix handling of predicates. The wrong count was being passed through, so even patterns that didn't have a predicate were being checked. * Fix regex pattern * Fix matcher set value test --- spacy/matcher/matcher.pyx | 60 ++++++++++++++---------- spacy/tests/matcher/test_matcher_api.py | 3 +- spacy/tests/regression/test_issue1971.py | 13 ++--- 3 files changed, 42 insertions(+), 34 deletions(-) diff --git a/spacy/matcher/matcher.pyx b/spacy/matcher/matcher.pyx index f0783df3f..5f17ec867 100644 --- a/spacy/matcher/matcher.pyx +++ b/spacy/matcher/matcher.pyx @@ -44,7 +44,7 @@ cdef find_matches(TokenPatternC** patterns, int n, Doc doc, extensions=None, cdef Pool mem = Pool() predicate_cache = mem.alloc(doc.length * len(predicates), sizeof(char)) if extensions is not None and len(extensions) >= 1: - nr_extra_attr = max(extensions.values()) + nr_extra_attr = max(extensions.values()) + 1 extra_attr_values = mem.alloc(doc.length * nr_extra_attr, sizeof(attr_t)) else: nr_extra_attr = 0 @@ -60,9 +60,8 @@ cdef find_matches(TokenPatternC** patterns, int n, Doc doc, extensions=None, for i in range(doc.length): for j in range(n): states.push_back(PatternStateC(patterns[j], i, 0)) - transition_states(states, matches, predicate_cache, + transition_states(states, matches, &predicate_cache[i], doc[i], extra_attr_values, predicates) - predicate_cache += nr_predicate extra_attr_values += nr_extra_attr # Handle matches that end in 0-width patterns finish_states(matches, states) @@ -74,6 +73,7 @@ cdef find_matches(TokenPatternC** patterns, int n, Doc doc, extensions=None, matches[i].start, matches[i].start+matches[i].length ) + # We need to deduplicate, because we could otherwise arrive at the same # match through two paths, e.g. .?.? matching 'a'. Are we matching the # first .?, or the second .? -- it doesn't matter, it's just one match. @@ -89,7 +89,8 @@ cdef attr_t get_ent_id(const TokenPatternC* pattern) nogil: # showed this wasn't the case when we had a reject-and-continue before a # match. I still don't really understand what's going on here, but this # workaround does resolve the issue. - while pattern.attrs.attr != ID and pattern.nr_attr > 0: + while pattern.attrs.attr != ID and \ + (pattern.nr_attr > 0 or pattern.nr_extra_attr > 0 or pattern.nr_py > 0): pattern += 1 return pattern.attrs.value @@ -101,13 +102,17 @@ cdef void transition_states(vector[PatternStateC]& states, vector[MatchC]& match cdef vector[PatternStateC] new_states cdef int nr_predicate = len(py_predicates) for i in range(states.size()): - if states[i].pattern.nr_py != 0: + if states[i].pattern.nr_py >= 1: update_predicate_cache(cached_py_predicates, states[i].pattern, token, py_predicates) + for i in range(states.size()): action = get_action(states[i], token.c, extra_attrs, - cached_py_predicates, nr_predicate) + cached_py_predicates) if action == REJECT: continue + # Keep only a subset of states (the active ones). Index q is the + # states which are still alive. If we reject a state, we overwrite + # it in the states list, because q doesn't advance. state = states[i] states[q] = state while action in (RETRY, RETRY_ADVANCE, RETRY_EXTEND): @@ -126,7 +131,7 @@ cdef void transition_states(vector[PatternStateC]& states, vector[MatchC]& match update_predicate_cache(cached_py_predicates, states[q].pattern, token, py_predicates) action = get_action(states[q], token.c, extra_attrs, - cached_py_predicates, nr_predicate) + cached_py_predicates) if action == REJECT: pass elif action == ADVANCE: @@ -154,8 +159,8 @@ cdef void transition_states(vector[PatternStateC]& states, vector[MatchC]& match states.push_back(new_states[i]) -cdef void update_predicate_cache(char* cache, - const TokenPatternC* pattern, Token token, predicates): +cdef int update_predicate_cache(char* cache, + const TokenPatternC* pattern, Token token, predicates) except -1: # If the state references any extra predicates, check whether they match. # These are cached, so that we don't call these potentially expensive # Python functions more than we need to. @@ -192,7 +197,7 @@ cdef void finish_states(vector[MatchC]& matches, vector[PatternStateC]& states) cdef action_t get_action(PatternStateC state, const TokenC* token, const attr_t* extra_attrs, - const char* predicate_matches, int nr_predicate) nogil: + const char* predicate_matches) nogil: '''We need to consider: a) Does the token match the specification? [Yes, No] @@ -252,7 +257,7 @@ cdef action_t get_action(PatternStateC state, Problem: If a quantifier is matching, we're adding a lot of open partials ''' cdef char is_match - is_match = get_is_match(state, token, extra_attrs, predicate_matches, nr_predicate) + is_match = get_is_match(state, token, extra_attrs, predicate_matches) quantifier = get_quantifier(state) is_final = get_is_final(state) if quantifier == ZERO: @@ -303,9 +308,9 @@ cdef action_t get_action(PatternStateC state, cdef char get_is_match(PatternStateC state, const TokenC* token, const attr_t* extra_attrs, - const char* predicate_matches, int nr_predicate) nogil: - for i in range(nr_predicate): - if predicate_matches[i] == -1: + const char* predicate_matches) nogil: + for i in range(state.pattern.nr_py): + if predicate_matches[state.pattern.py_predicates[i]] == -1: return 0 spec = state.pattern for attr in spec.attrs[:spec.nr_attr]: @@ -333,7 +338,7 @@ DEF PADDING = 5 cdef TokenPatternC* init_pattern(Pool mem, attr_t entity_id, object token_specs) except NULL: pattern = mem.alloc(len(token_specs) + 1, sizeof(TokenPatternC)) - cdef int i + cdef int i, index for i, (quantifier, spec, extensions, predicates) in enumerate(token_specs): pattern[i].quantifier = quantifier pattern[i].attrs = mem.alloc(len(spec), sizeof(AttrValueC)) @@ -356,11 +361,13 @@ cdef TokenPatternC* init_pattern(Pool mem, attr_t entity_id, object token_specs) pattern[i].attrs[0].attr = ID pattern[i].attrs[0].value = entity_id pattern[i].nr_attr = 0 + pattern[i].nr_extra_attr = 0 + pattern[i].nr_py = 0 return pattern cdef attr_t get_pattern_key(const TokenPatternC* pattern) nogil: - while pattern.nr_attr != 0: + while pattern.nr_attr != 0 or pattern.nr_extra_attr != 0 or pattern.nr_py != 0: pattern += 1 id_attr = pattern[0].attrs[0] if id_attr.attr != ID: @@ -384,7 +391,6 @@ def _preprocess_pattern(token_specs, string_store, extensions_table, extra_predi extra_predicates. """ tokens = [] - seen_predicates = {} for spec in token_specs: if not spec: # Signifier for 'any token' @@ -393,7 +399,7 @@ def _preprocess_pattern(token_specs, string_store, extensions_table, extra_predi ops = _get_operators(spec) attr_values = _get_attr_values(spec, string_store) extensions = _get_extensions(spec, string_store, extensions_table) - predicates = _get_extra_predicates(spec, extra_predicates, seen_predicates) + predicates = _get_extra_predicates(spec, extra_predicates) for op in ops: tokens.append((op, list(attr_values), list(extensions), list(predicates))) return tokens @@ -430,6 +436,7 @@ class _RegexPredicate(object): self.value = re.compile(value) self.predicate = predicate self.is_extension = is_extension + self.key = (attr, self.predicate, srsly.json_dumps(value, sort_keys=True)) assert self.predicate == 'REGEX' def __call__(self, Token token): @@ -447,6 +454,7 @@ class _SetMemberPredicate(object): self.value = set(get_string_id(v) for v in value) self.predicate = predicate self.is_extension = is_extension + self.key = (attr, self.predicate, srsly.json_dumps(value, sort_keys=True)) assert self.predicate in ('IN', 'NOT_IN') def __call__(self, Token token): @@ -459,6 +467,9 @@ class _SetMemberPredicate(object): else: return value not in self.value + def __repr__(self): + return repr(('SetMemberPredicate', self.i, self.attr, self.value, self.predicate)) + class _ComparisonPredicate(object): def __init__(self, i, attr, value, predicate, is_extension=False): @@ -467,6 +478,7 @@ class _ComparisonPredicate(object): self.value = value self.predicate = predicate self.is_extension = is_extension + self.key = (attr, self.predicate, srsly.json_dumps(value, sort_keys=True)) assert self.predicate in ('==', '!=', '>=', '<=', '>', '<') def __call__(self, Token token): @@ -488,7 +500,7 @@ class _ComparisonPredicate(object): return value < self.value -def _get_extra_predicates(spec, extra_predicates, seen_predicates): +def _get_extra_predicates(spec, extra_predicates): predicate_types = { 'REGEX': _RegexPredicate, 'IN': _SetMemberPredicate, @@ -499,6 +511,7 @@ def _get_extra_predicates(spec, extra_predicates, seen_predicates): '>': _ComparisonPredicate, '<': _ComparisonPredicate, } + seen_predicates = {pred.key: pred.i for pred in extra_predicates} output = [] for attr, value in spec.items(): if isinstance(attr, basestring): @@ -516,16 +529,15 @@ def _get_extra_predicates(spec, extra_predicates, seen_predicates): if isinstance(value, dict): for type_, cls in predicate_types.items(): if type_ in value: - key = (attr, type_, srsly.json_dumps(value[type_], sort_keys=True)) + predicate = cls(len(extra_predicates), attr, value[type_], type_) # Don't create a redundant predicates. # This helps with efficiency, as we're caching the results. - if key in seen_predicates: - output.append(seen_predicates[key]) + if predicate.key in seen_predicates: + output.append(seen_predicates[predicate.key]) else: - predicate = cls(len(extra_predicates), attr, value[type_], type_) extra_predicates.append(predicate) output.append(predicate.i) - seen_predicates[key] = predicate.i + seen_predicates[predicate.key] = predicate.i return output diff --git a/spacy/tests/matcher/test_matcher_api.py b/spacy/tests/matcher/test_matcher_api.py index 7f7ebfc73..6ece07482 100644 --- a/spacy/tests/matcher/test_matcher_api.py +++ b/spacy/tests/matcher/test_matcher_api.py @@ -207,14 +207,13 @@ def test_matcher_set_value(en_vocab): assert len(matches) == 0 -@pytest.mark.xfail def test_matcher_set_value_operator(en_vocab): matcher = Matcher(en_vocab) pattern = [{"ORTH": {"IN": ["a", "the"]}, "OP": "?"}, {"ORTH": "house"}] matcher.add("DET_HOUSE", None, pattern) doc = Doc(en_vocab, words=["In", "a", "house"]) matches = matcher(doc) - assert len(matches) == 1 + assert len(matches) == 2 doc = Doc(en_vocab, words=["my", "house"]) matches = matcher(doc) assert len(matches) == 1 diff --git a/spacy/tests/regression/test_issue1971.py b/spacy/tests/regression/test_issue1971.py index ecc7ebda1..e7273a5b0 100644 --- a/spacy/tests/regression/test_issue1971.py +++ b/spacy/tests/regression/test_issue1971.py @@ -6,7 +6,6 @@ from spacy.matcher import Matcher from spacy.tokens import Token, Doc -@pytest.mark.xfail def test_issue1971(en_vocab): # Possibly related to #2675 and #2671? matcher = Matcher(en_vocab) @@ -22,21 +21,20 @@ def test_issue1971(en_vocab): # We could also assert length 1 here, but this is more conclusive, because # the real problem here is that it returns a duplicate match for a match_id # that's not actually in the vocab! - assert all(match_id in en_vocab.strings for match_id, start, end in matcher(doc)) + matches = matcher(doc) + assert all([match_id in en_vocab.strings for match_id, start, end in matches]) -@pytest.mark.xfail def test_issue_1971_2(en_vocab): matcher = Matcher(en_vocab) - pattern1 = [{"LOWER": {"IN": ["eur"]}}, {"LIKE_NUM": True}] - pattern2 = list(reversed(pattern1)) + pattern1 = [{"ORTH": "EUR", "LOWER": {"IN": ["eur"]}}, {"LIKE_NUM": True}] + pattern2 = [{"LIKE_NUM": True}, {"ORTH": "EUR"}] #{"IN": ["EUR"]}}] doc = Doc(en_vocab, words=["EUR", "10", "is", "10", "EUR"]) - matcher.add("TEST", None, pattern1, pattern2) + matcher.add("TEST1", None, pattern1, pattern2) matches = matcher(doc) assert len(matches) == 2 -@pytest.mark.xfail def test_issue_1971_3(en_vocab): """Test that pattern matches correctly for multiple extension attributes.""" Token.set_extension("a", default=1) @@ -50,7 +48,6 @@ def test_issue_1971_3(en_vocab): assert matches == sorted([("A", 0, 1), ("A", 1, 2), ("B", 0, 1), ("B", 1, 2)]) -# @pytest.mark.xfail def test_issue_1971_4(en_vocab): """Test that pattern matches correctly with multiple extension attribute values on a single token.