diff --git a/spacy/matcher.pyx b/spacy/matcher.pyx index 93ee3d984..cf8ee089d 100644 --- a/spacy/matcher.pyx +++ b/spacy/matcher.pyx @@ -48,6 +48,7 @@ from .attrs import FLAG37 as L8_ENT from .attrs import FLAG36 as L9_ENT from .attrs import FLAG35 as L10_ENT +DELIMITER = '||' cpdef enum quantifier_t: _META @@ -66,7 +67,7 @@ cdef enum action_t: ACCEPT_PREV PANIC -# A "match expression" conists of one or more token patterns +# A "match expression" consists of one or more token patterns # Each token pattern consists of a quantifier and 0+ (attr, value) pairs. # A state is an (int, pattern pointer) pair, where the int is the start # position, and the pattern pointer shows where we're up to @@ -76,16 +77,16 @@ cdef struct AttrValueC: attr_id_t attr attr_t value - cdef struct TokenPatternC: AttrValueC* attrs int32_t nr_attr quantifier_t quantifier - ctypedef TokenPatternC* TokenPatternC_ptr ctypedef pair[int, TokenPatternC_ptr] StateC +DEF PADDING = 5 + cdef TokenPatternC* init_pattern(Pool mem, attr_t entity_id, object token_specs) except NULL: @@ -105,7 +106,6 @@ cdef TokenPatternC* init_pattern(Pool mem, attr_t entity_id, pattern[i].nr_attr = 0 return pattern - cdef attr_t get_pattern_key(const TokenPatternC* pattern) except 0: while pattern.nr_attr != 0: pattern += 1 @@ -262,7 +262,7 @@ cdef class Matcher: key (unicode): The match ID. on_match (callable): Callback executed on match. - *patterns (list): List of token descritions. + *patterns (list): List of token descriptions. """ for pattern in patterns: if len(pattern) == 0: @@ -526,6 +526,7 @@ cdef class PhraseMatcher: self.phrase_ids.set(phrase_hash, ent_id) def __call__(self, Doc doc): + """Find all sequences matching the supplied patterns on the `Doc`. doc (Doc): The document to match over. @@ -573,3 +574,236 @@ cdef class PhraseMatcher: return None else: return ent_id + +cdef class DependencyTreeMatcher: + """Match dependency parse tree based on pattern rules.""" + cdef Pool mem + cdef readonly Vocab vocab + cdef readonly Matcher token_matcher + cdef public object _patterns + cdef public object _keys_to_token + cdef public object _root + cdef public object _entities + cdef public object _callbacks + cdef public object _nodes + cdef public object _tree + + def __init__(self, vocab): + """Create the DependencyTreeMatcher. + + vocab (Vocab): The vocabulary object, which must be shared with the + documents the matcher will operate on. + RETURNS (DependencyTreeMatcher): The newly constructed object. + """ + size = 20 + self.token_matcher = Matcher(vocab) + self._keys_to_token = {} + self._patterns = {} + self._root = {} + self._nodes = {} + self._tree = {} + self._entities = {} + self._callbacks = {} + self.vocab = vocab + self.mem = Pool() + + def __reduce__(self): + data = (self.vocab, self._patterns,self._tree, self._callbacks) + return (unpickle_matcher, data, None, None) + + def __len__(self): + """Get the number of rules, which are edges ,added to the dependency tree matcher. + + RETURNS (int): The number of rules. + """ + return len(self._patterns) + + def __contains__(self, key): + """Check whether the matcher contains rules for a match ID. + + key (unicode): The match ID. + RETURNS (bool): Whether the matcher contains rules for this match ID. + """ + return self._normalize_key(key) in self._patterns + + + def add(self, key, on_match, *patterns): + + # TODO : validations + # 1. check if input pattern is connected + # 2. check if pattern format is correct + # 3. check if atleast one root node is present + # 4. check if node names are not repeated + # 5. check if each node has only one head + + for pattern in patterns: + if len(pattern) == 0: + raise ValueError(Errors.E012.format(key=key)) + + key = self._normalize_key(key) + + _patterns = [] + for pattern in patterns: + token_patterns = [] + for i in range(len(pattern)): + token_pattern = [pattern[i]['PATTERN']] + token_patterns.append(token_pattern) + # self.patterns.append(token_patterns) + _patterns.append(token_patterns) + + self._patterns.setdefault(key, []) + self._callbacks[key] = on_match + self._patterns[key].extend(_patterns) + + # Add each node pattern of all the input patterns individually to the matcher. + # This enables only a single instance of Matcher to be used. + # Multiple adds are required to track each node pattern. + _keys_to_token_list = [] + for i in range(len(_patterns)): + _keys_to_token = {} + # TODO : Better ways to hash edges in pattern? + for j in range(len(_patterns[i])): + k = self._normalize_key(unicode(key)+DELIMITER+unicode(i)+DELIMITER+unicode(j)) + self.token_matcher.add(k,None,_patterns[i][j]) + _keys_to_token[k] = j + _keys_to_token_list.append(_keys_to_token) + + self._keys_to_token.setdefault(key, []) + self._keys_to_token[key].extend(_keys_to_token_list) + + _nodes_list = [] + for pattern in patterns: + nodes = {} + for i in range(len(pattern)): + nodes[pattern[i]['SPEC']['NODE_NAME']]=i + _nodes_list.append(nodes) + + self._nodes.setdefault(key, []) + self._nodes[key].extend(_nodes_list) + + # Create an object tree to traverse later on. + # This datastructure enable easy tree pattern match. + # Doc-Token based tree cannot be reused since it is memory heavy and tightly coupled with doc + self.retrieve_tree(patterns,_nodes_list,key) + + def retrieve_tree(self,patterns,_nodes_list,key): + + _heads_list = [] + _root_list = [] + for i in range(len(patterns)): + heads = {} + root = -1 + for j in range(len(patterns[i])): + token_pattern = patterns[i][j] + if('NBOR_RELOP' not in token_pattern['SPEC']): + heads[j] = j + root = j + else: + # TODO: Add semgrex rules + # 1. > + if(token_pattern['SPEC']['NBOR_RELOP'] == '>'): + heads[j] = _nodes_list[i][token_pattern['SPEC']['NBOR_NAME']] + # 2. < + if(token_pattern['SPEC']['NBOR_RELOP'] == '<'): + heads[_nodes_list[i][token_pattern['SPEC']['NBOR_NAME']]] = j + + _heads_list.append(heads) + _root_list.append(root) + + _tree_list = [] + for i in range(len(patterns)): + tree = {} + for j in range(len(patterns[i])): + if(j == _heads_list[i][j]): + continue + head = _heads_list[i][j] + if(head not in tree): + tree[head] = [] + tree[head].append(j) + _tree_list.append(tree) + + self._tree.setdefault(key, []) + self._tree[key].extend(_tree_list) + + self._root.setdefault(key, []) + self._root[key].extend(_root_list) + + def has_key(self, key): + """Check whether the matcher has a rule with a given key. + + key (string or int): The key to check. + RETURNS (bool): Whether the matcher has the rule. + """ + key = self._normalize_key(key) + return key in self._patterns + + def get(self, key, default=None): + """Retrieve the pattern stored for a key. + + key (unicode or int): The key to retrieve. + RETURNS (tuple): The rule, as an (on_match, patterns) tuple. + """ + key = self._normalize_key(key) + if key not in self._patterns: + return default + return (self._callbacks[key], self._patterns[key]) + + def __call__(self, Doc doc): + matched_trees = [] + + matches = self.token_matcher(doc) + for key in list(self._patterns.keys()): + _patterns_list = self._patterns[key] + _keys_to_token_list = self._keys_to_token[key] + _root_list = self._root[key] + _tree_list = self._tree[key] + _nodes_list = self._nodes[key] + length = len(_patterns_list) + for i in range(length): + _keys_to_token = _keys_to_token_list[i] + _root = _root_list[i] + _tree = _tree_list[i] + _nodes = _nodes_list[i] + + id_to_position = {} + + # This could be taken outside to improve running time..? + for match_id, start, end in matches: + if match_id in _keys_to_token: + if _keys_to_token[match_id] not in id_to_position: + id_to_position[_keys_to_token[match_id]] = [] + id_to_position[_keys_to_token[match_id]].append(start) + + length = len(_nodes) + if _root in id_to_position: + candidates = id_to_position[_root] + for candidate in candidates: + isVisited = {} + self.dfs(candidate,_root,_tree,id_to_position,doc,isVisited) + # to check if the subtree pattern is completely identified + if(len(isVisited) == length): + matched_trees.append((key,list(isVisited))) + + for i, (ent_id, nodes) in enumerate(matched_trees): + on_match = self._callbacks.get(ent_id) + if on_match is not None: + on_match(self, doc, i, matches) + + return matched_trees + + def dfs(self,candidate,root,tree,id_to_position,doc,isVisited): + if(root in id_to_position and candidate in id_to_position[root]): + # color the node since it is valid + isVisited[candidate] = True + candidate_children = doc[candidate].children + for candidate_child in candidate_children: + if root in tree: + for root_child in tree[root]: + self.dfs(candidate_child.i,root_child,tree,id_to_position,doc,isVisited) + + + def _normalize_key(self, key): + if isinstance(key, basestring): + return self.vocab.strings.add(key) + else: + return key \ No newline at end of file diff --git a/spacy/tests/test_matcher.py b/spacy/tests/test_matcher.py index 8210467ea..bb3013dbd 100644 --- a/spacy/tests/test_matcher.py +++ b/spacy/tests/test_matcher.py @@ -1,12 +1,14 @@ # coding: utf-8 from __future__ import unicode_literals -from ..matcher import Matcher, PhraseMatcher +from numpy import sort + +from ..matcher import Matcher, PhraseMatcher, DependencyTreeMatcher from .util import get_doc from ..tokens import Doc import pytest - +import re @pytest.fixture def matcher(en_vocab): @@ -20,7 +22,6 @@ def matcher(en_vocab): matcher.add(key, None, *patterns) return matcher - def test_matcher_from_api_docs(en_vocab): matcher = Matcher(en_vocab) pattern = [{'ORTH': 'test'}] @@ -258,3 +259,47 @@ def test_matcher_end_zero_plus(matcher): assert len(matcher(nlp(u'a b c'))) == 1 assert len(matcher(nlp(u'a b b c'))) == 1 assert len(matcher(nlp(u'a b b'))) == 1 + + +@pytest.fixture +def text(): + return u"The quick brown fox jumped over the lazy fox" + +@pytest.fixture +def heads(): + return [3,2,1,1,0,-1,2,1,-3] + +@pytest.fixture +def deps(): + return ['det', 'amod', 'amod', 'nsubj', 'prep', 'pobj', 'det', 'amod'] + +@pytest.fixture +def dependency_tree_matcher(en_vocab): + is_brown_yellow = lambda text: bool(re.compile(r'brown|yellow|over').match(text)) + IS_BROWN_YELLOW = en_vocab.add_flag(is_brown_yellow) + pattern1 = [ + {'SPEC': {'NODE_NAME': 'fox'}, 'PATTERN': {'ORTH': 'fox'}}, + {'SPEC': {'NODE_NAME': 'q', 'NBOR_RELOP': '>', 'NBOR_NAME': 'fox'},'PATTERN': {'LOWER': u'quick'}}, + {'SPEC': {'NODE_NAME': 'r', 'NBOR_RELOP': '>', 'NBOR_NAME': 'fox'}, 'PATTERN': {IS_BROWN_YELLOW: True}} + ] + + pattern2 = [ + {'SPEC': {'NODE_NAME': 'jumped'}, 'PATTERN': {'ORTH': 'jumped'}}, + {'SPEC': {'NODE_NAME': 'fox', 'NBOR_RELOP': '>', 'NBOR_NAME': 'jumped'},'PATTERN': {'LOWER': u'fox'}}, + {'SPEC': {'NODE_NAME': 'over', 'NBOR_RELOP': '>', 'NBOR_NAME': 'fox'}, 'PATTERN': {IS_BROWN_YELLOW: True}} + ] + matcher = DependencyTreeMatcher(en_vocab) + matcher.add('pattern1', None, pattern1) + matcher.add('pattern2', None, pattern2) + return matcher + + + +def test_dependency_tree_matcher_compile(dependency_tree_matcher): + assert len(dependency_tree_matcher) == 2 + +def test_dependency_tree_matcher(dependency_tree_matcher,text,heads,deps): + doc = get_doc(dependency_tree_matcher.vocab,text.split(),heads=heads,deps=deps) + matches = dependency_tree_matcher(doc) + assert len(matches) == 2 +