From 1ea4df459dbd96842b337240059be5dd1ff44ef1 Mon Sep 17 00:00:00 2001 From: Ines Montani Date: Thu, 7 Feb 2019 09:42:25 +0100 Subject: [PATCH] =?UTF-8?q?=F0=9F=92=AB=20Break=20up=20large=20matcher.pyx?= =?UTF-8?q?=20(#3236)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Break up large matcher.pyx * Remove unused function --- setup.py | 4 +- spacy/matcher/__init__.py | 6 + spacy/matcher/dependencymatcher.pyx | 354 +++++++++++++++ spacy/matcher/matcher.pxd | 69 +++ spacy/{ => matcher}/matcher.pyx | 639 +--------------------------- spacy/matcher/phrasematcher.pyx | 210 +++++++++ 6 files changed, 658 insertions(+), 624 deletions(-) create mode 100644 spacy/matcher/__init__.py create mode 100644 spacy/matcher/dependencymatcher.pyx create mode 100644 spacy/matcher/matcher.pxd rename spacy/{ => matcher}/matcher.pyx (54%) create mode 100644 spacy/matcher/phrasematcher.pyx diff --git a/setup.py b/setup.py index 1367cb352..82191bcff 100755 --- a/setup.py +++ b/setup.py @@ -56,7 +56,9 @@ MOD_NAMES = [ "spacy.tokens.span", "spacy.tokens.token", "spacy.tokens._retokenize", - "spacy.matcher", + "spacy.matcher.matcher", + "spacy.matcher.phrasematcher", + "spacy.matcher.dependencymatcher", "spacy.syntax.ner", "spacy.symbols", "spacy.vectors", diff --git a/spacy/matcher/__init__.py b/spacy/matcher/__init__.py new file mode 100644 index 000000000..e417097b2 --- /dev/null +++ b/spacy/matcher/__init__.py @@ -0,0 +1,6 @@ +# coding: utf8 +from __future__ import unicode_literals + +from .matcher import Matcher +from .phrasematcher import PhraseMatcher +from .dependencymatcher import DependencyTreeMatcher diff --git a/spacy/matcher/dependencymatcher.pyx b/spacy/matcher/dependencymatcher.pyx new file mode 100644 index 000000000..1a4b23e11 --- /dev/null +++ b/spacy/matcher/dependencymatcher.pyx @@ -0,0 +1,354 @@ +# cython: infer_types=True +# cython: profile=True +from __future__ import unicode_literals + +from cymem.cymem cimport Pool +from preshed.maps cimport PreshMap + +from .matcher cimport Matcher +from ..vocab cimport Vocab +from ..tokens.doc cimport Doc + +from .matcher import unpickle_matcher +from ..errors import Errors + + +DELIMITER = '||' +INDEX_HEAD = 1 +INDEX_RELOP = 0 + + +cdef class DependencyTreeMatcher: + """Match dependency parse tree based on pattern rules.""" + cdef Pool mem + cdef readonly Vocab vocab + cdef readonly Matcher token_matcher + cdef public object _patterns + cdef public object _keys_to_token + cdef public object _root + cdef public object _entities + cdef public object _callbacks + cdef public object _nodes + cdef public object _tree + + def __init__(self, vocab): + """Create the DependencyTreeMatcher. + + vocab (Vocab): The vocabulary object, which must be shared with the + documents the matcher will operate on. + RETURNS (DependencyTreeMatcher): The newly constructed object. + """ + size = 20 + self.token_matcher = Matcher(vocab) + self._keys_to_token = {} + self._patterns = {} + self._root = {} + self._nodes = {} + self._tree = {} + self._entities = {} + self._callbacks = {} + self.vocab = vocab + self.mem = Pool() + + def __reduce__(self): + data = (self.vocab, self._patterns,self._tree, self._callbacks) + return (unpickle_matcher, data, None, None) + + def __len__(self): + """Get the number of rules, which are edges ,added to the dependency tree matcher. + + RETURNS (int): The number of rules. + """ + return len(self._patterns) + + def __contains__(self, key): + """Check whether the matcher contains rules for a match ID. + + key (unicode): The match ID. + RETURNS (bool): Whether the matcher contains rules for this match ID. + """ + return self._normalize_key(key) in self._patterns + + def validateInput(self, pattern, key): + idx = 0 + visitedNodes = {} + for relation in pattern: + if 'PATTERN' not in relation or 'SPEC' not in relation: + raise ValueError(Errors.E098.format(key=key)) + if idx == 0: + if not('NODE_NAME' in relation['SPEC'] and 'NBOR_RELOP' not in relation['SPEC'] and 'NBOR_NAME' not in relation['SPEC']): + raise ValueError(Errors.E099.format(key=key)) + visitedNodes[relation['SPEC']['NODE_NAME']] = True + else: + if not('NODE_NAME' in relation['SPEC'] and 'NBOR_RELOP' in relation['SPEC'] and 'NBOR_NAME' in relation['SPEC']): + raise ValueError(Errors.E100.format(key=key)) + if relation['SPEC']['NODE_NAME'] in visitedNodes or relation['SPEC']['NBOR_NAME'] not in visitedNodes: + raise ValueError(Errors.E101.format(key=key)) + visitedNodes[relation['SPEC']['NODE_NAME']] = True + visitedNodes[relation['SPEC']['NBOR_NAME']] = True + idx = idx + 1 + + def add(self, key, on_match, *patterns): + for pattern in patterns: + if len(pattern) == 0: + raise ValueError(Errors.E012.format(key=key)) + self.validateInput(pattern,key) + + key = self._normalize_key(key) + + _patterns = [] + for pattern in patterns: + token_patterns = [] + for i in range(len(pattern)): + token_pattern = [pattern[i]['PATTERN']] + token_patterns.append(token_pattern) + # self.patterns.append(token_patterns) + _patterns.append(token_patterns) + + self._patterns.setdefault(key, []) + self._callbacks[key] = on_match + self._patterns[key].extend(_patterns) + + # Add each node pattern of all the input patterns individually to the matcher. + # This enables only a single instance of Matcher to be used. + # Multiple adds are required to track each node pattern. + _keys_to_token_list = [] + for i in range(len(_patterns)): + _keys_to_token = {} + # TODO : Better ways to hash edges in pattern? + for j in range(len(_patterns[i])): + k = self._normalize_key(unicode(key)+DELIMITER+unicode(i)+DELIMITER+unicode(j)) + self.token_matcher.add(k,None,_patterns[i][j]) + _keys_to_token[k] = j + _keys_to_token_list.append(_keys_to_token) + + self._keys_to_token.setdefault(key, []) + self._keys_to_token[key].extend(_keys_to_token_list) + + _nodes_list = [] + for pattern in patterns: + nodes = {} + for i in range(len(pattern)): + nodes[pattern[i]['SPEC']['NODE_NAME']]=i + _nodes_list.append(nodes) + + self._nodes.setdefault(key, []) + self._nodes[key].extend(_nodes_list) + + # Create an object tree to traverse later on. + # This datastructure enable easy tree pattern match. + # Doc-Token based tree cannot be reused since it is memory heavy and + # tightly coupled with doc + self.retrieve_tree(patterns,_nodes_list,key) + + def retrieve_tree(self,patterns,_nodes_list,key): + _heads_list = [] + _root_list = [] + for i in range(len(patterns)): + heads = {} + root = -1 + for j in range(len(patterns[i])): + token_pattern = patterns[i][j] + if('NBOR_RELOP' not in token_pattern['SPEC']): + heads[j] = ('root',j) + root = j + else: + heads[j] = (token_pattern['SPEC']['NBOR_RELOP'],_nodes_list[i][token_pattern['SPEC']['NBOR_NAME']]) + + _heads_list.append(heads) + _root_list.append(root) + + _tree_list = [] + for i in range(len(patterns)): + tree = {} + for j in range(len(patterns[i])): + if(_heads_list[i][j][INDEX_HEAD] == j): + continue + + head = _heads_list[i][j][INDEX_HEAD] + if(head not in tree): + tree[head] = [] + tree[head].append( (_heads_list[i][j][INDEX_RELOP],j) ) + _tree_list.append(tree) + + self._tree.setdefault(key, []) + self._tree[key].extend(_tree_list) + + self._root.setdefault(key, []) + self._root[key].extend(_root_list) + + def has_key(self, key): + """Check whether the matcher has a rule with a given key. + + key (string or int): The key to check. + RETURNS (bool): Whether the matcher has the rule. + """ + key = self._normalize_key(key) + return key in self._patterns + + def get(self, key, default=None): + """Retrieve the pattern stored for a key. + + key (unicode or int): The key to retrieve. + RETURNS (tuple): The rule, as an (on_match, patterns) tuple. + """ + key = self._normalize_key(key) + if key not in self._patterns: + return default + return (self._callbacks[key], self._patterns[key]) + + def __call__(self, Doc doc): + matched_trees = [] + + matches = self.token_matcher(doc) + for key in list(self._patterns.keys()): + _patterns_list = self._patterns[key] + _keys_to_token_list = self._keys_to_token[key] + _root_list = self._root[key] + _tree_list = self._tree[key] + _nodes_list = self._nodes[key] + length = len(_patterns_list) + for i in range(length): + _keys_to_token = _keys_to_token_list[i] + _root = _root_list[i] + _tree = _tree_list[i] + _nodes = _nodes_list[i] + id_to_position = {} + for i in range(len(_nodes)): + id_to_position[i]=[] + + # This could be taken outside to improve running time..? + for match_id, start, end in matches: + if match_id in _keys_to_token: + id_to_position[_keys_to_token[match_id]].append(start) + + _node_operator_map = self.get_node_operator_map(doc,_tree,id_to_position,_nodes,_root) + length = len(_nodes) + if _root in id_to_position: + candidates = id_to_position[_root] + for candidate in candidates: + isVisited = {} + self.dfs(candidate,_root,_tree,id_to_position,doc,isVisited,_node_operator_map) + # To check if the subtree pattern is completely identified. This is a heuristic. + # This is done to reduce the complexity of exponential unordered subtree matching. + # Will give approximate matches in some cases. + if(len(isVisited) == length): + matched_trees.append((key,list(isVisited))) + + for i, (ent_id, nodes) in enumerate(matched_trees): + on_match = self._callbacks.get(ent_id) + if on_match is not None: + on_match(self, doc, i, matches) + + return matched_trees + + def dfs(self,candidate,root,tree,id_to_position,doc,isVisited,_node_operator_map): + if(root in id_to_position and candidate in id_to_position[root]): + # color the node since it is valid + isVisited[candidate] = True + if root in tree: + for root_child in tree[root]: + if candidate in _node_operator_map and root_child[INDEX_RELOP] in _node_operator_map[candidate]: + candidate_children = _node_operator_map[candidate][root_child[INDEX_RELOP]] + for candidate_child in candidate_children: + result = self.dfs( + candidate_child.i, + root_child[INDEX_HEAD], + tree, + id_to_position, + doc, + isVisited, + _node_operator_map + ) + + # Given a node and an edge operator, to return the list of nodes + # from the doc that belong to node+operator. This is used to store + # all the results beforehand to prevent unnecessary computation while + # pattern matching + # _node_operator_map[node][operator] = [...] + def get_node_operator_map(self,doc,tree,id_to_position,nodes,root): + _node_operator_map = {} + all_node_indices = nodes.values() + all_operators = [] + for node in all_node_indices: + if node in tree: + for child in tree[node]: + all_operators.append(child[INDEX_RELOP]) + all_operators = list(set(all_operators)) + + all_nodes = [] + for node in all_node_indices: + all_nodes = all_nodes + id_to_position[node] + all_nodes = list(set(all_nodes)) + + for node in all_nodes: + _node_operator_map[node] = {} + for operator in all_operators: + _node_operator_map[node][operator] = [] + + # Used to invoke methods for each operator + switcher = { + '<':self.dep, + '>':self.gov, + '>>':self.dep_chain, + '<<':self.gov_chain, + '.':self.imm_precede, + '$+':self.imm_right_sib, + '$-':self.imm_left_sib, + '$++':self.right_sib, + '$--':self.left_sib + } + for operator in all_operators: + for node in all_nodes: + _node_operator_map[node][operator] = switcher.get(operator)(doc,node) + + return _node_operator_map + + def dep(self,doc,node): + return list(doc[node].head) + + def gov(self,doc,node): + return list(doc[node].children) + + def dep_chain(self,doc,node): + return list(doc[node].ancestors) + + def gov_chain(self,doc,node): + return list(doc[node].subtree) + + def imm_precede(self,doc,node): + if node>0: + return [doc[node-1]] + return [] + + def imm_right_sib(self,doc,node): + for idx in range(list(doc[node].head.children)): + if idx == node-1: + return [doc[idx]] + return [] + + def imm_left_sib(self,doc,node): + for idx in range(list(doc[node].head.children)): + if idx == node+1: + return [doc[idx]] + return [] + + def right_sib(self,doc,node): + candidate_children = [] + for idx in range(list(doc[node].head.children)): + if idx < node: + candidate_children.append(doc[idx]) + return candidate_children + + def left_sib(self,doc,node): + candidate_children = [] + for idx in range(list(doc[node].head.children)): + if idx > node: + candidate_children.append(doc[idx]) + return candidate_children + + def _normalize_key(self, key): + if isinstance(key, basestring): + return self.vocab.strings.add(key) + else: + return key diff --git a/spacy/matcher/matcher.pxd b/spacy/matcher/matcher.pxd new file mode 100644 index 000000000..27d6ae92e --- /dev/null +++ b/spacy/matcher/matcher.pxd @@ -0,0 +1,69 @@ +from libc.stdint cimport int32_t +from libcpp.vector cimport vector +from cymem.cymem cimport Pool + +from ..vocab cimport Vocab +from ..typedefs cimport attr_t, hash_t +from ..structs cimport TokenC +from ..lexeme cimport attr_id_t + + +cdef enum action_t: + REJECT = 0000 + MATCH = 1000 + ADVANCE = 0100 + RETRY = 0010 + RETRY_EXTEND = 0011 + RETRY_ADVANCE = 0110 + MATCH_EXTEND = 1001 + MATCH_REJECT = 2000 + + +cdef enum quantifier_t: + ZERO + ZERO_ONE + ZERO_PLUS + ONE + ONE_PLUS + + +cdef struct AttrValueC: + attr_id_t attr + attr_t value + +cdef struct IndexValueC: + int32_t index + attr_t value + +cdef struct TokenPatternC: + AttrValueC* attrs + int32_t* py_predicates + IndexValueC* extra_attrs + int32_t nr_attr + int32_t nr_extra_attr + int32_t nr_py + quantifier_t quantifier + hash_t key + + +cdef struct PatternStateC: + TokenPatternC* pattern + int32_t start + int32_t length + + +cdef struct MatchC: + attr_t pattern_id + int32_t start + int32_t length + + +cdef class Matcher: + cdef Pool mem + cdef vector[TokenPatternC*] patterns + cdef readonly Vocab vocab + cdef public object _patterns + cdef public object _entities + cdef public object _callbacks + cdef public object _extensions + cdef public object _extra_predicates diff --git a/spacy/matcher.pyx b/spacy/matcher/matcher.pyx similarity index 54% rename from spacy/matcher.pyx rename to spacy/matcher/matcher.pyx index 46e3a2ffe..b12385d4e 100644 --- a/spacy/matcher.pyx +++ b/spacy/matcher/matcher.pyx @@ -1,90 +1,25 @@ # cython: infer_types=True # cython: profile=True from __future__ import unicode_literals -import re -import srsly + from libcpp.vector cimport vector -from libc.stdint cimport int32_t, uint64_t, uint16_t -from preshed.maps cimport PreshMap +from libc.stdint cimport int32_t from cymem.cymem cimport Pool from murmurhash.mrmr cimport hash64 -from .typedefs cimport attr_t, hash_t -from .structs cimport TokenC -from .lexeme cimport attr_id_t -from .vocab cimport Vocab -from .tokens.doc cimport Doc -from .tokens.token cimport Token -from .tokens.doc cimport get_token_attr -from .attrs cimport ID, attr_id_t, NULL_ATTR, ORTH -from .errors import Errors, TempErrors, Warnings, deprecation_warning -from .strings import get_string_id -from .attrs import IDS -from .attrs import FLAG61 as U_ENT -from .attrs import FLAG60 as B2_ENT -from .attrs import FLAG59 as B3_ENT -from .attrs import FLAG58 as B4_ENT -from .attrs import FLAG43 as L2_ENT -from .attrs import FLAG42 as L3_ENT -from .attrs import FLAG41 as L4_ENT -from .attrs import FLAG43 as I2_ENT -from .attrs import FLAG42 as I3_ENT -from .attrs import FLAG41 as I4_ENT +import re +import srsly -DELIMITER = '||' +from ..typedefs cimport attr_t +from ..structs cimport TokenC +from ..vocab cimport Vocab +from ..tokens.doc cimport Doc, get_token_attr +from ..tokens.token cimport Token +from ..attrs cimport ID, attr_id_t, NULL_ATTR, ORTH -DELIMITER = '||' -INDEX_HEAD = 1 -INDEX_RELOP = 0 - -cdef enum action_t: - REJECT = 0000 - MATCH = 1000 - ADVANCE = 0100 - RETRY = 0010 - RETRY_EXTEND = 0011 - RETRY_ADVANCE = 0110 - MATCH_EXTEND = 1001 - MATCH_REJECT = 2000 - - -cdef enum quantifier_t: - ZERO - ZERO_ONE - ZERO_PLUS - ONE - ONE_PLUS - - -cdef struct AttrValueC: - attr_id_t attr - attr_t value - -cdef struct IndexValueC: - int32_t index - attr_t value - -cdef struct TokenPatternC: - AttrValueC* attrs - int32_t* py_predicates - IndexValueC* extra_attrs - int32_t nr_attr - int32_t nr_extra_attr - int32_t nr_py - quantifier_t quantifier - hash_t key - - -cdef struct PatternStateC: - TokenPatternC* pattern - int32_t start - int32_t length - - -cdef struct MatchC: - attr_t pattern_id - int32_t start - int32_t length +from ..errors import Errors +from ..strings import get_string_id +from ..attrs import IDS cdef find_matches(TokenPatternC** patterns, int n, Doc doc, extensions=None, @@ -93,10 +28,10 @@ cdef find_matches(TokenPatternC** patterns, int n, Doc doc, extensions=None, returned as a list of (id, start, end) tuples. To augment the compiled patterns, we optionally also take two Python lists. - + The "predicates" list contains functions that take a Python list and return a boolean value. It's mostly used for regular expressions. - + The "extra_getters" list contains functions that take a Python list and return an attr ID. It's mostly used for extension attributes. ''' @@ -236,7 +171,7 @@ cdef void update_predicate_cache(char* cache, else: raise ValueError("Unexpected value: %s" % result) - + cdef void finish_states(vector[MatchC]& matches, vector[PatternStateC]& states) except *: '''Handle states that end in zero-width patterns.''' cdef PatternStateC state @@ -643,14 +578,6 @@ def _get_extensions(spec, string_store, name2index): cdef class Matcher: """Match sequences of tokens, based on pattern rules.""" - cdef Pool mem - cdef vector[TokenPatternC*] patterns - cdef readonly Vocab vocab - cdef public object _patterns - cdef public object _entities - cdef public object _callbacks - cdef public object _extensions - cdef public object _extra_predicates def __init__(self, vocab): """Create the Matcher. @@ -809,537 +736,3 @@ def unpickle_matcher(vocab, patterns, callbacks): callback = callbacks.get(key, None) matcher.add(key, callback, *specs) return matcher - - -def _get_longest_matches(matches): - '''Filter out matches that have a longer equivalent.''' - longest_matches = {} - for pattern_id, start, end in matches: - key = (pattern_id, start) - length = end-start - if key not in longest_matches or length > longest_matches[key]: - longest_matches[key] = length - return [(pattern_id, start, start+length) - for (pattern_id, start), length in longest_matches.items()] - - -def get_bilou(length): - if length == 0: - raise ValueError("Length must be >= 1") - elif length == 1: - return [U_ENT] - elif length == 2: - return [B2_ENT, L2_ENT] - elif length == 3: - return [B3_ENT, I3_ENT, L3_ENT] - else: - return [B4_ENT, I4_ENT] + [I4_ENT] * (length-3) + [L4_ENT] - - -cdef class PhraseMatcher: - cdef Pool mem - cdef Vocab vocab - cdef Matcher matcher - cdef PreshMap phrase_ids - cdef int max_length - cdef attr_id_t attr - cdef public object _callbacks - cdef public object _patterns - - def __init__(self, Vocab vocab, max_length=0, attr='ORTH'): - if max_length != 0: - deprecation_warning(Warnings.W010) - self.mem = Pool() - self.max_length = max_length - self.vocab = vocab - self.matcher = Matcher(self.vocab) - if isinstance(attr, long): - self.attr = attr - else: - self.attr = self.vocab.strings[attr] - self.phrase_ids = PreshMap() - abstract_patterns = [ - [{U_ENT: True}], - [{B2_ENT: True}, {L2_ENT: True}], - [{B3_ENT: True}, {I3_ENT: True}, {L3_ENT: True}], - [{B4_ENT: True}, {I4_ENT: True}, {I4_ENT: True, "OP": "+"}, {L4_ENT: True}], - ] - self.matcher.add('Candidate', None, *abstract_patterns) - self._callbacks = {} - - def __len__(self): - """Get the number of rules added to the matcher. Note that this only - returns the number of rules (identical with the number of IDs), not the - number of individual patterns. - - RETURNS (int): The number of rules. - """ - return len(self.phrase_ids) - - def __contains__(self, key): - """Check whether the matcher contains rules for a match ID. - - key (unicode): The match ID. - RETURNS (bool): Whether the matcher contains rules for this match ID. - """ - cdef hash_t ent_id = self.matcher._normalize_key(key) - return ent_id in self._callbacks - - def __reduce__(self): - return (self.__class__, (self.vocab,), None, None) - - def add(self, key, on_match, *docs): - """Add a match-rule to the phrase-matcher. A match-rule consists of: an ID - key, an on_match callback, and one or more patterns. - - key (unicode): The match ID. - on_match (callable): Callback executed on match. - *docs (Doc): `Doc` objects representing match patterns. - """ - cdef Doc doc - cdef hash_t ent_id = self.matcher._normalize_key(key) - self._callbacks[ent_id] = on_match - cdef int length - cdef int i - cdef hash_t phrase_hash - cdef Pool mem = Pool() - for doc in docs: - length = doc.length - if length == 0: - continue - tags = get_bilou(length) - phrase_key = mem.alloc(length, sizeof(attr_t)) - for i, tag in enumerate(tags): - attr_value = self.get_lex_value(doc, i) - lexeme = self.vocab[attr_value] - lexeme.set_flag(tag, True) - phrase_key[i] = lexeme.orth - phrase_hash = hash64(phrase_key, - length * sizeof(attr_t), 0) - self.phrase_ids.set(phrase_hash, ent_id) - - def __call__(self, Doc doc): - - """Find all sequences matching the supplied patterns on the `Doc`. - - doc (Doc): The document to match over. - RETURNS (list): A list of `(key, start, end)` tuples, - describing the matches. A match tuple describes a span - `doc[start:end]`. The `label_id` and `key` are both integers. - """ - matches = [] - if self.attr == ORTH: - match_doc = doc - else: - # If we're not matching on the ORTH, match_doc will be a Doc whose - # token.orth values are the attribute values we're matching on, - # e.g. Doc(nlp.vocab, words=[token.pos_ for token in doc]) - words = [self.get_lex_value(doc, i) for i in range(len(doc))] - match_doc = Doc(self.vocab, words=words) - for _, start, end in self.matcher(match_doc): - ent_id = self.accept_match(match_doc, start, end) - if ent_id is not None: - matches.append((ent_id, start, end)) - for i, (ent_id, start, end) in enumerate(matches): - on_match = self._callbacks.get(ent_id) - if on_match is not None: - on_match(self, doc, i, matches) - return matches - - def pipe(self, stream, batch_size=1000, n_threads=1, return_matches=False, - as_tuples=False): - """Match a stream of documents, yielding them in turn. - - docs (iterable): A stream of documents. - batch_size (int): Number of documents to accumulate into a working set. - n_threads (int): The number of threads with which to work on the buffer - in parallel, if the implementation supports multi-threading. - return_matches (bool): Yield the match lists along with the docs, making - results (doc, matches) tuples. - as_tuples (bool): Interpret the input stream as (doc, context) tuples, - and yield (result, context) tuples out. - If both return_matches and as_tuples are True, the output will - be a sequence of ((doc, matches), context) tuples. - YIELDS (Doc): Documents, in order. - """ - if as_tuples: - for doc, context in stream: - matches = self(doc) - if return_matches: - yield ((doc, matches), context) - else: - yield (doc, context) - else: - for doc in stream: - matches = self(doc) - if return_matches: - yield (doc, matches) - else: - yield doc - - def accept_match(self, Doc doc, int start, int end): - cdef int i, j - cdef Pool mem = Pool() - phrase_key = mem.alloc(end-start, sizeof(attr_t)) - for i, j in enumerate(range(start, end)): - phrase_key[i] = doc.c[j].lex.orth - cdef hash_t key = hash64(phrase_key, - (end-start) * sizeof(attr_t), 0) - ent_id = self.phrase_ids.get(key) - if ent_id == 0: - return None - else: - return ent_id - - def get_lex_value(self, Doc doc, int i): - if self.attr == ORTH: - # Return the regular orth value of the lexeme - return doc.c[i].lex.orth - # Get the attribute value instead, e.g. token.pos - attr_value = get_token_attr(&doc.c[i], self.attr) - if attr_value in (0, 1): - # Value is boolean, convert to string - string_attr_value = str(attr_value) - else: - string_attr_value = self.vocab.strings[attr_value] - string_attr_name = self.vocab.strings[self.attr] - # Concatenate the attr name and value to not pollute lexeme space - # e.g. 'POS-VERB' instead of just 'VERB', which could otherwise - # create false positive matches - return 'matcher:{}-{}'.format(string_attr_name, string_attr_value) - - -cdef class DependencyTreeMatcher: - """Match dependency parse tree based on pattern rules.""" - cdef Pool mem - cdef readonly Vocab vocab - cdef readonly Matcher token_matcher - cdef public object _patterns - cdef public object _keys_to_token - cdef public object _root - cdef public object _entities - cdef public object _callbacks - cdef public object _nodes - cdef public object _tree - - def __init__(self, vocab): - """Create the DependencyTreeMatcher. - - vocab (Vocab): The vocabulary object, which must be shared with the - documents the matcher will operate on. - RETURNS (DependencyTreeMatcher): The newly constructed object. - """ - size = 20 - self.token_matcher = Matcher(vocab) - self._keys_to_token = {} - self._patterns = {} - self._root = {} - self._nodes = {} - self._tree = {} - self._entities = {} - self._callbacks = {} - self.vocab = vocab - self.mem = Pool() - - def __reduce__(self): - data = (self.vocab, self._patterns,self._tree, self._callbacks) - return (unpickle_matcher, data, None, None) - - def __len__(self): - """Get the number of rules, which are edges ,added to the dependency tree matcher. - - RETURNS (int): The number of rules. - """ - return len(self._patterns) - - def __contains__(self, key): - """Check whether the matcher contains rules for a match ID. - - key (unicode): The match ID. - RETURNS (bool): Whether the matcher contains rules for this match ID. - """ - return self._normalize_key(key) in self._patterns - - def validateInput(self, pattern, key): - idx = 0 - visitedNodes = {} - for relation in pattern: - if 'PATTERN' not in relation or 'SPEC' not in relation: - raise ValueError(Errors.E098.format(key=key)) - if idx == 0: - if not('NODE_NAME' in relation['SPEC'] and 'NBOR_RELOP' not in relation['SPEC'] and 'NBOR_NAME' not in relation['SPEC']): - raise ValueError(Errors.E099.format(key=key)) - visitedNodes[relation['SPEC']['NODE_NAME']] = True - else: - if not('NODE_NAME' in relation['SPEC'] and 'NBOR_RELOP' in relation['SPEC'] and 'NBOR_NAME' in relation['SPEC']): - raise ValueError(Errors.E100.format(key=key)) - if relation['SPEC']['NODE_NAME'] in visitedNodes or relation['SPEC']['NBOR_NAME'] not in visitedNodes: - raise ValueError(Errors.E101.format(key=key)) - visitedNodes[relation['SPEC']['NODE_NAME']] = True - visitedNodes[relation['SPEC']['NBOR_NAME']] = True - idx = idx + 1 - - def add(self, key, on_match, *patterns): - for pattern in patterns: - if len(pattern) == 0: - raise ValueError(Errors.E012.format(key=key)) - self.validateInput(pattern,key) - - key = self._normalize_key(key) - - _patterns = [] - for pattern in patterns: - token_patterns = [] - for i in range(len(pattern)): - token_pattern = [pattern[i]['PATTERN']] - token_patterns.append(token_pattern) - # self.patterns.append(token_patterns) - _patterns.append(token_patterns) - - self._patterns.setdefault(key, []) - self._callbacks[key] = on_match - self._patterns[key].extend(_patterns) - - # Add each node pattern of all the input patterns individually to the matcher. - # This enables only a single instance of Matcher to be used. - # Multiple adds are required to track each node pattern. - _keys_to_token_list = [] - for i in range(len(_patterns)): - _keys_to_token = {} - # TODO : Better ways to hash edges in pattern? - for j in range(len(_patterns[i])): - k = self._normalize_key(unicode(key)+DELIMITER+unicode(i)+DELIMITER+unicode(j)) - self.token_matcher.add(k,None,_patterns[i][j]) - _keys_to_token[k] = j - _keys_to_token_list.append(_keys_to_token) - - self._keys_to_token.setdefault(key, []) - self._keys_to_token[key].extend(_keys_to_token_list) - - _nodes_list = [] - for pattern in patterns: - nodes = {} - for i in range(len(pattern)): - nodes[pattern[i]['SPEC']['NODE_NAME']]=i - _nodes_list.append(nodes) - - self._nodes.setdefault(key, []) - self._nodes[key].extend(_nodes_list) - - # Create an object tree to traverse later on. - # This datastructure enable easy tree pattern match. - # Doc-Token based tree cannot be reused since it is memory heavy and - # tightly coupled with doc - self.retrieve_tree(patterns,_nodes_list,key) - - def retrieve_tree(self,patterns,_nodes_list,key): - _heads_list = [] - _root_list = [] - for i in range(len(patterns)): - heads = {} - root = -1 - for j in range(len(patterns[i])): - token_pattern = patterns[i][j] - if('NBOR_RELOP' not in token_pattern['SPEC']): - heads[j] = ('root',j) - root = j - else: - heads[j] = (token_pattern['SPEC']['NBOR_RELOP'],_nodes_list[i][token_pattern['SPEC']['NBOR_NAME']]) - - _heads_list.append(heads) - _root_list.append(root) - - _tree_list = [] - for i in range(len(patterns)): - tree = {} - for j in range(len(patterns[i])): - if(_heads_list[i][j][INDEX_HEAD] == j): - continue - - head = _heads_list[i][j][INDEX_HEAD] - if(head not in tree): - tree[head] = [] - tree[head].append( (_heads_list[i][j][INDEX_RELOP],j) ) - _tree_list.append(tree) - - self._tree.setdefault(key, []) - self._tree[key].extend(_tree_list) - - self._root.setdefault(key, []) - self._root[key].extend(_root_list) - - def has_key(self, key): - """Check whether the matcher has a rule with a given key. - - key (string or int): The key to check. - RETURNS (bool): Whether the matcher has the rule. - """ - key = self._normalize_key(key) - return key in self._patterns - - def get(self, key, default=None): - """Retrieve the pattern stored for a key. - - key (unicode or int): The key to retrieve. - RETURNS (tuple): The rule, as an (on_match, patterns) tuple. - """ - key = self._normalize_key(key) - if key not in self._patterns: - return default - return (self._callbacks[key], self._patterns[key]) - - def __call__(self, Doc doc): - matched_trees = [] - - matches = self.token_matcher(doc) - for key in list(self._patterns.keys()): - _patterns_list = self._patterns[key] - _keys_to_token_list = self._keys_to_token[key] - _root_list = self._root[key] - _tree_list = self._tree[key] - _nodes_list = self._nodes[key] - length = len(_patterns_list) - for i in range(length): - _keys_to_token = _keys_to_token_list[i] - _root = _root_list[i] - _tree = _tree_list[i] - _nodes = _nodes_list[i] - id_to_position = {} - for i in range(len(_nodes)): - id_to_position[i]=[] - - # This could be taken outside to improve running time..? - for match_id, start, end in matches: - if match_id in _keys_to_token: - id_to_position[_keys_to_token[match_id]].append(start) - - _node_operator_map = self.get_node_operator_map(doc,_tree,id_to_position,_nodes,_root) - length = len(_nodes) - if _root in id_to_position: - candidates = id_to_position[_root] - for candidate in candidates: - isVisited = {} - self.dfs(candidate,_root,_tree,id_to_position,doc,isVisited,_node_operator_map) - # To check if the subtree pattern is completely identified. This is a heuristic. - # This is done to reduce the complexity of exponential unordered subtree matching. - # Will give approximate matches in some cases. - if(len(isVisited) == length): - matched_trees.append((key,list(isVisited))) - - for i, (ent_id, nodes) in enumerate(matched_trees): - on_match = self._callbacks.get(ent_id) - if on_match is not None: - on_match(self, doc, i, matches) - - return matched_trees - - def dfs(self,candidate,root,tree,id_to_position,doc,isVisited,_node_operator_map): - if(root in id_to_position and candidate in id_to_position[root]): - # color the node since it is valid - isVisited[candidate] = True - if root in tree: - for root_child in tree[root]: - if candidate in _node_operator_map and root_child[INDEX_RELOP] in _node_operator_map[candidate]: - candidate_children = _node_operator_map[candidate][root_child[INDEX_RELOP]] - for candidate_child in candidate_children: - result = self.dfs( - candidate_child.i, - root_child[INDEX_HEAD], - tree, - id_to_position, - doc, - isVisited, - _node_operator_map - ) - - # Given a node and an edge operator, to return the list of nodes - # from the doc that belong to node+operator. This is used to store - # all the results beforehand to prevent unnecessary computation while - # pattern matching - # _node_operator_map[node][operator] = [...] - def get_node_operator_map(self,doc,tree,id_to_position,nodes,root): - _node_operator_map = {} - all_node_indices = nodes.values() - all_operators = [] - for node in all_node_indices: - if node in tree: - for child in tree[node]: - all_operators.append(child[INDEX_RELOP]) - all_operators = list(set(all_operators)) - - all_nodes = [] - for node in all_node_indices: - all_nodes = all_nodes + id_to_position[node] - all_nodes = list(set(all_nodes)) - - for node in all_nodes: - _node_operator_map[node] = {} - for operator in all_operators: - _node_operator_map[node][operator] = [] - - # Used to invoke methods for each operator - switcher = { - '<':self.dep, - '>':self.gov, - '>>':self.dep_chain, - '<<':self.gov_chain, - '.':self.imm_precede, - '$+':self.imm_right_sib, - '$-':self.imm_left_sib, - '$++':self.right_sib, - '$--':self.left_sib - } - for operator in all_operators: - for node in all_nodes: - _node_operator_map[node][operator] = switcher.get(operator)(doc,node) - - return _node_operator_map - - def dep(self,doc,node): - return list(doc[node].head) - - def gov(self,doc,node): - return list(doc[node].children) - - def dep_chain(self,doc,node): - return list(doc[node].ancestors) - - def gov_chain(self,doc,node): - return list(doc[node].subtree) - - def imm_precede(self,doc,node): - if node>0: - return [doc[node-1]] - return [] - - def imm_right_sib(self,doc,node): - for idx in range(list(doc[node].head.children)): - if idx == node-1: - return [doc[idx]] - return [] - - def imm_left_sib(self,doc,node): - for idx in range(list(doc[node].head.children)): - if idx == node+1: - return [doc[idx]] - return [] - - def right_sib(self,doc,node): - candidate_children = [] - for idx in range(list(doc[node].head.children)): - if idx < node: - candidate_children.append(doc[idx]) - return candidate_children - - def left_sib(self,doc,node): - candidate_children = [] - for idx in range(list(doc[node].head.children)): - if idx > node: - candidate_children.append(doc[idx]) - return candidate_children - - def _normalize_key(self, key): - if isinstance(key, basestring): - return self.vocab.strings.add(key) - else: - return key diff --git a/spacy/matcher/phrasematcher.pyx b/spacy/matcher/phrasematcher.pyx new file mode 100644 index 000000000..cb3f0c0f4 --- /dev/null +++ b/spacy/matcher/phrasematcher.pyx @@ -0,0 +1,210 @@ +# cython: infer_types=True +# cython: profile=True +from __future__ import unicode_literals + +from cymem.cymem cimport Pool +from murmurhash.mrmr cimport hash64 +from preshed.maps cimport PreshMap + +from .matcher cimport Matcher +from ..attrs cimport ORTH, attr_id_t +from ..vocab cimport Vocab +from ..tokens.doc cimport Doc, get_token_attr +from ..typedefs cimport attr_t, hash_t + +from ..errors import Warnings, deprecation_warning +from ..attrs import FLAG61 as U_ENT +from ..attrs import FLAG60 as B2_ENT +from ..attrs import FLAG59 as B3_ENT +from ..attrs import FLAG58 as B4_ENT +from ..attrs import FLAG43 as L2_ENT +from ..attrs import FLAG42 as L3_ENT +from ..attrs import FLAG41 as L4_ENT +from ..attrs import FLAG42 as I3_ENT +from ..attrs import FLAG41 as I4_ENT + + +cdef class PhraseMatcher: + cdef Pool mem + cdef Vocab vocab + cdef Matcher matcher + cdef PreshMap phrase_ids + cdef int max_length + cdef attr_id_t attr + cdef public object _callbacks + cdef public object _patterns + + def __init__(self, Vocab vocab, max_length=0, attr='ORTH'): + if max_length != 0: + deprecation_warning(Warnings.W010) + self.mem = Pool() + self.max_length = max_length + self.vocab = vocab + self.matcher = Matcher(self.vocab) + if isinstance(attr, long): + self.attr = attr + else: + self.attr = self.vocab.strings[attr] + self.phrase_ids = PreshMap() + abstract_patterns = [ + [{U_ENT: True}], + [{B2_ENT: True}, {L2_ENT: True}], + [{B3_ENT: True}, {I3_ENT: True}, {L3_ENT: True}], + [{B4_ENT: True}, {I4_ENT: True}, {I4_ENT: True, "OP": "+"}, {L4_ENT: True}], + ] + self.matcher.add('Candidate', None, *abstract_patterns) + self._callbacks = {} + + def __len__(self): + """Get the number of rules added to the matcher. Note that this only + returns the number of rules (identical with the number of IDs), not the + number of individual patterns. + + RETURNS (int): The number of rules. + """ + return len(self.phrase_ids) + + def __contains__(self, key): + """Check whether the matcher contains rules for a match ID. + + key (unicode): The match ID. + RETURNS (bool): Whether the matcher contains rules for this match ID. + """ + cdef hash_t ent_id = self.matcher._normalize_key(key) + return ent_id in self._callbacks + + def __reduce__(self): + return (self.__class__, (self.vocab,), None, None) + + def add(self, key, on_match, *docs): + """Add a match-rule to the phrase-matcher. A match-rule consists of: an ID + key, an on_match callback, and one or more patterns. + + key (unicode): The match ID. + on_match (callable): Callback executed on match. + *docs (Doc): `Doc` objects representing match patterns. + """ + cdef Doc doc + cdef hash_t ent_id = self.matcher._normalize_key(key) + self._callbacks[ent_id] = on_match + cdef int length + cdef int i + cdef hash_t phrase_hash + cdef Pool mem = Pool() + for doc in docs: + length = doc.length + if length == 0: + continue + tags = get_bilou(length) + phrase_key = mem.alloc(length, sizeof(attr_t)) + for i, tag in enumerate(tags): + attr_value = self.get_lex_value(doc, i) + lexeme = self.vocab[attr_value] + lexeme.set_flag(tag, True) + phrase_key[i] = lexeme.orth + phrase_hash = hash64(phrase_key, + length * sizeof(attr_t), 0) + self.phrase_ids.set(phrase_hash, ent_id) + + def __call__(self, Doc doc): + + """Find all sequences matching the supplied patterns on the `Doc`. + + doc (Doc): The document to match over. + RETURNS (list): A list of `(key, start, end)` tuples, + describing the matches. A match tuple describes a span + `doc[start:end]`. The `label_id` and `key` are both integers. + """ + matches = [] + if self.attr == ORTH: + match_doc = doc + else: + # If we're not matching on the ORTH, match_doc will be a Doc whose + # token.orth values are the attribute values we're matching on, + # e.g. Doc(nlp.vocab, words=[token.pos_ for token in doc]) + words = [self.get_lex_value(doc, i) for i in range(len(doc))] + match_doc = Doc(self.vocab, words=words) + for _, start, end in self.matcher(match_doc): + ent_id = self.accept_match(match_doc, start, end) + if ent_id is not None: + matches.append((ent_id, start, end)) + for i, (ent_id, start, end) in enumerate(matches): + on_match = self._callbacks.get(ent_id) + if on_match is not None: + on_match(self, doc, i, matches) + return matches + + def pipe(self, stream, batch_size=1000, n_threads=1, return_matches=False, + as_tuples=False): + """Match a stream of documents, yielding them in turn. + + docs (iterable): A stream of documents. + batch_size (int): Number of documents to accumulate into a working set. + n_threads (int): The number of threads with which to work on the buffer + in parallel, if the implementation supports multi-threading. + return_matches (bool): Yield the match lists along with the docs, making + results (doc, matches) tuples. + as_tuples (bool): Interpret the input stream as (doc, context) tuples, + and yield (result, context) tuples out. + If both return_matches and as_tuples are True, the output will + be a sequence of ((doc, matches), context) tuples. + YIELDS (Doc): Documents, in order. + """ + if as_tuples: + for doc, context in stream: + matches = self(doc) + if return_matches: + yield ((doc, matches), context) + else: + yield (doc, context) + else: + for doc in stream: + matches = self(doc) + if return_matches: + yield (doc, matches) + else: + yield doc + + def accept_match(self, Doc doc, int start, int end): + cdef int i, j + cdef Pool mem = Pool() + phrase_key = mem.alloc(end-start, sizeof(attr_t)) + for i, j in enumerate(range(start, end)): + phrase_key[i] = doc.c[j].lex.orth + cdef hash_t key = hash64(phrase_key, + (end-start) * sizeof(attr_t), 0) + ent_id = self.phrase_ids.get(key) + if ent_id == 0: + return None + else: + return ent_id + + def get_lex_value(self, Doc doc, int i): + if self.attr == ORTH: + # Return the regular orth value of the lexeme + return doc.c[i].lex.orth + # Get the attribute value instead, e.g. token.pos + attr_value = get_token_attr(&doc.c[i], self.attr) + if attr_value in (0, 1): + # Value is boolean, convert to string + string_attr_value = str(attr_value) + else: + string_attr_value = self.vocab.strings[attr_value] + string_attr_name = self.vocab.strings[self.attr] + # Concatenate the attr name and value to not pollute lexeme space + # e.g. 'POS-VERB' instead of just 'VERB', which could otherwise + # create false positive matches + return 'matcher:{}-{}'.format(string_attr_name, string_attr_value) + + +def get_bilou(length): + if length == 0: + raise ValueError("Length must be >= 1") + elif length == 1: + return [U_ENT] + elif length == 2: + return [B2_ENT, L2_ENT] + elif length == 3: + return [B3_ENT, I3_ENT, L3_ENT] + else: + return [B4_ENT, I4_ENT] + [I4_ENT] * (length-3) + [L4_ENT]