diff --git a/spacy/matcher2.pyx b/spacy/matcher2.pyx new file mode 100644 index 000000000..ff90e644d --- /dev/null +++ b/spacy/matcher2.pyx @@ -0,0 +1,399 @@ +# cython: infer_types=True +from libcpp.vector cimport vector +from libc.stdint cimport int32_t, uint64_t +from preshed.maps cimport PreshMap +from cymem.cymem cimport Pool +from murmurhash.mrmr cimport hash64 +from .typedefs cimport attr_t, hash_t +from .structs cimport TokenC +from .lexeme cimport attr_id_t +from .vocab cimport Vocab +from .tokens.doc cimport Doc +from .tokens.doc cimport get_token_attr +from .attrs cimport ID, attr_id_t, NULL_ATTR +from .attrs import IDS + + +cdef enum quantifier_t: + ZERO + ZERO_ONE + ZERO_PLUS + ONE + ONE_PLUS + + +cdef struct AttrValueC: + attr_id_t attr + attr_t value + + +cdef struct TokenPatternC: + AttrValueC* attrs + int32_t nr_attr + quantifier_t quantifier + hash_t key + + +cdef struct ActionC: + char is_match + char keep_state + char advance_state + + +cdef struct PatternStateC: + TokenPatternC* state + int32_t pattern_id + int32_t start + ActionC last_action + + +cdef struct MatchC: + int32_t pattern_id + int32_t start + int32_t end + + +cdef find_matches(TokenPatternC** patterns, int n, Doc doc): + cdef vector[PatternStateC] init_states + cdef ActionC null_action = ActionC(-1, -1, -1) + for i in range(n): + init_states.push_back(PatternStateC(patterns[i], i, -1, last_action=null_action)) + cdef vector[PatternStateC] curr_states + cdef vector[PatternStateC] nexts + cdef vector[MatchC] matches + cdef PreshMap cache = PreshMap() + cdef Pool mem = Pool() + # TODO: Prefill this with the extra attribute values. + extra_attrs = mem.alloc(len(doc), sizeof(attr_t*)) + for i in range(doc.length): + nexts.clear() + for j in range(curr_states.size()): + action = get_action(curr_states[j], &doc.c[i], extra_attrs[i], cache) + transition(matches, nexts, + action, curr_states[j], i) + for j in range(init_states.size()): + action = get_action(init_states[j], &doc.c[i], extra_attrs[i], cache) + transition(matches, nexts, + action, init_states[j], i) + nexts, curr_states = curr_states, nexts + # Filter out matches that have a longer equivalent. + longest_matches = {} + for i in range(matches.size()): + key = matches[i].pattern_id, matches[i].start + length = matches[i].end - matches[i].start + if key not in longest_matches or length > longest_matches[key]: + longest_matches[key] = length + return [(pattern_id, start, length-start) + for (pattern_id, start), length in longest_matches] + + +cdef void transition(vector[MatchC]& matches, vector[PatternStateC]& nexts, + ActionC action, PatternStateC state, int token) except *: + if state.start == -1: + state.start = token + if action.is_match: + matches.push_back( + MatchC(pattern_id=state.pattern_id, start=state.start, end=token+1)) + if action.keep_state: + nexts.push_back(PatternStateC(pattern_id=pattern_id, + start=state.start, state=state.state, last_action=action)) + if action.advance_state: + nexts.push_back(PatternStateC(pattern_id=pattern_id, + start=state.start, state=state.state+1, last_action=action)) + + +cdef ActionC get_action(PatternStateC state, const TokenC* token, const attr_t* extra_attrs, + PreshMap cache) except *: + '''We need to consider: + + a) Does the token match the specification? [Yes, No] + b) What's the quantifier? [1, 0+, ?] + c) Is this the last specification? [final, non-final] + + We therefore have 12 cases to consider. For each case, we need to know + whether to emit a match, whether to keep the current state in the partials, + and whether to add an advanced state to the partials. + + We therefore have eight possible results for these three booleans, which + we'll code as 000, 001 etc. + + 1: + - Match, final: + 100 + - Match, non-final: + 001 + - No match: + 000 + 0+: + - Match, final: + 100 + - Match, non-final: + 011 + - Non-match, final: + 100 + - Non-match, non-final: + 010 + + Problem: If a quantifier is matching, we're adding a lot of open partials + Question: Is it worth doing a lookahead, to see if we add? + ''' + cached_match = cache.get(state.state.key) + cdef char is_match + if cached_match == 0: + is_match = get_is_match(state, token, extra_attrs) + cached_match = is_match + 1 + cache.set(state.state.key, cached_match) + elif cached_match == 1: + is_match = 0 + else: + is_match = 1 + quantifier = get_quantifier(state, token) + is_final = get_is_final(state, token) + if quantifier == ONE: + if not is_match: + return ActionC(is_match=0, keep_state=0, advance_state=0) + elif is_final: + return ActionC(is_match=1, keep_state=0, advance_state=0) + else: + return ActionC(is_match=0, keep_state=0, advance_state=1) + elif quantifier == ZERO_PLUS: + if is_final: + return ActionC(is_match=1, keep_state=0, advance_state=0) + elif is_match: + return ActionC(is_match=0, keep_state=1, advance_state=1) + else: + return ActionC(is_match=0, keep_state=1, advance_state=0) + elif quantifier == ZERO_ONE: + if is_final: + return ActionC(is_match=1, keep_state=0, advance_state=0) + elif is_match: + if state.last_action.keep_state: + return ActionC(is_match=0, keep_state=0, advance_state=1) + else: + return ActionC(is_match=0, keep_state=1, advance_state=1) + else: + print(quantifier, is_match, is_final) + raise ValueError + + +cdef char get_is_match(PatternStateC state, const TokenC* token, const attr_t* extra_attrs) nogil: + spec = state.state + for attr in spec.attrs[:spec.nr_attr]: + if get_token_attr(token, attr.attr) != attr.value: + return 0 + else: + return 1 + + +cdef char get_is_final(PatternStateC state, const TokenC* token) nogil: + if state.state[1].attrs[0].attr == ID and state.state[1].nr_attr == 0: + return 1 + else: + return 0 + + +cdef char get_quantifier(PatternStateC state, const TokenC* token) nogil: + return state.state.quantifier + + +cdef TokenPatternC* init_pattern(Pool mem, attr_t entity_id, + object token_specs) except NULL: + pattern = mem.alloc(len(token_specs) + 1, sizeof(TokenPatternC)) + cdef int i + for i, (quantifier, spec) in enumerate(token_specs): + pattern[i].quantifier = quantifier + pattern[i].attrs = mem.alloc(len(spec), sizeof(AttrValueC)) + pattern[i].nr_attr = len(spec) + for j, (attr, value) in enumerate(spec): + pattern[i].attrs[j].attr = attr + pattern[i].attrs[j].value = value + pattern[i].key = hash64(pattern[i].attrs, pattern[i].nr_attr * sizeof(AttrValueC), 0) + i = len(token_specs) + pattern[i].attrs = mem.alloc(2, sizeof(AttrValueC)) + pattern[i].attrs[0].attr = ID + pattern[i].attrs[0].value = entity_id + pattern[i].nr_attr = 0 + return pattern + + +cdef attr_t get_pattern_key(const TokenPatternC* pattern) nogil: + while pattern.nr_attr != 0: + pattern += 1 + id_attr = pattern[0].attrs[0] + return id_attr.value + +def _convert_strings(token_specs, string_store): + # Support 'syntactic sugar' operator '+', as combination of ONE, ZERO_PLUS + operators = {'*': (ZERO_PLUS,), '+': (ONE, ZERO_PLUS), + '?': (ZERO_ONE,), '1': (ONE,)} + tokens = [] + op = ONE + for spec in token_specs: + if not spec: + # Signifier for 'any token' + tokens.append((ONE, [(NULL_ATTR, 0)])) + continue + token = [] + ops = (ONE,) + for attr, value in spec.items(): + if isinstance(attr, basestring) and attr.upper() == 'OP': + if value in operators: + ops = operators[value] + else: + msg = "Unknown operator '%s'. Options: %s" + raise KeyError(msg % (value, ', '.join(operators.keys()))) + if isinstance(attr, basestring): + attr = IDS.get(attr.upper()) + if isinstance(value, basestring): + value = string_store.add(value) + if isinstance(value, bool): + value = int(value) + if attr is not None: + token.append((attr, value)) + for op in ops: + tokens.append((op, token)) + return tokens + + +cdef class Matcher: + """Match sequences of tokens, based on pattern rules.""" + cdef Pool mem + cdef vector[TokenPatternC*] patterns + cdef readonly Vocab vocab + cdef public object _patterns + cdef public object _entities + cdef public object _callbacks + + def __init__(self, vocab): + """Create the Matcher. + + vocab (Vocab): The vocabulary object, which must be shared with the + documents the matcher will operate on. + RETURNS (Matcher): The newly constructed object. + """ + self._patterns = {} + self._entities = {} + self._callbacks = {} + self.vocab = vocab + self.mem = Pool() + + def __reduce__(self): + data = (self.vocab, self._patterns, self._callbacks) + return (unpickle_matcher, data, None, None) + + def __len__(self): + """Get the number of rules added to the matcher. Note that this only + returns the number of rules (identical with the number of IDs), not the + number of individual patterns. + + RETURNS (int): The number of rules. + """ + return len(self._patterns) + + def __contains__(self, key): + """Check whether the matcher contains rules for a match ID. + + key (unicode): The match ID. + RETURNS (bool): Whether the matcher contains rules for this match ID. + """ + return self._normalize_key(key) in self._patterns + + def add(self, key, on_match, *patterns): + """Add a match-rule to the matcher. A match-rule consists of: an ID + key, an on_match callback, and one or more patterns. + + If the key exists, the patterns are appended to the previous ones, and + the previous on_match callback is replaced. The `on_match` callback + will receive the arguments `(matcher, doc, i, matches)`. You can also + set `on_match` to `None` to not perform any actions. + + A pattern consists of one or more `token_specs`, where a `token_spec` + is a dictionary mapping attribute IDs to values, and optionally a + quantifier operator under the key "op". The available quantifiers are: + + '!': Negate the pattern, by requiring it to match exactly 0 times. + '?': Make the pattern optional, by allowing it to match 0 or 1 times. + '+': Require the pattern to match 1 or more times. + '*': Allow the pattern to zero or more times. + + The + and * operators are usually interpretted "greedily", i.e. longer + matches are returned where possible. However, if you specify two '+' + and '*' patterns in a row and their matches overlap, the first + operator will behave non-greedily. This quirk in the semantics makes + the matcher more efficient, by avoiding the need for back-tracking. + + key (unicode): The match ID. + on_match (callable): Callback executed on match. + *patterns (list): List of token descritions. + """ + for pattern in patterns: + if len(pattern) == 0: + msg = ("Cannot add pattern for zero tokens to matcher.\n" + "key: {key}\n") + raise ValueError(msg.format(key=key)) + key = self._normalize_key(key) + for pattern in patterns: + specs = _convert_strings(pattern, self.vocab.strings) + self.patterns.push_back(init_pattern(self.mem, key, specs)) + self._patterns.setdefault(key, []) + self._callbacks[key] = on_match + self._patterns[key].extend(patterns) + + def remove(self, key): + """Remove a rule from the matcher. A KeyError is raised if the key does + not exist. + + key (unicode): The ID of the match rule. + """ + key = self._normalize_key(key) + self._patterns.pop(key) + self._callbacks.pop(key) + cdef int i = 0 + while i < self.patterns.size(): + pattern_key = get_pattern_key(self.patterns.at(i)) + if pattern_key == key: + self.patterns.erase(self.patterns.begin()+i) + else: + i += 1 + + def has_key(self, key): + """Check whether the matcher has a rule with a given key. + + key (string or int): The key to check. + RETURNS (bool): Whether the matcher has the rule. + """ + key = self._normalize_key(key) + return key in self._patterns + + def get(self, key, default=None): + """Retrieve the pattern stored for a key. + + key (unicode or int): The key to retrieve. + RETURNS (tuple): The rule, as an (on_match, patterns) tuple. + """ + key = self._normalize_key(key) + if key not in self._patterns: + return default + return (self._callbacks[key], self._patterns[key]) + + def __call__(self, Doc doc): + """Find all token sequences matching the supplied pattern. + + doc (Doc): The document to match over. + RETURNS (list): A list of `(key, start, end)` tuples, + describing the matches. A match tuple describes a span + `doc[start:end]`. The `label_id` and `key` are both integers. + """ + matches = find_matches(&self.patterns[0], self.patterns.size(), doc) + return matches + + +def unpickle_matcher(vocab, patterns, callbacks): + matcher = Matcher(vocab) + for key, specs in patterns.items(): + callback = callbacks.get(key, None) + matcher.add(key, callback, *specs) + return matcher + + +