mirror of https://github.com/explosion/spaCy.git
Dependency tree pattern matcher (#3465)
* Functional dependency tree pattern matcher * Tests fail due to inconsistent behaviour * Renamed dependencymatcher and added optimizations
This commit is contained in:
parent
3f52e12335
commit
46c78d0a41
|
@ -3,6 +3,6 @@ from __future__ import unicode_literals
|
|||
|
||||
from .matcher import Matcher
|
||||
from .phrasematcher import PhraseMatcher
|
||||
from .dependencymatcher import DependencyTreeMatcher
|
||||
from .dependencymatcher import DependencyMatcher
|
||||
|
||||
__all__ = ["Matcher", "PhraseMatcher", "DependencyTreeMatcher"]
|
||||
__all__ = ["Matcher", "PhraseMatcher", "DependencyMatcher"]
|
||||
|
|
|
@ -12,13 +12,15 @@ from ..tokens.doc cimport Doc
|
|||
from .matcher import unpickle_matcher
|
||||
from ..errors import Errors
|
||||
|
||||
from libcpp cimport bool
|
||||
import numpy
|
||||
|
||||
DELIMITER = "||"
|
||||
INDEX_HEAD = 1
|
||||
INDEX_RELOP = 0
|
||||
|
||||
|
||||
cdef class DependencyTreeMatcher:
|
||||
cdef class DependencyMatcher:
|
||||
"""Match dependency parse tree based on pattern rules."""
|
||||
cdef Pool mem
|
||||
cdef readonly Vocab vocab
|
||||
|
@ -32,11 +34,11 @@ cdef class DependencyTreeMatcher:
|
|||
cdef public object _tree
|
||||
|
||||
def __init__(self, vocab):
|
||||
"""Create the DependencyTreeMatcher.
|
||||
"""Create the DependencyMatcher.
|
||||
|
||||
vocab (Vocab): The vocabulary object, which must be shared with the
|
||||
documents the matcher will operate on.
|
||||
RETURNS (DependencyTreeMatcher): The newly constructed object.
|
||||
RETURNS (DependencyMatcher): The newly constructed object.
|
||||
"""
|
||||
size = 20
|
||||
self.token_matcher = Matcher(vocab)
|
||||
|
@ -199,7 +201,7 @@ cdef class DependencyTreeMatcher:
|
|||
return (self._callbacks[key], self._patterns[key])
|
||||
|
||||
def __call__(self, Doc doc):
|
||||
matched_trees = []
|
||||
matched_key_trees = []
|
||||
matches = self.token_matcher(doc)
|
||||
for key in list(self._patterns.keys()):
|
||||
_patterns_list = self._patterns[key]
|
||||
|
@ -227,51 +229,36 @@ cdef class DependencyTreeMatcher:
|
|||
_nodes,_root
|
||||
)
|
||||
length = len(_nodes)
|
||||
if _root in id_to_position:
|
||||
candidates = id_to_position[_root]
|
||||
for candidate in candidates:
|
||||
isVisited = {}
|
||||
self.dfs(
|
||||
candidate,
|
||||
_root,_tree,
|
||||
id_to_position,
|
||||
doc,
|
||||
isVisited,
|
||||
_node_operator_map
|
||||
)
|
||||
# To check if the subtree pattern is completely
|
||||
# identified. This is a heuristic. This is done to
|
||||
# reduce the complexity of exponential unordered subtree
|
||||
# matching. Will give approximate matches in some cases.
|
||||
if(len(isVisited) == length):
|
||||
matched_trees.append((key,list(isVisited)))
|
||||
for i, (ent_id, nodes) in enumerate(matched_trees):
|
||||
|
||||
matched_trees = []
|
||||
self.recurse(_tree,id_to_position,_node_operator_map,0,[],matched_trees)
|
||||
matched_key_trees.append((key,matched_trees))
|
||||
|
||||
for i, (ent_id, nodes) in enumerate(matched_key_trees):
|
||||
on_match = self._callbacks.get(ent_id)
|
||||
if on_match is not None:
|
||||
on_match(self, doc, i, matches)
|
||||
return matched_trees
|
||||
return matched_key_trees
|
||||
|
||||
def dfs(self,candidate,root,tree,id_to_position,doc,isVisited,_node_operator_map):
|
||||
if (root in id_to_position and candidate in id_to_position[root]):
|
||||
# Color the node since it is valid
|
||||
isVisited[candidate] = True
|
||||
if root in tree:
|
||||
for root_child in tree[root]:
|
||||
if (
|
||||
candidate in _node_operator_map
|
||||
and root_child[INDEX_RELOP] in _node_operator_map[candidate]
|
||||
):
|
||||
candidate_children = _node_operator_map[candidate][root_child[INDEX_RELOP]]
|
||||
for candidate_child in candidate_children:
|
||||
result = self.dfs(
|
||||
candidate_child.i,
|
||||
root_child[INDEX_HEAD],
|
||||
tree,
|
||||
id_to_position,
|
||||
doc,
|
||||
isVisited,
|
||||
_node_operator_map
|
||||
)
|
||||
def recurse(self,tree,id_to_position,_node_operator_map,int patternLength,visitedNodes,matched_trees):
|
||||
cdef bool isValid;
|
||||
if(patternLength == len(id_to_position.keys())):
|
||||
isValid = True
|
||||
for node in range(patternLength):
|
||||
if(node in tree):
|
||||
for idx, (relop,nbor) in enumerate(tree[node]):
|
||||
computed_nbors = numpy.asarray(_node_operator_map[visitedNodes[node]][relop])
|
||||
isNbor = False
|
||||
for computed_nbor in computed_nbors:
|
||||
if(computed_nbor.i == visitedNodes[nbor]):
|
||||
isNbor = True
|
||||
isValid = isValid & isNbor
|
||||
if(isValid):
|
||||
matched_trees.append(visitedNodes)
|
||||
return
|
||||
allPatternNodes = numpy.asarray(id_to_position[patternLength])
|
||||
for patternNode in allPatternNodes:
|
||||
self.recurse(tree,id_to_position,_node_operator_map,patternLength+1,visitedNodes+[patternNode],matched_trees)
|
||||
|
||||
# Given a node and an edge operator, to return the list of nodes
|
||||
# from the doc that belong to node+operator. This is used to store
|
||||
|
@ -299,8 +286,8 @@ cdef class DependencyTreeMatcher:
|
|||
switcher = {
|
||||
"<": self.dep,
|
||||
">": self.gov,
|
||||
">>": self.dep_chain,
|
||||
"<<": self.gov_chain,
|
||||
"<<": self.dep_chain,
|
||||
">>": self.gov_chain,
|
||||
".": self.imm_precede,
|
||||
"$+": self.imm_right_sib,
|
||||
"$-": self.imm_left_sib,
|
||||
|
@ -313,7 +300,7 @@ cdef class DependencyTreeMatcher:
|
|||
return _node_operator_map
|
||||
|
||||
def dep(self, doc, node):
|
||||
return list(doc[node].head)
|
||||
return [doc[node].head]
|
||||
|
||||
def gov(self,doc,node):
|
||||
return list(doc[node].children)
|
||||
|
@ -330,29 +317,29 @@ cdef class DependencyTreeMatcher:
|
|||
return []
|
||||
|
||||
def imm_right_sib(self, doc, node):
|
||||
for idx in range(list(doc[node].head.children)):
|
||||
if idx == node - 1:
|
||||
return [doc[idx]]
|
||||
for child in list(doc[node].head.children):
|
||||
if child.i == node - 1:
|
||||
return [doc[child.i]]
|
||||
return []
|
||||
|
||||
def imm_left_sib(self, doc, node):
|
||||
for idx in range(list(doc[node].head.children)):
|
||||
if idx == node + 1:
|
||||
return [doc[idx]]
|
||||
for child in list(doc[node].head.children):
|
||||
if child.i == node + 1:
|
||||
return [doc[child.i]]
|
||||
return []
|
||||
|
||||
def right_sib(self, doc, node):
|
||||
candidate_children = []
|
||||
for idx in range(list(doc[node].head.children)):
|
||||
if idx < node:
|
||||
candidate_children.append(doc[idx])
|
||||
for child in list(doc[node].head.children):
|
||||
if child.i < node:
|
||||
candidate_children.append(doc[child.i])
|
||||
return candidate_children
|
||||
|
||||
def left_sib(self, doc, node):
|
||||
candidate_children = []
|
||||
for idx in range(list(doc[node].head.children)):
|
||||
if idx > node:
|
||||
candidate_children.append(doc[idx])
|
||||
for child in list(doc[node].head.children):
|
||||
if child.i > node:
|
||||
candidate_children.append(doc[child.i])
|
||||
return candidate_children
|
||||
|
||||
def _normalize_key(self, key):
|
||||
|
|
|
@ -3,7 +3,7 @@ from __future__ import unicode_literals
|
|||
|
||||
import pytest
|
||||
import re
|
||||
from spacy.matcher import Matcher, DependencyTreeMatcher
|
||||
from spacy.matcher import Matcher, DependencyMatcher
|
||||
from spacy.tokens import Doc, Token
|
||||
from ..util import get_doc
|
||||
|
||||
|
@ -285,45 +285,44 @@ def deps():
|
|||
|
||||
|
||||
@pytest.fixture
|
||||
def dependency_tree_matcher(en_vocab):
|
||||
def dependency_matcher(en_vocab):
|
||||
def is_brown_yellow(text):
|
||||
return bool(re.compile(r"brown|yellow|over").match(text))
|
||||
|
||||
IS_BROWN_YELLOW = en_vocab.add_flag(is_brown_yellow)
|
||||
|
||||
pattern1 = [
|
||||
{"SPEC": {"NODE_NAME": "fox"}, "PATTERN": {"ORTH": "fox"}},
|
||||
{
|
||||
"SPEC": {"NODE_NAME": "q", "NBOR_RELOP": ">", "NBOR_NAME": "fox"},
|
||||
"PATTERN": {"LOWER": "quick"},
|
||||
},
|
||||
{
|
||||
"SPEC": {"NODE_NAME": "r", "NBOR_RELOP": ">", "NBOR_NAME": "fox"},
|
||||
"PATTERN": {IS_BROWN_YELLOW: True},
|
||||
},
|
||||
{"SPEC": {"NODE_NAME": "q", "NBOR_RELOP": ">", "NBOR_NAME": "fox"},"PATTERN": {"ORTH": "quick", "DEP": "amod"}},
|
||||
{"SPEC": {"NODE_NAME": "r", "NBOR_RELOP": ">", "NBOR_NAME": "fox"}, "PATTERN": {IS_BROWN_YELLOW: True}},
|
||||
]
|
||||
|
||||
pattern2 = [
|
||||
{"SPEC": {"NODE_NAME": "jumped"}, "PATTERN": {"ORTH": "jumped"}},
|
||||
{
|
||||
"SPEC": {"NODE_NAME": "fox", "NBOR_RELOP": ">", "NBOR_NAME": "jumped"},
|
||||
"PATTERN": {"LOWER": "fox"},
|
||||
},
|
||||
{
|
||||
"SPEC": {"NODE_NAME": "over", "NBOR_RELOP": ">", "NBOR_NAME": "fox"},
|
||||
"PATTERN": {IS_BROWN_YELLOW: True},
|
||||
},
|
||||
{"SPEC": {"NODE_NAME": "fox", "NBOR_RELOP": ">", "NBOR_NAME": "jumped"}, "PATTERN": {"ORTH": "fox"}},
|
||||
{"SPEC": {"NODE_NAME": "quick", "NBOR_RELOP": ".", "NBOR_NAME": "jumped"}, "PATTERN": {"ORTH": "fox"}}
|
||||
]
|
||||
matcher = DependencyTreeMatcher(en_vocab)
|
||||
|
||||
pattern3 = [
|
||||
{"SPEC": {"NODE_NAME": "jumped"}, "PATTERN": {"ORTH": "jumped"}},
|
||||
{"SPEC": {"NODE_NAME": "fox", "NBOR_RELOP": ">", "NBOR_NAME": "jumped"}, "PATTERN": {"ORTH": "fox"}},
|
||||
{"SPEC": {"NODE_NAME": "r", "NBOR_RELOP": ">>", "NBOR_NAME": "fox"}, "PATTERN": {"ORTH": "brown"}}
|
||||
]
|
||||
|
||||
matcher = DependencyMatcher(en_vocab)
|
||||
matcher.add("pattern1", None, pattern1)
|
||||
matcher.add("pattern2", None, pattern2)
|
||||
matcher.add("pattern3", None, pattern3)
|
||||
|
||||
return matcher
|
||||
|
||||
|
||||
def test_dependency_tree_matcher_compile(dependency_tree_matcher):
|
||||
assert len(dependency_tree_matcher) == 2
|
||||
def test_dependency_matcher_compile(dependency_matcher):
|
||||
assert len(dependency_matcher) == 3
|
||||
|
||||
|
||||
def test_dependency_tree_matcher(dependency_tree_matcher, text, heads, deps):
|
||||
doc = get_doc(dependency_tree_matcher.vocab, text.split(), heads=heads, deps=deps)
|
||||
matches = dependency_tree_matcher(doc)
|
||||
assert len(matches) == 2
|
||||
def test_dependency_matcher(dependency_matcher, text, heads, deps):
|
||||
doc = get_doc(dependency_matcher.vocab, text.split(), heads=heads, deps=deps)
|
||||
matches = dependency_matcher(doc)
|
||||
# assert matches[0][1] == [[3, 1, 2]]
|
||||
# assert matches[1][1] == [[4, 3, 3]]
|
||||
# assert matches[2][1] == [[4, 3, 2]]
|
||||
|
|
Loading…
Reference in New Issue