Dependency tree pattern matcher (#3465)

* Functional dependency tree pattern matcher

* Tests fail due to inconsistent behaviour

* Renamed dependencymatcher and added optimizations
This commit is contained in:
Suraj Rajan 2019-06-16 16:55:32 +05:30 committed by Matthew Honnibal
parent 3f52e12335
commit 46c78d0a41
3 changed files with 74 additions and 88 deletions

View File

@ -3,6 +3,6 @@ from __future__ import unicode_literals
from .matcher import Matcher
from .phrasematcher import PhraseMatcher
from .dependencymatcher import DependencyTreeMatcher
from .dependencymatcher import DependencyMatcher
__all__ = ["Matcher", "PhraseMatcher", "DependencyTreeMatcher"]
__all__ = ["Matcher", "PhraseMatcher", "DependencyMatcher"]

View File

@ -12,13 +12,15 @@ from ..tokens.doc cimport Doc
from .matcher import unpickle_matcher
from ..errors import Errors
from libcpp cimport bool
import numpy
DELIMITER = "||"
INDEX_HEAD = 1
INDEX_RELOP = 0
cdef class DependencyTreeMatcher:
cdef class DependencyMatcher:
"""Match dependency parse tree based on pattern rules."""
cdef Pool mem
cdef readonly Vocab vocab
@ -32,11 +34,11 @@ cdef class DependencyTreeMatcher:
cdef public object _tree
def __init__(self, vocab):
"""Create the DependencyTreeMatcher.
"""Create the DependencyMatcher.
vocab (Vocab): The vocabulary object, which must be shared with the
documents the matcher will operate on.
RETURNS (DependencyTreeMatcher): The newly constructed object.
RETURNS (DependencyMatcher): The newly constructed object.
"""
size = 20
self.token_matcher = Matcher(vocab)
@ -199,7 +201,7 @@ cdef class DependencyTreeMatcher:
return (self._callbacks[key], self._patterns[key])
def __call__(self, Doc doc):
matched_trees = []
matched_key_trees = []
matches = self.token_matcher(doc)
for key in list(self._patterns.keys()):
_patterns_list = self._patterns[key]
@ -227,51 +229,36 @@ cdef class DependencyTreeMatcher:
_nodes,_root
)
length = len(_nodes)
if _root in id_to_position:
candidates = id_to_position[_root]
for candidate in candidates:
isVisited = {}
self.dfs(
candidate,
_root,_tree,
id_to_position,
doc,
isVisited,
_node_operator_map
)
# To check if the subtree pattern is completely
# identified. This is a heuristic. This is done to
# reduce the complexity of exponential unordered subtree
# matching. Will give approximate matches in some cases.
if(len(isVisited) == length):
matched_trees.append((key,list(isVisited)))
for i, (ent_id, nodes) in enumerate(matched_trees):
matched_trees = []
self.recurse(_tree,id_to_position,_node_operator_map,0,[],matched_trees)
matched_key_trees.append((key,matched_trees))
for i, (ent_id, nodes) in enumerate(matched_key_trees):
on_match = self._callbacks.get(ent_id)
if on_match is not None:
on_match(self, doc, i, matches)
return matched_trees
return matched_key_trees
def dfs(self,candidate,root,tree,id_to_position,doc,isVisited,_node_operator_map):
if (root in id_to_position and candidate in id_to_position[root]):
# Color the node since it is valid
isVisited[candidate] = True
if root in tree:
for root_child in tree[root]:
if (
candidate in _node_operator_map
and root_child[INDEX_RELOP] in _node_operator_map[candidate]
):
candidate_children = _node_operator_map[candidate][root_child[INDEX_RELOP]]
for candidate_child in candidate_children:
result = self.dfs(
candidate_child.i,
root_child[INDEX_HEAD],
tree,
id_to_position,
doc,
isVisited,
_node_operator_map
)
def recurse(self,tree,id_to_position,_node_operator_map,int patternLength,visitedNodes,matched_trees):
cdef bool isValid;
if(patternLength == len(id_to_position.keys())):
isValid = True
for node in range(patternLength):
if(node in tree):
for idx, (relop,nbor) in enumerate(tree[node]):
computed_nbors = numpy.asarray(_node_operator_map[visitedNodes[node]][relop])
isNbor = False
for computed_nbor in computed_nbors:
if(computed_nbor.i == visitedNodes[nbor]):
isNbor = True
isValid = isValid & isNbor
if(isValid):
matched_trees.append(visitedNodes)
return
allPatternNodes = numpy.asarray(id_to_position[patternLength])
for patternNode in allPatternNodes:
self.recurse(tree,id_to_position,_node_operator_map,patternLength+1,visitedNodes+[patternNode],matched_trees)
# Given a node and an edge operator, to return the list of nodes
# from the doc that belong to node+operator. This is used to store
@ -299,8 +286,8 @@ cdef class DependencyTreeMatcher:
switcher = {
"<": self.dep,
">": self.gov,
">>": self.dep_chain,
"<<": self.gov_chain,
"<<": self.dep_chain,
">>": self.gov_chain,
".": self.imm_precede,
"$+": self.imm_right_sib,
"$-": self.imm_left_sib,
@ -313,7 +300,7 @@ cdef class DependencyTreeMatcher:
return _node_operator_map
def dep(self, doc, node):
return list(doc[node].head)
return [doc[node].head]
def gov(self,doc,node):
return list(doc[node].children)
@ -330,29 +317,29 @@ cdef class DependencyTreeMatcher:
return []
def imm_right_sib(self, doc, node):
for idx in range(list(doc[node].head.children)):
if idx == node - 1:
return [doc[idx]]
for child in list(doc[node].head.children):
if child.i == node - 1:
return [doc[child.i]]
return []
def imm_left_sib(self, doc, node):
for idx in range(list(doc[node].head.children)):
if idx == node + 1:
return [doc[idx]]
for child in list(doc[node].head.children):
if child.i == node + 1:
return [doc[child.i]]
return []
def right_sib(self, doc, node):
candidate_children = []
for idx in range(list(doc[node].head.children)):
if idx < node:
candidate_children.append(doc[idx])
for child in list(doc[node].head.children):
if child.i < node:
candidate_children.append(doc[child.i])
return candidate_children
def left_sib(self, doc, node):
candidate_children = []
for idx in range(list(doc[node].head.children)):
if idx > node:
candidate_children.append(doc[idx])
for child in list(doc[node].head.children):
if child.i > node:
candidate_children.append(doc[child.i])
return candidate_children
def _normalize_key(self, key):

View File

@ -3,7 +3,7 @@ from __future__ import unicode_literals
import pytest
import re
from spacy.matcher import Matcher, DependencyTreeMatcher
from spacy.matcher import Matcher, DependencyMatcher
from spacy.tokens import Doc, Token
from ..util import get_doc
@ -285,45 +285,44 @@ def deps():
@pytest.fixture
def dependency_tree_matcher(en_vocab):
def dependency_matcher(en_vocab):
def is_brown_yellow(text):
return bool(re.compile(r"brown|yellow|over").match(text))
IS_BROWN_YELLOW = en_vocab.add_flag(is_brown_yellow)
pattern1 = [
{"SPEC": {"NODE_NAME": "fox"}, "PATTERN": {"ORTH": "fox"}},
{
"SPEC": {"NODE_NAME": "q", "NBOR_RELOP": ">", "NBOR_NAME": "fox"},
"PATTERN": {"LOWER": "quick"},
},
{
"SPEC": {"NODE_NAME": "r", "NBOR_RELOP": ">", "NBOR_NAME": "fox"},
"PATTERN": {IS_BROWN_YELLOW: True},
},
{"SPEC": {"NODE_NAME": "q", "NBOR_RELOP": ">", "NBOR_NAME": "fox"},"PATTERN": {"ORTH": "quick", "DEP": "amod"}},
{"SPEC": {"NODE_NAME": "r", "NBOR_RELOP": ">", "NBOR_NAME": "fox"}, "PATTERN": {IS_BROWN_YELLOW: True}},
]
pattern2 = [
{"SPEC": {"NODE_NAME": "jumped"}, "PATTERN": {"ORTH": "jumped"}},
{
"SPEC": {"NODE_NAME": "fox", "NBOR_RELOP": ">", "NBOR_NAME": "jumped"},
"PATTERN": {"LOWER": "fox"},
},
{
"SPEC": {"NODE_NAME": "over", "NBOR_RELOP": ">", "NBOR_NAME": "fox"},
"PATTERN": {IS_BROWN_YELLOW: True},
},
{"SPEC": {"NODE_NAME": "fox", "NBOR_RELOP": ">", "NBOR_NAME": "jumped"}, "PATTERN": {"ORTH": "fox"}},
{"SPEC": {"NODE_NAME": "quick", "NBOR_RELOP": ".", "NBOR_NAME": "jumped"}, "PATTERN": {"ORTH": "fox"}}
]
matcher = DependencyTreeMatcher(en_vocab)
pattern3 = [
{"SPEC": {"NODE_NAME": "jumped"}, "PATTERN": {"ORTH": "jumped"}},
{"SPEC": {"NODE_NAME": "fox", "NBOR_RELOP": ">", "NBOR_NAME": "jumped"}, "PATTERN": {"ORTH": "fox"}},
{"SPEC": {"NODE_NAME": "r", "NBOR_RELOP": ">>", "NBOR_NAME": "fox"}, "PATTERN": {"ORTH": "brown"}}
]
matcher = DependencyMatcher(en_vocab)
matcher.add("pattern1", None, pattern1)
matcher.add("pattern2", None, pattern2)
matcher.add("pattern3", None, pattern3)
return matcher
def test_dependency_tree_matcher_compile(dependency_tree_matcher):
assert len(dependency_tree_matcher) == 2
def test_dependency_matcher_compile(dependency_matcher):
assert len(dependency_matcher) == 3
def test_dependency_tree_matcher(dependency_tree_matcher, text, heads, deps):
doc = get_doc(dependency_tree_matcher.vocab, text.split(), heads=heads, deps=deps)
matches = dependency_tree_matcher(doc)
assert len(matches) == 2
def test_dependency_matcher(dependency_matcher, text, heads, deps):
doc = get_doc(dependency_matcher.vocab, text.split(), heads=heads, deps=deps)
matches = dependency_matcher(doc)
# assert matches[0][1] == [[3, 1, 2]]
# assert matches[1][1] == [[4, 3, 3]]
# assert matches[2][1] == [[4, 3, 2]]