Merge pull request #6018 from adrianeboyd/feature/dependency-matcher-v3

This commit is contained in:
Ines Montani 2020-09-04 20:51:50 +02:00 committed by GitHub
commit a8b5f78fc3
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
10 changed files with 996 additions and 279 deletions

View File

@ -288,12 +288,12 @@ class Errors:
"Span objects, or dicts if set to manual=True.")
E097 = ("Invalid pattern: expected token pattern (list of dicts) or "
"phrase pattern (string) but got:\n{pattern}")
E098 = ("Invalid pattern specified: expected both SPEC and PATTERN.")
E099 = ("First node of pattern should be a root node. The root should "
"only contain NODE_NAME.")
E100 = ("Nodes apart from the root should contain NODE_NAME, NBOR_NAME and "
"NBOR_RELOP.")
E101 = ("NODE_NAME should be a new node and NBOR_NAME should already have "
E098 = ("Invalid pattern: expected both RIGHT_ID and RIGHT_ATTRS.")
E099 = ("Invalid pattern: the first node of pattern should be an anchor "
"node. The node should only contain RIGHT_ID and RIGHT_ATTRS.")
E100 = ("Nodes other than the anchor node should all contain LEFT_ID, "
"REL_OP and RIGHT_ID.")
E101 = ("RIGHT_ID should be a new node and LEFT_ID should already have "
"have been declared in previous edges.")
E102 = ("Can't merge non-disjoint spans. '{token}' is already part of "
"tokens to merge. If you want to find the longest non-overlapping "
@ -661,6 +661,9 @@ class Errors:
"'{chunk}'. Tokenizer exceptions are only allowed to specify "
"`ORTH` and `NORM`.")
E1006 = ("Unable to initialize {name} model with 0 labels.")
E1007 = ("Unsupported DependencyMatcher operator '{op}'.")
E1008 = ("Invalid pattern: each pattern should be a list of dicts. Check "
"that you are providing a list of patterns as `List[List[dict]]`.")
@add_codes

View File

@ -1,16 +1,16 @@
# cython: infer_types=True, profile=True
from cymem.cymem cimport Pool
from preshed.maps cimport PreshMap
from libcpp cimport bool
from typing import List
import numpy
from cymem.cymem cimport Pool
from .matcher cimport Matcher
from ..vocab cimport Vocab
from ..tokens.doc cimport Doc
from .matcher import unpickle_matcher
from ..errors import Errors
from ..tokens import Span
DELIMITER = "||"
@ -22,36 +22,52 @@ cdef class DependencyMatcher:
"""Match dependency parse tree based on pattern rules."""
cdef Pool mem
cdef readonly Vocab vocab
cdef readonly Matcher token_matcher
cdef readonly Matcher matcher
cdef public object _patterns
cdef public object _raw_patterns
cdef public object _keys_to_token
cdef public object _root
cdef public object _entities
cdef public object _callbacks
cdef public object _nodes
cdef public object _tree
cdef public object _ops
def __init__(self, vocab):
def __init__(self, vocab, *, validate=False):
"""Create the DependencyMatcher.
vocab (Vocab): The vocabulary object, which must be shared with the
documents the matcher will operate on.
validate (bool): Whether patterns should be validated, passed to
Matcher as `validate`
"""
size = 20
# TODO: make matcher work with validation
self.token_matcher = Matcher(vocab, validate=False)
self.matcher = Matcher(vocab, validate=validate)
self._keys_to_token = {}
self._patterns = {}
self._raw_patterns = {}
self._root = {}
self._nodes = {}
self._tree = {}
self._entities = {}
self._callbacks = {}
self.vocab = vocab
self.mem = Pool()
self._ops = {
"<": self.dep,
">": self.gov,
"<<": self.dep_chain,
">>": self.gov_chain,
".": self.imm_precede,
".*": self.precede,
";": self.imm_follow,
";*": self.follow,
"$+": self.imm_right_sib,
"$-": self.imm_left_sib,
"$++": self.right_sib,
"$--": self.left_sib,
}
def __reduce__(self):
data = (self.vocab, self._patterns,self._tree, self._callbacks)
data = (self.vocab, self._raw_patterns, self._callbacks)
return (unpickle_matcher, data, None, None)
def __len__(self):
@ -74,54 +90,61 @@ cdef class DependencyMatcher:
idx = 0
visited_nodes = {}
for relation in pattern:
if "PATTERN" not in relation or "SPEC" not in relation:
if not isinstance(relation, dict):
raise ValueError(Errors.E1008)
if "RIGHT_ATTRS" not in relation and "RIGHT_ID" not in relation:
raise ValueError(Errors.E098.format(key=key))
if idx == 0:
if not(
"NODE_NAME" in relation["SPEC"]
and "NBOR_RELOP" not in relation["SPEC"]
and "NBOR_NAME" not in relation["SPEC"]
"RIGHT_ID" in relation
and "REL_OP" not in relation
and "LEFT_ID" not in relation
):
raise ValueError(Errors.E099.format(key=key))
visited_nodes[relation["SPEC"]["NODE_NAME"]] = True
visited_nodes[relation["RIGHT_ID"]] = True
else:
if not(
"NODE_NAME" in relation["SPEC"]
and "NBOR_RELOP" in relation["SPEC"]
and "NBOR_NAME" in relation["SPEC"]
"RIGHT_ID" in relation
and "RIGHT_ATTRS" in relation
and "REL_OP" in relation
and "LEFT_ID" in relation
):
raise ValueError(Errors.E100.format(key=key))
if (
relation["SPEC"]["NODE_NAME"] in visited_nodes
or relation["SPEC"]["NBOR_NAME"] not in visited_nodes
relation["RIGHT_ID"] in visited_nodes
or relation["LEFT_ID"] not in visited_nodes
):
raise ValueError(Errors.E101.format(key=key))
visited_nodes[relation["SPEC"]["NODE_NAME"]] = True
visited_nodes[relation["SPEC"]["NBOR_NAME"]] = True
if relation["REL_OP"] not in self._ops:
raise ValueError(Errors.E1007.format(op=relation["REL_OP"]))
visited_nodes[relation["RIGHT_ID"]] = True
visited_nodes[relation["LEFT_ID"]] = True
idx = idx + 1
def add(self, key, patterns, *_patterns, on_match=None):
def add(self, key, patterns, *, on_match=None):
"""Add a new matcher rule to the matcher.
key (str): The match ID.
patterns (list): The patterns to add for the given key.
on_match (callable): Optional callback executed on match.
"""
if patterns is None or hasattr(patterns, "__call__"): # old API
on_match = patterns
patterns = _patterns
if on_match is not None and not hasattr(on_match, "__call__"):
raise ValueError(Errors.E171.format(arg_type=type(on_match)))
if patterns is None or not isinstance(patterns, List): # old API
raise ValueError(Errors.E948.format(arg_type=type(patterns)))
for pattern in patterns:
if len(pattern) == 0:
raise ValueError(Errors.E012.format(key=key))
self.validate_input(pattern,key)
self.validate_input(pattern, key)
key = self._normalize_key(key)
self._raw_patterns.setdefault(key, [])
self._raw_patterns[key].extend(patterns)
_patterns = []
for pattern in patterns:
token_patterns = []
for i in range(len(pattern)):
token_pattern = [pattern[i]["PATTERN"]]
token_pattern = [pattern[i]["RIGHT_ATTRS"]]
token_patterns.append(token_pattern)
# self.patterns.append(token_patterns)
_patterns.append(token_patterns)
self._patterns.setdefault(key, [])
self._callbacks[key] = on_match
@ -135,7 +158,7 @@ cdef class DependencyMatcher:
# TODO: Better ways to hash edges in pattern?
for j in range(len(_patterns[i])):
k = self._normalize_key(unicode(key) + DELIMITER + unicode(i) + DELIMITER + unicode(j))
self.token_matcher.add(k, [_patterns[i][j]])
self.matcher.add(k, [_patterns[i][j]])
_keys_to_token[k] = j
_keys_to_token_list.append(_keys_to_token)
self._keys_to_token.setdefault(key, [])
@ -144,14 +167,14 @@ cdef class DependencyMatcher:
for pattern in patterns:
nodes = {}
for i in range(len(pattern)):
nodes[pattern[i]["SPEC"]["NODE_NAME"]] = i
nodes[pattern[i]["RIGHT_ID"]] = i
_nodes_list.append(nodes)
self._nodes.setdefault(key, [])
self._nodes[key].extend(_nodes_list)
# Create an object tree to traverse later on. This data structure
# enables easy tree pattern match. Doc-Token based tree cannot be
# reused since it is memory-heavy and tightly coupled with the Doc.
self.retrieve_tree(patterns, _nodes_list,key)
self.retrieve_tree(patterns, _nodes_list, key)
def retrieve_tree(self, patterns, _nodes_list, key):
_heads_list = []
@ -161,13 +184,13 @@ cdef class DependencyMatcher:
root = -1
for j in range(len(patterns[i])):
token_pattern = patterns[i][j]
if ("NBOR_RELOP" not in token_pattern["SPEC"]):
if ("REL_OP" not in token_pattern):
heads[j] = ('root', j)
root = j
else:
heads[j] = (
token_pattern["SPEC"]["NBOR_RELOP"],
_nodes_list[i][token_pattern["SPEC"]["NBOR_NAME"]]
token_pattern["REL_OP"],
_nodes_list[i][token_pattern["LEFT_ID"]]
)
_heads_list.append(heads)
_root_list.append(root)
@ -202,11 +225,21 @@ cdef class DependencyMatcher:
RETURNS (tuple): The rule, as an (on_match, patterns) tuple.
"""
key = self._normalize_key(key)
if key not in self._patterns:
if key not in self._raw_patterns:
return default
return (self._callbacks[key], self._patterns[key])
return (self._callbacks[key], self._raw_patterns[key])
def __call__(self, Doc doc):
def remove(self, key):
key = self._normalize_key(key)
if not key in self._patterns:
raise ValueError(Errors.E175.format(key=key))
self._patterns.pop(key)
self._raw_patterns.pop(key)
self._nodes.pop(key)
self._tree.pop(key)
self._root.pop(key)
def __call__(self, object doclike):
"""Find all token sequences matching the supplied pattern.
doclike (Doc or Span): The document to match over.
@ -214,8 +247,14 @@ cdef class DependencyMatcher:
describing the matches. A match tuple describes a span
`doc[start:end]`. The `label_id` and `key` are both integers.
"""
if isinstance(doclike, Doc):
doc = doclike
elif isinstance(doclike, Span):
doc = doclike.as_doc()
else:
raise ValueError(Errors.E195.format(good="Doc or Span", got=type(doclike).__name__))
matched_key_trees = []
matches = self.token_matcher(doc)
matches = self.matcher(doc)
for key in list(self._patterns.keys()):
_patterns_list = self._patterns[key]
_keys_to_token_list = self._keys_to_token[key]
@ -244,26 +283,26 @@ cdef class DependencyMatcher:
length = len(_nodes)
matched_trees = []
self.recurse(_tree,id_to_position,_node_operator_map,0,[],matched_trees)
matched_key_trees.append((key,matched_trees))
for i, (ent_id, nodes) in enumerate(matched_key_trees):
on_match = self._callbacks.get(ent_id)
self.recurse(_tree, id_to_position, _node_operator_map, 0, [], matched_trees)
for matched_tree in matched_trees:
matched_key_trees.append((key, matched_tree))
for i, (match_id, nodes) in enumerate(matched_key_trees):
on_match = self._callbacks.get(match_id)
if on_match is not None:
on_match(self, doc, i, matched_key_trees)
return matched_key_trees
def recurse(self,tree,id_to_position,_node_operator_map,int patternLength,visited_nodes,matched_trees):
cdef bool isValid;
if(patternLength == len(id_to_position.keys())):
def recurse(self, tree, id_to_position, _node_operator_map, int patternLength, visited_nodes, matched_trees):
cdef bint isValid;
if patternLength == len(id_to_position.keys()):
isValid = True
for node in range(patternLength):
if(node in tree):
if node in tree:
for idx, (relop,nbor) in enumerate(tree[node]):
computed_nbors = numpy.asarray(_node_operator_map[visited_nodes[node]][relop])
isNbor = False
for computed_nbor in computed_nbors:
if(computed_nbor.i == visited_nodes[nbor]):
if computed_nbor.i == visited_nodes[nbor]:
isNbor = True
isValid = isValid & isNbor
if(isValid):
@ -271,14 +310,14 @@ cdef class DependencyMatcher:
return
allPatternNodes = numpy.asarray(id_to_position[patternLength])
for patternNode in allPatternNodes:
self.recurse(tree,id_to_position,_node_operator_map,patternLength+1,visited_nodes+[patternNode],matched_trees)
self.recurse(tree, id_to_position, _node_operator_map, patternLength+1, visited_nodes+[patternNode], matched_trees)
# Given a node and an edge operator, to return the list of nodes
# from the doc that belong to node+operator. This is used to store
# all the results beforehand to prevent unnecessary computation while
# pattern matching
# _node_operator_map[node][operator] = [...]
def get_node_operator_map(self,doc,tree,id_to_position,nodes,root):
def get_node_operator_map(self, doc, tree, id_to_position, nodes, root):
_node_operator_map = {}
all_node_indices = nodes.values()
all_operators = []
@ -295,24 +334,14 @@ cdef class DependencyMatcher:
_node_operator_map[node] = {}
for operator in all_operators:
_node_operator_map[node][operator] = []
# Used to invoke methods for each operator
switcher = {
"<": self.dep,
">": self.gov,
"<<": self.dep_chain,
">>": self.gov_chain,
".": self.imm_precede,
"$+": self.imm_right_sib,
"$-": self.imm_left_sib,
"$++": self.right_sib,
"$--": self.left_sib
}
for operator in all_operators:
for node in all_nodes:
_node_operator_map[node][operator] = switcher.get(operator)(doc,node)
_node_operator_map[node][operator] = self._ops.get(operator)(doc, node)
return _node_operator_map
def dep(self, doc, node):
if doc[node].head == doc[node]:
return []
return [doc[node].head]
def gov(self,doc,node):
@ -322,36 +351,51 @@ cdef class DependencyMatcher:
return list(doc[node].ancestors)
def gov_chain(self, doc, node):
return list(doc[node].subtree)
return [t for t in doc[node].subtree if t != doc[node]]
def imm_precede(self, doc, node):
if node > 0:
sent = self._get_sent(doc[node])
if node < len(doc) - 1 and doc[node + 1] in sent:
return [doc[node + 1]]
return []
def precede(self, doc, node):
sent = self._get_sent(doc[node])
return [doc[i] for i in range(node + 1, sent.end)]
def imm_follow(self, doc, node):
sent = self._get_sent(doc[node])
if node > 0 and doc[node - 1] in sent:
return [doc[node - 1]]
return []
def follow(self, doc, node):
sent = self._get_sent(doc[node])
return [doc[i] for i in range(sent.start, node)]
def imm_right_sib(self, doc, node):
for child in list(doc[node].head.children):
if child.i == node - 1:
if child.i == node + 1:
return [doc[child.i]]
return []
def imm_left_sib(self, doc, node):
for child in list(doc[node].head.children):
if child.i == node + 1:
if child.i == node - 1:
return [doc[child.i]]
return []
def right_sib(self, doc, node):
candidate_children = []
for child in list(doc[node].head.children):
if child.i < node:
if child.i > node:
candidate_children.append(doc[child.i])
return candidate_children
def left_sib(self, doc, node):
candidate_children = []
for child in list(doc[node].head.children):
if child.i > node:
if child.i < node:
candidate_children.append(doc[child.i])
return candidate_children
@ -360,3 +404,15 @@ cdef class DependencyMatcher:
return self.vocab.strings.add(key)
else:
return key
def _get_sent(self, token):
root = (list(token.ancestors) or [token])[-1]
return token.doc[root.left_edge.i:root.right_edge.i + 1]
def unpickle_matcher(vocab, patterns, callbacks):
matcher = DependencyMatcher(vocab)
for key, pattern in patterns.items():
callback = callbacks.get(key, None)
matcher.add(key, pattern, on_match=callback)
return matcher

View File

@ -0,0 +1,334 @@
import pytest
import pickle
import re
import copy
from mock import Mock
from spacy.matcher import DependencyMatcher
from ..util import get_doc
@pytest.fixture
def doc(en_vocab):
text = "The quick brown fox jumped over the lazy fox"
heads = [3, 2, 1, 1, 0, -1, 2, 1, -3]
deps = ["det", "amod", "amod", "nsubj", "ROOT", "prep", "pobj", "det", "amod"]
doc = get_doc(en_vocab, text.split(), heads=heads, deps=deps)
return doc
@pytest.fixture
def patterns(en_vocab):
def is_brown_yellow(text):
return bool(re.compile(r"brown|yellow").match(text))
IS_BROWN_YELLOW = en_vocab.add_flag(is_brown_yellow)
pattern1 = [
{"RIGHT_ID": "fox", "RIGHT_ATTRS": {"ORTH": "fox"}},
{
"LEFT_ID": "fox",
"REL_OP": ">",
"RIGHT_ID": "q",
"RIGHT_ATTRS": {"ORTH": "quick", "DEP": "amod"},
},
{
"LEFT_ID": "fox",
"REL_OP": ">",
"RIGHT_ID": "r",
"RIGHT_ATTRS": {IS_BROWN_YELLOW: True},
},
]
pattern2 = [
{"RIGHT_ID": "jumped", "RIGHT_ATTRS": {"ORTH": "jumped"}},
{
"LEFT_ID": "jumped",
"REL_OP": ">",
"RIGHT_ID": "fox1",
"RIGHT_ATTRS": {"ORTH": "fox"},
},
{
"LEFT_ID": "jumped",
"REL_OP": ".",
"RIGHT_ID": "over",
"RIGHT_ATTRS": {"ORTH": "over"},
},
]
pattern3 = [
{"RIGHT_ID": "jumped", "RIGHT_ATTRS": {"ORTH": "jumped"}},
{
"LEFT_ID": "jumped",
"REL_OP": ">",
"RIGHT_ID": "fox",
"RIGHT_ATTRS": {"ORTH": "fox"},
},
{
"LEFT_ID": "fox",
"REL_OP": ">>",
"RIGHT_ID": "r",
"RIGHT_ATTRS": {"ORTH": "brown"},
},
]
pattern4 = [
{"RIGHT_ID": "jumped", "RIGHT_ATTRS": {"ORTH": "jumped"}},
{
"LEFT_ID": "jumped",
"REL_OP": ">",
"RIGHT_ID": "fox",
"RIGHT_ATTRS": {"ORTH": "fox"},
}
]
pattern5 = [
{"RIGHT_ID": "jumped", "RIGHT_ATTRS": {"ORTH": "jumped"}},
{
"LEFT_ID": "jumped",
"REL_OP": ">>",
"RIGHT_ID": "fox",
"RIGHT_ATTRS": {"ORTH": "fox"},
},
]
return [pattern1, pattern2, pattern3, pattern4, pattern5]
@pytest.fixture
def dependency_matcher(en_vocab, patterns, doc):
matcher = DependencyMatcher(en_vocab)
mock = Mock()
for i in range(1, len(patterns) + 1):
if i == 1:
matcher.add("pattern1", [patterns[0]], on_match=mock)
else:
matcher.add("pattern" + str(i), [patterns[i - 1]])
return matcher
def test_dependency_matcher(dependency_matcher, doc, patterns):
assert len(dependency_matcher) == 5
assert "pattern3" in dependency_matcher
assert dependency_matcher.get("pattern3") == (None, [patterns[2]])
matches = dependency_matcher(doc)
assert len(matches) == 6
assert matches[0][1] == [3, 1, 2]
assert matches[1][1] == [4, 3, 5]
assert matches[2][1] == [4, 3, 2]
assert matches[3][1] == [4, 3]
assert matches[4][1] == [4, 3]
assert matches[5][1] == [4, 8]
span = doc[0:6]
matches = dependency_matcher(span)
assert len(matches) == 5
assert matches[0][1] == [3, 1, 2]
assert matches[1][1] == [4, 3, 5]
assert matches[2][1] == [4, 3, 2]
assert matches[3][1] == [4, 3]
assert matches[4][1] == [4, 3]
def test_dependency_matcher_pickle(en_vocab, patterns, doc):
matcher = DependencyMatcher(en_vocab)
for i in range(1, len(patterns) + 1):
matcher.add("pattern" + str(i), [patterns[i - 1]])
matches = matcher(doc)
assert matches[0][1] == [3, 1, 2]
assert matches[1][1] == [4, 3, 5]
assert matches[2][1] == [4, 3, 2]
assert matches[3][1] == [4, 3]
assert matches[4][1] == [4, 3]
assert matches[5][1] == [4, 8]
b = pickle.dumps(matcher)
matcher_r = pickle.loads(b)
assert len(matcher) == len(matcher_r)
matches = matcher_r(doc)
assert matches[0][1] == [3, 1, 2]
assert matches[1][1] == [4, 3, 5]
assert matches[2][1] == [4, 3, 2]
assert matches[3][1] == [4, 3]
assert matches[4][1] == [4, 3]
assert matches[5][1] == [4, 8]
def test_dependency_matcher_pattern_validation(en_vocab):
pattern = [
{"RIGHT_ID": "fox", "RIGHT_ATTRS": {"ORTH": "fox"}},
{
"LEFT_ID": "fox",
"REL_OP": ">",
"RIGHT_ID": "q",
"RIGHT_ATTRS": {"ORTH": "quick", "DEP": "amod"},
},
{
"LEFT_ID": "fox",
"REL_OP": ">",
"RIGHT_ID": "r",
"RIGHT_ATTRS": {"ORTH": "brown"},
},
]
matcher = DependencyMatcher(en_vocab)
# original pattern is valid
matcher.add("FOUNDED", [pattern])
# individual pattern not wrapped in a list
with pytest.raises(ValueError):
matcher.add("FOUNDED", pattern)
# no anchor node
with pytest.raises(ValueError):
matcher.add("FOUNDED", [pattern[1:]])
# required keys missing
with pytest.raises(ValueError):
pattern2 = copy.deepcopy(pattern)
del pattern2[0]["RIGHT_ID"]
matcher.add("FOUNDED", [pattern2])
with pytest.raises(ValueError):
pattern2 = copy.deepcopy(pattern)
del pattern2[1]["RIGHT_ID"]
matcher.add("FOUNDED", [pattern2])
with pytest.raises(ValueError):
pattern2 = copy.deepcopy(pattern)
del pattern2[1]["RIGHT_ATTRS"]
matcher.add("FOUNDED", [pattern2])
with pytest.raises(ValueError):
pattern2 = copy.deepcopy(pattern)
del pattern2[1]["LEFT_ID"]
matcher.add("FOUNDED", [pattern2])
with pytest.raises(ValueError):
pattern2 = copy.deepcopy(pattern)
del pattern2[1]["REL_OP"]
matcher.add("FOUNDED", [pattern2])
# invalid operator
with pytest.raises(ValueError):
pattern2 = copy.deepcopy(pattern)
pattern2[1]["REL_OP"] = "!!!"
matcher.add("FOUNDED", [pattern2])
# duplicate node name
with pytest.raises(ValueError):
pattern2 = copy.deepcopy(pattern)
pattern2[1]["RIGHT_ID"] = "fox"
matcher.add("FOUNDED", [pattern2])
def test_dependency_matcher_callback(en_vocab, doc):
pattern = [
{"RIGHT_ID": "quick", "RIGHT_ATTRS": {"ORTH": "quick"}},
]
matcher = DependencyMatcher(en_vocab)
mock = Mock()
matcher.add("pattern", [pattern], on_match=mock)
matches = matcher(doc)
mock.assert_called_once_with(matcher, doc, 0, matches)
# check that matches with and without callback are the same (#4590)
matcher2 = DependencyMatcher(en_vocab)
matcher2.add("pattern", [pattern])
matches2 = matcher2(doc)
assert matches == matches2
@pytest.mark.parametrize(
"op,num_matches", [(".", 8), (".*", 20), (";", 8), (";*", 20),]
)
def test_dependency_matcher_precedence_ops(en_vocab, op, num_matches):
# two sentences to test that all matches are within the same sentence
doc = get_doc(
en_vocab,
words=["a", "b", "c", "d", "e"] * 2,
heads=[0, -1, -2, -3, -4] * 2,
deps=["dep"] * 10,
)
match_count = 0
for text in ["a", "b", "c", "d", "e"]:
pattern = [
{"RIGHT_ID": "1", "RIGHT_ATTRS": {"ORTH": text}},
{"LEFT_ID": "1", "REL_OP": op, "RIGHT_ID": "2", "RIGHT_ATTRS": {},},
]
matcher = DependencyMatcher(en_vocab)
matcher.add("A", [pattern])
matches = matcher(doc)
match_count += len(matches)
for match in matches:
match_id, token_ids = match
# token_ids[0] op token_ids[1]
if op == ".":
assert token_ids[0] == token_ids[1] - 1
elif op == ";":
assert token_ids[0] == token_ids[1] + 1
elif op == ".*":
assert token_ids[0] < token_ids[1]
elif op == ";*":
assert token_ids[0] > token_ids[1]
# all tokens are within the same sentence
assert doc[token_ids[0]].sent == doc[token_ids[1]].sent
assert match_count == num_matches
@pytest.mark.parametrize(
"left,right,op,num_matches",
[
("fox", "jumped", "<", 1),
("the", "lazy", "<", 0),
("jumped", "jumped", "<", 0),
("fox", "jumped", ">", 0),
("fox", "lazy", ">", 1),
("lazy", "lazy", ">", 0),
("fox", "jumped", "<<", 2),
("jumped", "fox", "<<", 0),
("the", "fox", "<<", 2),
("fox", "jumped", ">>", 0),
("over", "the", ">>", 1),
("fox", "the", ">>", 2),
("fox", "jumped", ".", 1),
("lazy", "fox", ".", 1),
("the", "fox", ".", 0),
("the", "the", ".", 0),
("fox", "jumped", ";", 0),
("lazy", "fox", ";", 0),
("the", "fox", ";", 0),
("the", "the", ";", 0),
("quick", "fox", ".*", 2),
("the", "fox", ".*", 3),
("the", "the", ".*", 1),
("fox", "jumped", ";*", 1),
("quick", "fox", ";*", 0),
("the", "fox", ";*", 1),
("the", "the", ";*", 1),
("quick", "brown", "$+", 1),
("brown", "quick", "$+", 0),
("brown", "brown", "$+", 0),
("quick", "brown", "$-", 0),
("brown", "quick", "$-", 1),
("brown", "brown", "$-", 0),
("the", "brown", "$++", 1),
("brown", "the", "$++", 0),
("brown", "brown", "$++", 0),
("the", "brown", "$--", 0),
("brown", "the", "$--", 1),
("brown", "brown", "$--", 0),
],
)
def test_dependency_matcher_ops(en_vocab, doc, left, right, op, num_matches):
right_id = right
if left == right:
right_id = right + "2"
pattern = [
{"RIGHT_ID": left, "RIGHT_ATTRS": {"LOWER": left}},
{
"LEFT_ID": left,
"REL_OP": op,
"RIGHT_ID": right_id,
"RIGHT_ATTRS": {"LOWER": right},
},
]
matcher = DependencyMatcher(en_vocab)
matcher.add("pattern", [pattern])
matches = matcher(doc)
assert len(matches) == num_matches

View File

@ -1,7 +1,6 @@
import pytest
import re
from mock import Mock
from spacy.matcher import Matcher, DependencyMatcher
from spacy.matcher import Matcher
from spacy.tokens import Doc, Token, Span
from ..doc.test_underscore import clean_underscore # noqa: F401
@ -292,84 +291,6 @@ def test_matcher_extension_set_membership(en_vocab):
assert len(matches) == 0
@pytest.fixture
def text():
return "The quick brown fox jumped over the lazy fox"
@pytest.fixture
def heads():
return [3, 2, 1, 1, 0, -1, 2, 1, -3]
@pytest.fixture
def deps():
return ["det", "amod", "amod", "nsubj", "prep", "pobj", "det", "amod"]
@pytest.fixture
def dependency_matcher(en_vocab):
def is_brown_yellow(text):
return bool(re.compile(r"brown|yellow|over").match(text))
IS_BROWN_YELLOW = en_vocab.add_flag(is_brown_yellow)
pattern1 = [
{"SPEC": {"NODE_NAME": "fox"}, "PATTERN": {"ORTH": "fox"}},
{
"SPEC": {"NODE_NAME": "q", "NBOR_RELOP": ">", "NBOR_NAME": "fox"},
"PATTERN": {"ORTH": "quick", "DEP": "amod"},
},
{
"SPEC": {"NODE_NAME": "r", "NBOR_RELOP": ">", "NBOR_NAME": "fox"},
"PATTERN": {IS_BROWN_YELLOW: True},
},
]
pattern2 = [
{"SPEC": {"NODE_NAME": "jumped"}, "PATTERN": {"ORTH": "jumped"}},
{
"SPEC": {"NODE_NAME": "fox", "NBOR_RELOP": ">", "NBOR_NAME": "jumped"},
"PATTERN": {"ORTH": "fox"},
},
{
"SPEC": {"NODE_NAME": "quick", "NBOR_RELOP": ".", "NBOR_NAME": "jumped"},
"PATTERN": {"ORTH": "fox"},
},
]
pattern3 = [
{"SPEC": {"NODE_NAME": "jumped"}, "PATTERN": {"ORTH": "jumped"}},
{
"SPEC": {"NODE_NAME": "fox", "NBOR_RELOP": ">", "NBOR_NAME": "jumped"},
"PATTERN": {"ORTH": "fox"},
},
{
"SPEC": {"NODE_NAME": "r", "NBOR_RELOP": ">>", "NBOR_NAME": "fox"},
"PATTERN": {"ORTH": "brown"},
},
]
matcher = DependencyMatcher(en_vocab)
matcher.add("pattern1", [pattern1])
matcher.add("pattern2", [pattern2])
matcher.add("pattern3", [pattern3])
return matcher
def test_dependency_matcher_compile(dependency_matcher):
assert len(dependency_matcher) == 3
# def test_dependency_matcher(dependency_matcher, text, heads, deps):
# doc = get_doc(dependency_matcher.vocab, text.split(), heads=heads, deps=deps)
# matches = dependency_matcher(doc)
# assert matches[0][1] == [[3, 1, 2]]
# assert matches[1][1] == [[4, 3, 3]]
# assert matches[2][1] == [[4, 3, 2]]
def test_matcher_basic_check(en_vocab):
matcher = Matcher(en_vocab)
# Potential mistake: pass in pattern instead of list of patterns

View File

@ -38,32 +38,6 @@ def test_gold_misaligned(en_tokenizer, text, words):
Example.from_dict(doc, {"words": words})
def test_issue4590(en_vocab):
"""Test that matches param in on_match method are the same as matches run with no on_match method"""
pattern = [
{"SPEC": {"NODE_NAME": "jumped"}, "PATTERN": {"ORTH": "jumped"}},
{
"SPEC": {"NODE_NAME": "fox", "NBOR_RELOP": ">", "NBOR_NAME": "jumped"},
"PATTERN": {"ORTH": "fox"},
},
{
"SPEC": {"NODE_NAME": "quick", "NBOR_RELOP": ".", "NBOR_NAME": "jumped"},
"PATTERN": {"ORTH": "fox"},
},
]
on_match = Mock()
matcher = DependencyMatcher(en_vocab)
matcher.add("pattern", on_match, pattern)
text = "The quick brown fox jumped over the lazy fox"
heads = [3, 2, 1, 1, 0, -1, 2, 1, -3]
deps = ["det", "amod", "amod", "nsubj", "ROOT", "prep", "det", "amod", "pobj"]
doc = get_doc(en_vocab, text.split(), heads=heads, deps=deps)
matches = matcher(doc)
on_match_args = on_match.call_args
assert on_match_args[0][3] == matches
def test_issue4651_with_phrase_matcher_attr():
"""Test that the EntityRuler PhraseMatcher is deserialized correctly using
the method from_disk when the EntityRuler argument phrase_matcher_attr is

View File

@ -1,65 +1,91 @@
---
title: DependencyMatcher
teaser: Match sequences of tokens, based on the dependency parse
teaser: Match subtrees within a dependency parse
tag: class
new: 3
source: spacy/matcher/dependencymatcher.pyx
---
The `DependencyMatcher` follows the same API as the [`Matcher`](/api/matcher)
and [`PhraseMatcher`](/api/phrasematcher) and lets you match on dependency trees
using the
[Semgrex syntax](https://nlp.stanford.edu/nlp/javadoc/javanlp/edu/stanford/nlp/semgraph/semgrex/SemgrexPattern.html).
It requires a trained [`DependencyParser`](/api/parser) or other component that
sets the `Token.dep` attribute.
using
[Semgrex operators](https://nlp.stanford.edu/nlp/javadoc/javanlp/edu/stanford/nlp/semgraph/semgrex/SemgrexPattern.html).
It requires a pretrained [`DependencyParser`](/api/parser) or other component
that sets the `Token.dep` and `Token.head` attributes. See the
[usage guide](/usage/rule-based-matching#dependencymatcher) for examples.
## Pattern format {#patterns}
> ```json
> ```python
> ### Example
> # pattern: "[subject] ... initially founded"
> [
> # anchor token: founded
> {
> "SPEC": {"NODE_NAME": "founded"},
> "PATTERN": {"ORTH": "founded"}
> "RIGHT_ID": "founded",
> "RIGHT_ATTRS": {"ORTH": "founded"}
> },
> # founded -> subject
> {
> "SPEC": {
> "NODE_NAME": "founder",
> "NBOR_RELOP": ">",
> "NBOR_NAME": "founded"
> },
> "PATTERN": {"DEP": "nsubj"}
> "LEFT_ID": "founded",
> "REL_OP": ">",
> "RIGHT_ID": "subject",
> "RIGHT_ATTRS": {"DEP": "nsubj"}
> },
> # "founded" follows "initially"
> {
> "SPEC": {
> "NODE_NAME": "object",
> "NBOR_RELOP": ">",
> "NBOR_NAME": "founded"
> },
> "PATTERN": {"DEP": "dobj"}
> "LEFT_ID": "founded",
> "REL_OP": ";",
> "RIGHT_ID": "initially",
> "RIGHT_ATTRS": {"ORTH": "initially"}
> }
> ]
> ```
A pattern added to the `DependencyMatcher` consists of a list of dictionaries,
with each dictionary describing a node to match. Each pattern should have the
following top-level keys:
with each dictionary describing a token to match. Except for the first
dictionary, which defines an anchor token using only `RIGHT_ID` and
`RIGHT_ATTRS`, each pattern should have the following keys:
| Name | Description |
| --------- | ---------------------------------------------------------------------------------------------------------------------------------------------- |
| `PATTERN` | The token attributes to match in the same format as patterns provided to the regular token-based [`Matcher`](/api/matcher). ~~Dict[str, Any]~~ |
| `SPEC` | The relationships of the nodes in the subtree that should be matched. ~~Dict[str, str]~~ |
| Name | Description |
| ------------- | ---------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
| `LEFT_ID` | The name of the left-hand node in the relation, which has been defined in an earlier node. ~~str~~ |
| `REL_OP` | An operator that describes how the two nodes are related. ~~str~~ |
| `RIGHT_ID` | A unique name for the right-hand node in the relation. ~~str~~ |
| `RIGHT_ATTRS` | The token attributes to match for the right-hand node in the same format as patterns provided to the regular token-based [`Matcher`](/api/matcher). ~~Dict[str, Any]~~ |
The `SPEC` includes the following fields:
<Infobox title="Designing dependency matcher patterns" emoji="📖">
| Name | Description |
| ------------ | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ |
| `NODE_NAME` | A unique name for this node to refer to it in other specs. ~~str~~ |
| `NBOR_RELOP` | A [Semgrex](https://nlp.stanford.edu/nlp/javadoc/javanlp/edu/stanford/nlp/semgraph/semgrex/SemgrexPattern.html) operator that describes how the two nodes are related. ~~str~~ |
| `NBOR_NAME` | The unique name of the node that this node is connected to. ~~str~~ |
For examples of how to construct dependency matcher patterns for different types
of relations, see the usage guide on
[dependency matching](/usage/rule-based-matching#dependencymatcher).
</Infobox>
### Operators
The following operators are supported by the `DependencyMatcher`, most of which
come directly from
[Semgrex](https://nlp.stanford.edu/nlp/javadoc/javanlp/edu/stanford/nlp/semgraph/semgrex/SemgrexPattern.html):
| Symbol | Description |
| --------- | -------------------------------------------------------------------------------------------------------------------- |
| `A < B` | `A` is the immediate dependent of `B`. |
| `A > B` | `A` is the immediate head of `B`. |
| `A << B` | `A` is the dependent in a chain to `B` following dep &rarr; head paths. |
| `A >> B` | `A` is the head in a chain to `B` following head &rarr; dep paths. |
| `A . B` | `A` immediately precedes `B`, i.e. `A.i == B.i - 1`, and both are within the same dependency tree. |
| `A .* B` | `A` precedes `B`, i.e. `A.i < B.i`, and both are within the same dependency tree _(not in Semgrex)_. |
| `A ; B` | `A` immediately follows `B`, i.e. `A.i == B.i + 1`, and both are within the same dependency tree _(not in Semgrex)_. |
| `A ;* B` | `A` follows `B`, i.e. `A.i > B.i`, and both are within the same dependency tree _(not in Semgrex)_. |
| `A $+ B` | `B` is a right immediate sibling of `A`, i.e. `A` and `B` have the same parent and `A.i == B.i - 1`. |
| `A $- B` | `B` is a left immediate sibling of `A`, i.e. `A` and `B` have the same parent and `A.i == B.i + 1`. |
| `A $++ B` | `B` is a right sibling of `A`, i.e. `A` and `B` have the same parent and `A.i < B.i`. |
| `A $-- B` | `B` is a left sibling of `A`, i.e. `A` and `B` have the same parent and `A.i > B.i`. |
## DependencyMatcher.\_\_init\_\_ {#init tag="method"}
Create a rule-based `DependencyMatcher`.
Create a `DependencyMatcher`.
> #### Example
>
@ -68,13 +94,15 @@ Create a rule-based `DependencyMatcher`.
> matcher = DependencyMatcher(nlp.vocab)
> ```
| Name | Description |
| ------- | ----------------------------------------------------------------------------------------------------- |
| `vocab` | The vocabulary object, which must be shared with the documents the matcher will operate on. ~~Vocab~~ |
| Name | Description |
| -------------- | ----------------------------------------------------------------------------------------------------- |
| `vocab` | The vocabulary object, which must be shared with the documents the matcher will operate on. ~~Vocab~~ |
| _keyword-only_ | |
| `validate` | Validate all patterns added to this matcher. ~~bool~~ |
## DependencyMatcher.\_\call\_\_ {#call tag="method"}
Find all token sequences matching the supplied patterns on the `Doc` or `Span`.
Find all tokens matching the supplied patterns on the `Doc` or `Span`.
> #### Example
>
@ -82,36 +110,32 @@ Find all token sequences matching the supplied patterns on the `Doc` or `Span`.
> from spacy.matcher import DependencyMatcher
>
> matcher = DependencyMatcher(nlp.vocab)
> pattern = [
> {"SPEC": {"NODE_NAME": "founded"}, "PATTERN": {"ORTH": "founded"}},
> {"SPEC": {"NODE_NAME": "founder", "NBOR_RELOP": ">", "NBOR_NAME": "founded"}, "PATTERN": {"DEP": "nsubj"}},
> ]
> matcher.add("Founder", [pattern])
> pattern = [{"RIGHT_ID": "founded_id",
> "RIGHT_ATTRS": {"ORTH": "founded"}}]
> matcher.add("FOUNDED", [pattern])
> doc = nlp("Bill Gates founded Microsoft.")
> matches = matcher(doc)
> ```
| Name | Description |
| ----------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
| `doclike` | The `Doc` or `Span` to match over. ~~Union[Doc, Span]~~ |
| **RETURNS** | A list of `(match_id, start, end)` tuples, describing the matches. A match tuple describes a span `doc[start:end`]. The `match_id` is the ID of the added match pattern. ~~List[Tuple[int, int, int]]~~ |
| Name | Description |
| ----------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
| `doclike` | The `Doc` or `Span` to match over. ~~Union[Doc, Span]~~ |
| **RETURNS** | A list of `(match_id, token_ids)` tuples, describing the matches. The `match_id` is the ID of the match pattern and `token_ids` is a list of token indices matched by the pattern, where the position of each token in the list corresponds to the position of the node specification in the pattern. ~~List[Tuple[int, List[int]]]~~ |
## DependencyMatcher.\_\_len\_\_ {#len tag="method"}
Get the number of rules (edges) added to the dependency matcher. Note that this
only returns the number of rules (identical with the number of IDs), not the
number of individual patterns.
Get the number of rules added to the dependency matcher. Note that this only
returns the number of rules (identical with the number of IDs), not the number
of individual patterns.
> #### Example
>
> ```python
> matcher = DependencyMatcher(nlp.vocab)
> assert len(matcher) == 0
> pattern = [
> {"SPEC": {"NODE_NAME": "founded"}, "PATTERN": {"ORTH": "founded"}},
> {"SPEC": {"NODE_NAME": "START_ENTITY", "NBOR_RELOP": ">", "NBOR_NAME": "founded"}, "PATTERN": {"DEP": "nsubj"}},
> ]
> matcher.add("Rule", [pattern])
> pattern = [{"RIGHT_ID": "founded_id",
> "RIGHT_ATTRS": {"ORTH": "founded"}}]
> matcher.add("FOUNDED", [pattern])
> assert len(matcher) == 1
> ```
@ -126,10 +150,10 @@ Check whether the matcher contains rules for a match ID.
> #### Example
>
> ```python
> matcher = Matcher(nlp.vocab)
> assert "Rule" not in matcher
> matcher.add("Rule", [pattern])
> assert "Rule" in matcher
> matcher = DependencyMatcher(nlp.vocab)
> assert "FOUNDED" not in matcher
> matcher.add("FOUNDED", [pattern])
> assert "FOUNDED" in matcher
> ```
| Name | Description |
@ -152,33 +176,15 @@ will be overwritten.
> print('Matched!', matches)
>
> matcher = DependencyMatcher(nlp.vocab)
> matcher.add("TEST_PATTERNS", patterns)
> matcher.add("FOUNDED", patterns, on_match=on_match)
> ```
| Name | Description |
| -------------- | ---------------------------------------------------------------------------------------------------------------------------------------------------------- |
| `match_id` | An ID for the thing you're matching. ~~str~~ |
| `patterns` | list | Match pattern. A pattern consists of a list of dicts, where each dict describes a `"PATTERN"` and `"SPEC"`. ~~List[List[Dict[str, dict]]]~~ |
| _keyword-only_ | | |
| `on_match` | Callback function to act on matches. Takes the arguments `matcher`, `doc`, `i` and `matches`. ~~Optional[Callable[[Matcher, Doc, int, List[tuple], Any]]~~ |
## DependencyMatcher.remove {#remove tag="method"}
Remove a rule from the matcher. A `KeyError` is raised if the match ID does not
exist.
> #### Example
>
> ```python
> matcher.add("Rule", [pattern]])
> assert "Rule" in matcher
> matcher.remove("Rule")
> assert "Rule" not in matcher
> ```
| Name | Description |
| ----- | --------------------------------- |
| `key` | The ID of the match rule. ~~str~~ |
| Name | Description |
| -------------- | -------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
| `match_id` | An ID for the patterns. ~~str~~ |
| `patterns` | A list of match patterns. A pattern consists of a list of dicts, where each dict describes a token in the tree. ~~List[List[Dict[str, Union[str, Dict]]]]~~ |
| _keyword-only_ | | |
| `on_match` | Callback function to act on matches. Takes the arguments `matcher`, `doc`, `i` and `matches`. ~~Optional[Callable[[DependencyMatcher, Doc, int, List[Tuple], Any]]~~ |
## DependencyMatcher.get {#get tag="method"}
@ -188,11 +194,29 @@ Retrieve the pattern stored for a key. Returns the rule as an
> #### Example
>
> ```python
> matcher.add("Rule", [pattern], on_match=on_match)
> on_match, patterns = matcher.get("Rule")
> matcher.add("FOUNDED", patterns, on_match=on_match)
> on_match, patterns = matcher.get("FOUNDED")
> ```
| Name | Description |
| ----------- | --------------------------------------------------------------------------------------------- |
| `key` | The ID of the match rule. ~~str~~ |
| **RETURNS** | The rule, as an `(on_match, patterns)` tuple. ~~Tuple[Optional[Callable], List[List[dict]]]~~ |
| Name | Description |
| ----------- | ----------------------------------------------------------------------------------------------------------- |
| `key` | The ID of the match rule. ~~str~~ |
| **RETURNS** | The rule, as an `(on_match, patterns)` tuple. ~~Tuple[Optional[Callable], List[List[Union[Dict, Tuple]]]]~~ |
## DependencyMatcher.remove {#remove tag="method"}
Remove a rule from the dependency matcher. A `KeyError` is raised if the match
ID does not exist.
> #### Example
>
> ```python
> matcher.add("FOUNDED", patterns)
> assert "FOUNDED" in matcher
> matcher.remove("FOUNDED")
> assert "FOUNDED" not in matcher
> ```
| Name | Description |
| ----- | --------------------------------- |
| `key` | The ID of the match rule. ~~str~~ |

File diff suppressed because one or more lines are too long

After

Width:  |  Height:  |  Size: 25 KiB

View File

@ -0,0 +1,58 @@
<svg xmlns="http://www.w3.org/2000/svg" xlink="http://www.w3.org/1999/xlink" xml:lang="en" id="c3124cc3e661444cb9d4175a5b7c09d1-0" class="displacy" width="925" height="399.5" direction="ltr" style="max-width: none; height: 399.5px; color: #000000; background: #ffffff; font-family: Arial; direction: ltr">
<text class="displacy-token" fill="currentColor" text-anchor="middle" y="309.5">
<tspan class="displacy-word" fill="currentColor" x="50">Smith</tspan>
<tspan class="displacy-tag" dy="2em" fill="currentColor" x="50"></tspan>
</text>
<text class="displacy-token" fill="currentColor" text-anchor="middle" y="309.5">
<tspan class="displacy-word" fill="currentColor" x="225">founded</tspan>
<tspan class="displacy-tag" dy="2em" fill="currentColor" x="225"></tspan>
</text>
<text class="displacy-token" fill="currentColor" text-anchor="middle" y="309.5">
<tspan class="displacy-word" fill="currentColor" x="400">a</tspan>
<tspan class="displacy-tag" dy="2em" fill="currentColor" x="400"></tspan>
</text>
<text class="displacy-token" fill="currentColor" text-anchor="middle" y="309.5">
<tspan class="displacy-word" fill="currentColor" x="575">healthcare</tspan>
<tspan class="displacy-tag" dy="2em" fill="currentColor" x="575"></tspan>
</text>
<text class="displacy-token" fill="currentColor" text-anchor="middle" y="309.5">
<tspan class="displacy-word" fill="currentColor" x="750">company</tspan>
<tspan class="displacy-tag" dy="2em" fill="currentColor" x="750"></tspan>
</text>
<g class="displacy-arrow">
<path class="displacy-arc" id="arrow-c3124cc3e661444cb9d4175a5b7c09d1-0-0" stroke-width="2px" d="M70,264.5 C70,177.0 215.0,177.0 215.0,264.5" fill="none" stroke="currentColor"></path>
<text dy="1.25em" style="font-size: 0.8em; letter-spacing: 1px">
<textPath xlink:href="#arrow-c3124cc3e661444cb9d4175a5b7c09d1-0-0" class="displacy-label" startOffset="50%" side="left" fill="currentColor" text-anchor="middle">nsubj</textPath>
</text>
<path class="displacy-arrowhead" d="M70,266.5 L62,254.5 78,254.5" fill="currentColor"></path>
</g>
<g class="displacy-arrow">
<path class="displacy-arc" id="arrow-c3124cc3e661444cb9d4175a5b7c09d1-0-1" stroke-width="2px" d="M420,264.5 C420,89.5 745.0,89.5 745.0,264.5" fill="none" stroke="currentColor"></path>
<text dy="1.25em" style="font-size: 0.8em; letter-spacing: 1px">
<textPath xlink:href="#arrow-c3124cc3e661444cb9d4175a5b7c09d1-0-1" class="displacy-label" startOffset="50%" side="left" fill="currentColor" text-anchor="middle">det</textPath>
</text>
<path class="displacy-arrowhead" d="M420,266.5 L412,254.5 428,254.5" fill="currentColor"></path>
</g>
<g class="displacy-arrow">
<path class="displacy-arc" id="arrow-c3124cc3e661444cb9d4175a5b7c09d1-0-2" stroke-width="2px" d="M595,264.5 C595,177.0 740.0,177.0 740.0,264.5" fill="none" stroke="currentColor"></path>
<text dy="1.25em" style="font-size: 0.8em; letter-spacing: 1px">
<textPath xlink:href="#arrow-c3124cc3e661444cb9d4175a5b7c09d1-0-2" class="displacy-label" startOffset="50%" side="left" fill="currentColor" text-anchor="middle">compound</textPath>
</text>
<path class="displacy-arrowhead" d="M595,266.5 L587,254.5 603,254.5" fill="currentColor"></path>
</g>
<g class="displacy-arrow">
<path class="displacy-arc" id="arrow-c3124cc3e661444cb9d4175a5b7c09d1-0-3" stroke-width="2px" d="M245,264.5 C245,2.0 750.0,2.0 750.0,264.5" fill="none" stroke="currentColor"></path>
<text dy="1.25em" style="font-size: 0.8em; letter-spacing: 1px">
<textPath xlink:href="#arrow-c3124cc3e661444cb9d4175a5b7c09d1-0-3" class="displacy-label" startOffset="50%" side="left" fill="currentColor" text-anchor="middle">dobj</textPath>
</text>
<path class="displacy-arrowhead" d="M750.0,266.5 L758.0,254.5 742.0,254.5" fill="currentColor"></path>
</g>
</svg>

After

Width:  |  Height:  |  Size: 3.8 KiB

View File

@ -4,6 +4,7 @@ teaser: Find phrases and tokens, and match entities
menu:
- ['Token Matcher', 'matcher']
- ['Phrase Matcher', 'phrasematcher']
- ['Dependency Matcher', 'dependencymatcher']
- ['Entity Ruler', 'entityruler']
- ['Models & Rules', 'models-rules']
---
@ -939,10 +940,10 @@ object patterns as efficiently as possible and without running any of the other
pipeline components. If the token attribute you want to match on are set by a
pipeline component, **make sure that the pipeline component runs** when you
create the pattern. For example, to match on `POS` or `LEMMA`, the pattern `Doc`
objects need to have part-of-speech tags set by the `tagger`. You can either
call the `nlp` object on your pattern texts instead of `nlp.make_doc`, or use
[`nlp.select_pipes`](/api/language#select_pipes) to disable components
selectively.
objects need to have part-of-speech tags set by the `tagger` or `morphologizer`.
You can either call the `nlp` object on your pattern texts instead of
`nlp.make_doc`, or use [`nlp.select_pipes`](/api/language#select_pipes) to
disable components selectively.
</Infobox>
@ -973,10 +974,287 @@ to match phrases with the same sequence of punctuation and non-punctuation
tokens as the pattern. But this can easily get confusing and doesn't have much
of an advantage over writing one or two token patterns.
## Dependency Matcher {#dependencymatcher new="3" model="parser"}
The [`DependencyMatcher`](/api/dependencymatcher) lets you match patterns within
the dependency parse using
[Semgrex](https://nlp.stanford.edu/nlp/javadoc/javanlp/edu/stanford/nlp/semgraph/semgrex/SemgrexPattern.html)
operators. It requires a model containing a parser such as the
[`DependencyParser`](/api/dependencyparser). Instead of defining a list of
adjacent tokens as in `Matcher` patterns, the `DependencyMatcher` patterns match
tokens in the dependency parse and specify the relations between them.
> ```python
> ### Example
> from spacy.matcher import DependencyMatcher
>
> # "[subject] ... initially founded"
> pattern = [
> # anchor token: founded
> {
> "RIGHT_ID": "founded",
> "RIGHT_ATTRS": {"ORTH": "founded"}
> },
> # founded -> subject
> {
> "LEFT_ID": "founded",
> "REL_OP": ">",
> "RIGHT_ID": "subject",
> "RIGHT_ATTRS": {"DEP": "nsubj"}
> },
> # "founded" follows "initially"
> {
> "LEFT_ID": "founded",
> "REL_OP": ";",
> "RIGHT_ID": "initially",
> "RIGHT_ATTRS": {"ORTH": "initially"}
> }
> ]
>
> matcher = DependencyMatcher(nlp.vocab)
> matcher.add("FOUNDED", [pattern])
> matches = matcher(doc)
> ```
A pattern added to the dependency matcher consists of a **list of
dictionaries**, with each dictionary describing a **token to match** and its
**relation to an existing token** in the pattern. Except for the first
dictionary, which defines an anchor token using only `RIGHT_ID` and
`RIGHT_ATTRS`, each pattern should have the following keys:
| Name | Description |
| ------------- | ---------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
| `LEFT_ID` | The name of the left-hand node in the relation, which has been defined in an earlier node. ~~str~~ |
| `REL_OP` | An operator that describes how the two nodes are related. ~~str~~ |
| `RIGHT_ID` | A unique name for the right-hand node in the relation. ~~str~~ |
| `RIGHT_ATTRS` | The token attributes to match for the right-hand node in the same format as patterns provided to the regular token-based [`Matcher`](/api/matcher). ~~Dict[str, Any]~~ |
Each additional token added to the pattern is linked to an existing token
`LEFT_ID` by the relation `REL_OP`. The new token is given the name `RIGHT_ID`
and described by the attributes `RIGHT_ATTRS`.
<Infobox title="Important note" variant="warning">
Because the unique token **names** in `LEFT_ID` and `RIGHT_ID` are used to
identify tokens, the order of the dicts in the patterns is important: a token
name needs to be defined as `RIGHT_ID` in one dict in the pattern **before** it
can be used as `LEFT_ID` in another dict.
</Infobox>
### Dependency matcher operators {#dependencymatcher-operators}
The following operators are supported by the `DependencyMatcher`, most of which
come directly from
[Semgrex](https://nlp.stanford.edu/nlp/javadoc/javanlp/edu/stanford/nlp/semgraph/semgrex/SemgrexPattern.html):
| Symbol | Description |
| --------- | -------------------------------------------------------------------------------------------------------------------- |
| `A < B` | `A` is the immediate dependent of `B`. |
| `A > B` | `A` is the immediate head of `B`. |
| `A << B` | `A` is the dependent in a chain to `B` following dep &rarr; head paths. |
| `A >> B` | `A` is the head in a chain to `B` following head &rarr; dep paths. |
| `A . B` | `A` immediately precedes `B`, i.e. `A.i == B.i - 1`, and both are within the same dependency tree. |
| `A .* B` | `A` precedes `B`, i.e. `A.i < B.i`, and both are within the same dependency tree _(not in Semgrex)_. |
| `A ; B` | `A` immediately follows `B`, i.e. `A.i == B.i + 1`, and both are within the same dependency tree _(not in Semgrex)_. |
| `A ;* B` | `A` follows `B`, i.e. `A.i > B.i`, and both are within the same dependency tree _(not in Semgrex)_. |
| `A $+ B` | `B` is a right immediate sibling of `A`, i.e. `A` and `B` have the same parent and `A.i == B.i - 1`. |
| `A $- B` | `B` is a left immediate sibling of `A`, i.e. `A` and `B` have the same parent and `A.i == B.i + 1`. |
| `A $++ B` | `B` is a right sibling of `A`, i.e. `A` and `B` have the same parent and `A.i < B.i`. |
| `A $-- B` | `B` is a left sibling of `A`, i.e. `A` and `B` have the same parent and `A.i > B.i`. |
### Designing dependency matcher patterns {#dependencymatcher-patterns}
Let's say we want to find sentences describing who founded what kind of company:
- _Smith founded a healthcare company in 2005._
- _Williams initially founded an insurance company in 1987._
- _Lee, an experienced CEO, has founded two AI startups._
The dependency parse for "Smith founded a healthcare company" shows types of
relations and tokens we want to match:
> #### Visualizing the parse
>
> The [`displacy` visualizer](/usage/visualizer) lets you render `Doc` objects
> and their dependency parse and part-of-speech tags:
>
> ```python
> import spacy
> from spacy import displacy
>
> nlp = spacy.load("en_core_web_sm")
> doc = nlp("Smith founded a healthcare company")
> displacy.serve(doc)
> ```
import DisplaCyDepFoundedHtml from 'images/displacy-dep-founded.html'
<Iframe title="displaCy visualization of dependencies" html={DisplaCyDepFoundedHtml} height={450} />
The relations we're interested in are:
- the founder is the **subject** (`nsubj`) of the token with the text `founded`
- the company is the **object** (`dobj`) of `founded`
- the kind of company may be an **adjective** (`amod`, not shown above) or a
**compound** (`compound`)
The first step is to pick an **anchor token** for the pattern. Since it's the
root of the dependency parse, `founded` is a good choice here. It is often
easier to construct patterns when all dependency relation operators point from
the head to the children. In this example, we'll only use `>`, which connects a
head to an immediate dependent as `head > child`.
The simplest dependency matcher pattern will identify and name a single token in
the tree:
```python
### {executable="true"}
import spacy
from spacy.matcher import DependencyMatcher
nlp = spacy.load("en_core_web_sm")
matcher = DependencyMatcher(nlp.vocab)
pattern = [
{
"RIGHT_ID": "anchor_founded", # unique name
"RIGHT_ATTRS": {"ORTH": "founded"} # token pattern for "founded"
}
]
matcher.add("FOUNDED", [pattern])
doc = nlp("Smith founded two companies.")
matches = matcher(doc)
print(matches) # [(4851363122962674176, [1])]
```
Now that we have a named anchor token (`anchor_founded`), we can add the founder
as the immediate dependent (`>`) of `founded` with the dependency label `nsubj`:
```python
### Step 1 {highlight="8,10"}
pattern = [
{
"RIGHT_ID": "anchor_founded",
"RIGHT_ATTRS": {"ORTH": "founded"}
},
{
"LEFT_ID": "anchor_founded",
"REL_OP": ">",
"RIGHT_ID": "subject",
"RIGHT_ATTRS": {"DEP": "nsubj"},
}
# ...
]
```
The direct object (`dobj`) is added in the same way:
```python
### Step 2 {highlight=""}
pattern = [
#...
{
"LEFT_ID": "anchor_founded",
"REL_OP": ">",
"RIGHT_ID": "founded_object",
"RIGHT_ATTRS": {"DEP": "dobj"},
}
# ...
]
```
When the subject and object tokens are added, they are required to have names
under the key `RIGHT_ID`, which are allowed to be any unique string, e.g.
`founded_subject`. These names can then be used as `LEFT_ID` to **link new
tokens into the pattern**. For the final part of our pattern, we'll specify that
the token `founded_object` should have a modifier with the dependency relation
`amod` or `compound`:
```python
### Step 3 {highlight="7"}
pattern = [
# ...
{
"LEFT_ID": "founded_object",
"REL_OP": ">",
"RIGHT_ID": "founded_object_modifier",
"RIGHT_ATTRS": {"DEP": {"IN": ["amod", "compound"]}},
}
]
```
You can picture the process of creating a dependency matcher pattern as defining
an anchor token on the left and building up the pattern by linking tokens
one-by-one on the right using relation operators. To create a valid pattern,
each new token needs to be linked to an existing token on its left. As for
`founded` in this example, a token may be linked to more than one token on its
right:
![Dependency matcher pattern](../images/dep-match-diagram.svg)
The full pattern comes together as shown in the example below:
```python
### {executable="true"}
import spacy
from spacy.matcher import DependencyMatcher
nlp = spacy.load("en_core_web_sm")
matcher = DependencyMatcher(nlp.vocab)
pattern = [
{
"RIGHT_ID": "anchor_founded",
"RIGHT_ATTRS": {"ORTH": "founded"}
},
{
"LEFT_ID": "anchor_founded",
"REL_OP": ">",
"RIGHT_ID": "subject",
"RIGHT_ATTRS": {"DEP": "nsubj"},
},
{
"LEFT_ID": "anchor_founded",
"REL_OP": ">",
"RIGHT_ID": "founded_object",
"RIGHT_ATTRS": {"DEP": "dobj"},
},
{
"LEFT_ID": "founded_object",
"REL_OP": ">",
"RIGHT_ID": "founded_object_modifier",
"RIGHT_ATTRS": {"DEP": {"IN": ["amod", "compound"]}},
}
]
matcher.add("FOUNDED", [pattern])
doc = nlp("Lee, an experienced CEO, has founded two AI startups.")
matches = matcher(doc)
print(matches) # [(4851363122962674176, [6, 0, 10, 9])]
# Each token_id corresponds to one pattern dict
match_id, token_ids = matches[0]
for i in range(len(token_ids)):
print(pattern[i]["RIGHT_ID"] + ":", doc[token_ids[i]].text)
```
<Infobox title="Important note on speed" variant="warning">
The dependency matcher may be slow when token patterns can potentially match
many tokens in the sentence or when relation operators allow longer paths in the
dependency parse, e.g. `<<`, `>>`, `.*` and `;*`.
To improve the matcher speed, try to make your token patterns and operators as
specific as possible. For example, use `>` instead of `>>` if possible and use
token patterns that include dependency labels and other token attributes instead
of patterns such as `{}` that match any token in the sentence.
</Infobox>
## Rule-based entity recognition {#entityruler new="2.1"}
The [`EntityRuler`](/api/entityruler) is an exciting new component that lets you
add named entities based on pattern dictionaries, and makes it easy to combine
The [`EntityRuler`](/api/entityruler) is a component that lets you add named
entities based on pattern dictionaries, which makes it easy to combine
rule-based and statistical named entity recognition for even more powerful
pipelines.

View File

@ -26,6 +26,7 @@ menu:
- [End-to-end project workflows](#features-projects)
- [New built-in components](#features-pipeline-components)
- [New custom component API](#features-components)
- [Dependency matching](#features-dep-matcher)
- [Python type hints](#features-types)
- [New methods & attributes](#new-methods)
- [New & updated documentation](#new-docs)
@ -201,6 +202,34 @@ aren't set.
</Infobox>
### Dependency matching {#features-dep-matcher}
<!-- TODO: improve summary -->
> #### Example
>
> ```python
> # TODO: example
> ```
The [`DependencyMatcher`](/api/dependencymatcher) lets you match patterns within
the dependency parse using
[Semgrex](https://nlp.stanford.edu/nlp/javadoc/javanlp/edu/stanford/nlp/semgraph/semgrex/SemgrexPattern.html)
operators. It follows the same API as the token-based [`Matcher`](/api/matcher).
A pattern added to the dependency matcher consists of a **list of
dictionaries**, with each dictionary describing a **token to match** and its
**relation to an existing token** in the pattern.
<Infobox title="Details & Documentation" emoji="📖" list>
- **Usage:**
[Dependency matching](/usage/rule-based-matching#dependencymatcher),
- **API:** [`DependencyMatcher`](/api/dependencymatcher),
- **Implementation:**
[`spacy/matcher/dependencymatcher.pyx`](https://github.com/explosion/spaCy/tree/develop/spacy/matcher/dependencymatcher.pyx)
</Infobox>
### Type hints and type-based data validation {#features-types}
> #### Example
@ -313,7 +342,8 @@ format for documenting argument and return types.
[`Transformer`](/api/transformer), [`Lemmatizer`](/api/lemmatizer),
[`Morphologizer`](/api/morphologizer),
[`AttributeRuler`](/api/attributeruler),
[`SentenceRecognizer`](/api/sentencerecognizer), [`Pipe`](/api/pipe),
[`SentenceRecognizer`](/api/sentencerecognizer),
[`DependencyMatcher`])(/api/dependencymatcher), [`Pipe`](/api/pipe),
[`Corpus`](/api/corpus)
</Infobox>