diff --git a/spacy/errors.py b/spacy/errors.py
index 8f0666753..bad3e83e4 100644
--- a/spacy/errors.py
+++ b/spacy/errors.py
@@ -288,12 +288,12 @@ class Errors:
"Span objects, or dicts if set to manual=True.")
E097 = ("Invalid pattern: expected token pattern (list of dicts) or "
"phrase pattern (string) but got:\n{pattern}")
- E098 = ("Invalid pattern specified: expected both SPEC and PATTERN.")
- E099 = ("First node of pattern should be a root node. The root should "
- "only contain NODE_NAME.")
- E100 = ("Nodes apart from the root should contain NODE_NAME, NBOR_NAME and "
- "NBOR_RELOP.")
- E101 = ("NODE_NAME should be a new node and NBOR_NAME should already have "
+ E098 = ("Invalid pattern: expected both RIGHT_ID and RIGHT_ATTRS.")
+ E099 = ("Invalid pattern: the first node of pattern should be an anchor "
+ "node. The node should only contain RIGHT_ID and RIGHT_ATTRS.")
+ E100 = ("Nodes other than the anchor node should all contain LEFT_ID, "
+ "REL_OP and RIGHT_ID.")
+ E101 = ("RIGHT_ID should be a new node and LEFT_ID should already have "
"have been declared in previous edges.")
E102 = ("Can't merge non-disjoint spans. '{token}' is already part of "
"tokens to merge. If you want to find the longest non-overlapping "
@@ -661,6 +661,9 @@ class Errors:
"'{chunk}'. Tokenizer exceptions are only allowed to specify "
"`ORTH` and `NORM`.")
E1006 = ("Unable to initialize {name} model with 0 labels.")
+ E1007 = ("Unsupported DependencyMatcher operator '{op}'.")
+ E1008 = ("Invalid pattern: each pattern should be a list of dicts. Check "
+ "that you are providing a list of patterns as `List[List[dict]]`.")
@add_codes
diff --git a/spacy/matcher/dependencymatcher.pyx b/spacy/matcher/dependencymatcher.pyx
index e0a54e6f1..067b2167c 100644
--- a/spacy/matcher/dependencymatcher.pyx
+++ b/spacy/matcher/dependencymatcher.pyx
@@ -1,16 +1,16 @@
# cython: infer_types=True, profile=True
-from cymem.cymem cimport Pool
-from preshed.maps cimport PreshMap
-from libcpp cimport bool
+from typing import List
import numpy
+from cymem.cymem cimport Pool
+
from .matcher cimport Matcher
from ..vocab cimport Vocab
from ..tokens.doc cimport Doc
-from .matcher import unpickle_matcher
from ..errors import Errors
+from ..tokens import Span
DELIMITER = "||"
@@ -22,36 +22,52 @@ cdef class DependencyMatcher:
"""Match dependency parse tree based on pattern rules."""
cdef Pool mem
cdef readonly Vocab vocab
- cdef readonly Matcher token_matcher
+ cdef readonly Matcher matcher
cdef public object _patterns
+ cdef public object _raw_patterns
cdef public object _keys_to_token
cdef public object _root
- cdef public object _entities
cdef public object _callbacks
cdef public object _nodes
cdef public object _tree
+ cdef public object _ops
- def __init__(self, vocab):
+ def __init__(self, vocab, *, validate=False):
"""Create the DependencyMatcher.
vocab (Vocab): The vocabulary object, which must be shared with the
documents the matcher will operate on.
+ validate (bool): Whether patterns should be validated, passed to
+ Matcher as `validate`
"""
size = 20
- # TODO: make matcher work with validation
- self.token_matcher = Matcher(vocab, validate=False)
+ self.matcher = Matcher(vocab, validate=validate)
self._keys_to_token = {}
self._patterns = {}
+ self._raw_patterns = {}
self._root = {}
self._nodes = {}
self._tree = {}
- self._entities = {}
self._callbacks = {}
self.vocab = vocab
self.mem = Pool()
+ self._ops = {
+ "<": self.dep,
+ ">": self.gov,
+ "<<": self.dep_chain,
+ ">>": self.gov_chain,
+ ".": self.imm_precede,
+ ".*": self.precede,
+ ";": self.imm_follow,
+ ";*": self.follow,
+ "$+": self.imm_right_sib,
+ "$-": self.imm_left_sib,
+ "$++": self.right_sib,
+ "$--": self.left_sib,
+ }
def __reduce__(self):
- data = (self.vocab, self._patterns,self._tree, self._callbacks)
+ data = (self.vocab, self._raw_patterns, self._callbacks)
return (unpickle_matcher, data, None, None)
def __len__(self):
@@ -74,54 +90,61 @@ cdef class DependencyMatcher:
idx = 0
visited_nodes = {}
for relation in pattern:
- if "PATTERN" not in relation or "SPEC" not in relation:
+ if not isinstance(relation, dict):
+ raise ValueError(Errors.E1008)
+ if "RIGHT_ATTRS" not in relation and "RIGHT_ID" not in relation:
raise ValueError(Errors.E098.format(key=key))
if idx == 0:
if not(
- "NODE_NAME" in relation["SPEC"]
- and "NBOR_RELOP" not in relation["SPEC"]
- and "NBOR_NAME" not in relation["SPEC"]
+ "RIGHT_ID" in relation
+ and "REL_OP" not in relation
+ and "LEFT_ID" not in relation
):
raise ValueError(Errors.E099.format(key=key))
- visited_nodes[relation["SPEC"]["NODE_NAME"]] = True
+ visited_nodes[relation["RIGHT_ID"]] = True
else:
if not(
- "NODE_NAME" in relation["SPEC"]
- and "NBOR_RELOP" in relation["SPEC"]
- and "NBOR_NAME" in relation["SPEC"]
+ "RIGHT_ID" in relation
+ and "RIGHT_ATTRS" in relation
+ and "REL_OP" in relation
+ and "LEFT_ID" in relation
):
raise ValueError(Errors.E100.format(key=key))
if (
- relation["SPEC"]["NODE_NAME"] in visited_nodes
- or relation["SPEC"]["NBOR_NAME"] not in visited_nodes
+ relation["RIGHT_ID"] in visited_nodes
+ or relation["LEFT_ID"] not in visited_nodes
):
raise ValueError(Errors.E101.format(key=key))
- visited_nodes[relation["SPEC"]["NODE_NAME"]] = True
- visited_nodes[relation["SPEC"]["NBOR_NAME"]] = True
+ if relation["REL_OP"] not in self._ops:
+ raise ValueError(Errors.E1007.format(op=relation["REL_OP"]))
+ visited_nodes[relation["RIGHT_ID"]] = True
+ visited_nodes[relation["LEFT_ID"]] = True
idx = idx + 1
- def add(self, key, patterns, *_patterns, on_match=None):
+ def add(self, key, patterns, *, on_match=None):
"""Add a new matcher rule to the matcher.
key (str): The match ID.
patterns (list): The patterns to add for the given key.
on_match (callable): Optional callback executed on match.
"""
- if patterns is None or hasattr(patterns, "__call__"): # old API
- on_match = patterns
- patterns = _patterns
+ if on_match is not None and not hasattr(on_match, "__call__"):
+ raise ValueError(Errors.E171.format(arg_type=type(on_match)))
+ if patterns is None or not isinstance(patterns, List): # old API
+ raise ValueError(Errors.E948.format(arg_type=type(patterns)))
for pattern in patterns:
if len(pattern) == 0:
raise ValueError(Errors.E012.format(key=key))
- self.validate_input(pattern,key)
+ self.validate_input(pattern, key)
key = self._normalize_key(key)
+ self._raw_patterns.setdefault(key, [])
+ self._raw_patterns[key].extend(patterns)
_patterns = []
for pattern in patterns:
token_patterns = []
for i in range(len(pattern)):
- token_pattern = [pattern[i]["PATTERN"]]
+ token_pattern = [pattern[i]["RIGHT_ATTRS"]]
token_patterns.append(token_pattern)
- # self.patterns.append(token_patterns)
_patterns.append(token_patterns)
self._patterns.setdefault(key, [])
self._callbacks[key] = on_match
@@ -135,7 +158,7 @@ cdef class DependencyMatcher:
# TODO: Better ways to hash edges in pattern?
for j in range(len(_patterns[i])):
k = self._normalize_key(unicode(key) + DELIMITER + unicode(i) + DELIMITER + unicode(j))
- self.token_matcher.add(k, [_patterns[i][j]])
+ self.matcher.add(k, [_patterns[i][j]])
_keys_to_token[k] = j
_keys_to_token_list.append(_keys_to_token)
self._keys_to_token.setdefault(key, [])
@@ -144,14 +167,14 @@ cdef class DependencyMatcher:
for pattern in patterns:
nodes = {}
for i in range(len(pattern)):
- nodes[pattern[i]["SPEC"]["NODE_NAME"]] = i
+ nodes[pattern[i]["RIGHT_ID"]] = i
_nodes_list.append(nodes)
self._nodes.setdefault(key, [])
self._nodes[key].extend(_nodes_list)
# Create an object tree to traverse later on. This data structure
# enables easy tree pattern match. Doc-Token based tree cannot be
# reused since it is memory-heavy and tightly coupled with the Doc.
- self.retrieve_tree(patterns, _nodes_list,key)
+ self.retrieve_tree(patterns, _nodes_list, key)
def retrieve_tree(self, patterns, _nodes_list, key):
_heads_list = []
@@ -161,13 +184,13 @@ cdef class DependencyMatcher:
root = -1
for j in range(len(patterns[i])):
token_pattern = patterns[i][j]
- if ("NBOR_RELOP" not in token_pattern["SPEC"]):
+ if ("REL_OP" not in token_pattern):
heads[j] = ('root', j)
root = j
else:
heads[j] = (
- token_pattern["SPEC"]["NBOR_RELOP"],
- _nodes_list[i][token_pattern["SPEC"]["NBOR_NAME"]]
+ token_pattern["REL_OP"],
+ _nodes_list[i][token_pattern["LEFT_ID"]]
)
_heads_list.append(heads)
_root_list.append(root)
@@ -202,11 +225,21 @@ cdef class DependencyMatcher:
RETURNS (tuple): The rule, as an (on_match, patterns) tuple.
"""
key = self._normalize_key(key)
- if key not in self._patterns:
+ if key not in self._raw_patterns:
return default
- return (self._callbacks[key], self._patterns[key])
+ return (self._callbacks[key], self._raw_patterns[key])
- def __call__(self, Doc doc):
+ def remove(self, key):
+ key = self._normalize_key(key)
+ if not key in self._patterns:
+ raise ValueError(Errors.E175.format(key=key))
+ self._patterns.pop(key)
+ self._raw_patterns.pop(key)
+ self._nodes.pop(key)
+ self._tree.pop(key)
+ self._root.pop(key)
+
+ def __call__(self, object doclike):
"""Find all token sequences matching the supplied pattern.
doclike (Doc or Span): The document to match over.
@@ -214,8 +247,14 @@ cdef class DependencyMatcher:
describing the matches. A match tuple describes a span
`doc[start:end]`. The `label_id` and `key` are both integers.
"""
+ if isinstance(doclike, Doc):
+ doc = doclike
+ elif isinstance(doclike, Span):
+ doc = doclike.as_doc()
+ else:
+ raise ValueError(Errors.E195.format(good="Doc or Span", got=type(doclike).__name__))
matched_key_trees = []
- matches = self.token_matcher(doc)
+ matches = self.matcher(doc)
for key in list(self._patterns.keys()):
_patterns_list = self._patterns[key]
_keys_to_token_list = self._keys_to_token[key]
@@ -244,26 +283,26 @@ cdef class DependencyMatcher:
length = len(_nodes)
matched_trees = []
- self.recurse(_tree,id_to_position,_node_operator_map,0,[],matched_trees)
- matched_key_trees.append((key,matched_trees))
-
- for i, (ent_id, nodes) in enumerate(matched_key_trees):
- on_match = self._callbacks.get(ent_id)
+ self.recurse(_tree, id_to_position, _node_operator_map, 0, [], matched_trees)
+ for matched_tree in matched_trees:
+ matched_key_trees.append((key, matched_tree))
+ for i, (match_id, nodes) in enumerate(matched_key_trees):
+ on_match = self._callbacks.get(match_id)
if on_match is not None:
on_match(self, doc, i, matched_key_trees)
return matched_key_trees
- def recurse(self,tree,id_to_position,_node_operator_map,int patternLength,visited_nodes,matched_trees):
- cdef bool isValid;
- if(patternLength == len(id_to_position.keys())):
+ def recurse(self, tree, id_to_position, _node_operator_map, int patternLength, visited_nodes, matched_trees):
+ cdef bint isValid;
+ if patternLength == len(id_to_position.keys()):
isValid = True
for node in range(patternLength):
- if(node in tree):
+ if node in tree:
for idx, (relop,nbor) in enumerate(tree[node]):
computed_nbors = numpy.asarray(_node_operator_map[visited_nodes[node]][relop])
isNbor = False
for computed_nbor in computed_nbors:
- if(computed_nbor.i == visited_nodes[nbor]):
+ if computed_nbor.i == visited_nodes[nbor]:
isNbor = True
isValid = isValid & isNbor
if(isValid):
@@ -271,14 +310,14 @@ cdef class DependencyMatcher:
return
allPatternNodes = numpy.asarray(id_to_position[patternLength])
for patternNode in allPatternNodes:
- self.recurse(tree,id_to_position,_node_operator_map,patternLength+1,visited_nodes+[patternNode],matched_trees)
+ self.recurse(tree, id_to_position, _node_operator_map, patternLength+1, visited_nodes+[patternNode], matched_trees)
# Given a node and an edge operator, to return the list of nodes
# from the doc that belong to node+operator. This is used to store
# all the results beforehand to prevent unnecessary computation while
# pattern matching
# _node_operator_map[node][operator] = [...]
- def get_node_operator_map(self,doc,tree,id_to_position,nodes,root):
+ def get_node_operator_map(self, doc, tree, id_to_position, nodes, root):
_node_operator_map = {}
all_node_indices = nodes.values()
all_operators = []
@@ -295,24 +334,14 @@ cdef class DependencyMatcher:
_node_operator_map[node] = {}
for operator in all_operators:
_node_operator_map[node][operator] = []
- # Used to invoke methods for each operator
- switcher = {
- "<": self.dep,
- ">": self.gov,
- "<<": self.dep_chain,
- ">>": self.gov_chain,
- ".": self.imm_precede,
- "$+": self.imm_right_sib,
- "$-": self.imm_left_sib,
- "$++": self.right_sib,
- "$--": self.left_sib
- }
for operator in all_operators:
for node in all_nodes:
- _node_operator_map[node][operator] = switcher.get(operator)(doc,node)
+ _node_operator_map[node][operator] = self._ops.get(operator)(doc, node)
return _node_operator_map
def dep(self, doc, node):
+ if doc[node].head == doc[node]:
+ return []
return [doc[node].head]
def gov(self,doc,node):
@@ -322,36 +351,51 @@ cdef class DependencyMatcher:
return list(doc[node].ancestors)
def gov_chain(self, doc, node):
- return list(doc[node].subtree)
+ return [t for t in doc[node].subtree if t != doc[node]]
def imm_precede(self, doc, node):
- if node > 0:
+ sent = self._get_sent(doc[node])
+ if node < len(doc) - 1 and doc[node + 1] in sent:
+ return [doc[node + 1]]
+ return []
+
+ def precede(self, doc, node):
+ sent = self._get_sent(doc[node])
+ return [doc[i] for i in range(node + 1, sent.end)]
+
+ def imm_follow(self, doc, node):
+ sent = self._get_sent(doc[node])
+ if node > 0 and doc[node - 1] in sent:
return [doc[node - 1]]
return []
+ def follow(self, doc, node):
+ sent = self._get_sent(doc[node])
+ return [doc[i] for i in range(sent.start, node)]
+
def imm_right_sib(self, doc, node):
for child in list(doc[node].head.children):
- if child.i == node - 1:
+ if child.i == node + 1:
return [doc[child.i]]
return []
def imm_left_sib(self, doc, node):
for child in list(doc[node].head.children):
- if child.i == node + 1:
+ if child.i == node - 1:
return [doc[child.i]]
return []
def right_sib(self, doc, node):
candidate_children = []
for child in list(doc[node].head.children):
- if child.i < node:
+ if child.i > node:
candidate_children.append(doc[child.i])
return candidate_children
def left_sib(self, doc, node):
candidate_children = []
for child in list(doc[node].head.children):
- if child.i > node:
+ if child.i < node:
candidate_children.append(doc[child.i])
return candidate_children
@@ -360,3 +404,15 @@ cdef class DependencyMatcher:
return self.vocab.strings.add(key)
else:
return key
+
+ def _get_sent(self, token):
+ root = (list(token.ancestors) or [token])[-1]
+ return token.doc[root.left_edge.i:root.right_edge.i + 1]
+
+
+def unpickle_matcher(vocab, patterns, callbacks):
+ matcher = DependencyMatcher(vocab)
+ for key, pattern in patterns.items():
+ callback = callbacks.get(key, None)
+ matcher.add(key, pattern, on_match=callback)
+ return matcher
diff --git a/spacy/tests/matcher/test_dependency_matcher.py b/spacy/tests/matcher/test_dependency_matcher.py
new file mode 100644
index 000000000..72005cc82
--- /dev/null
+++ b/spacy/tests/matcher/test_dependency_matcher.py
@@ -0,0 +1,334 @@
+import pytest
+import pickle
+import re
+import copy
+from mock import Mock
+from spacy.matcher import DependencyMatcher
+from ..util import get_doc
+
+
+@pytest.fixture
+def doc(en_vocab):
+ text = "The quick brown fox jumped over the lazy fox"
+ heads = [3, 2, 1, 1, 0, -1, 2, 1, -3]
+ deps = ["det", "amod", "amod", "nsubj", "ROOT", "prep", "pobj", "det", "amod"]
+ doc = get_doc(en_vocab, text.split(), heads=heads, deps=deps)
+ return doc
+
+
+@pytest.fixture
+def patterns(en_vocab):
+ def is_brown_yellow(text):
+ return bool(re.compile(r"brown|yellow").match(text))
+
+ IS_BROWN_YELLOW = en_vocab.add_flag(is_brown_yellow)
+
+ pattern1 = [
+ {"RIGHT_ID": "fox", "RIGHT_ATTRS": {"ORTH": "fox"}},
+ {
+ "LEFT_ID": "fox",
+ "REL_OP": ">",
+ "RIGHT_ID": "q",
+ "RIGHT_ATTRS": {"ORTH": "quick", "DEP": "amod"},
+ },
+ {
+ "LEFT_ID": "fox",
+ "REL_OP": ">",
+ "RIGHT_ID": "r",
+ "RIGHT_ATTRS": {IS_BROWN_YELLOW: True},
+ },
+ ]
+
+ pattern2 = [
+ {"RIGHT_ID": "jumped", "RIGHT_ATTRS": {"ORTH": "jumped"}},
+ {
+ "LEFT_ID": "jumped",
+ "REL_OP": ">",
+ "RIGHT_ID": "fox1",
+ "RIGHT_ATTRS": {"ORTH": "fox"},
+ },
+ {
+ "LEFT_ID": "jumped",
+ "REL_OP": ".",
+ "RIGHT_ID": "over",
+ "RIGHT_ATTRS": {"ORTH": "over"},
+ },
+ ]
+
+ pattern3 = [
+ {"RIGHT_ID": "jumped", "RIGHT_ATTRS": {"ORTH": "jumped"}},
+ {
+ "LEFT_ID": "jumped",
+ "REL_OP": ">",
+ "RIGHT_ID": "fox",
+ "RIGHT_ATTRS": {"ORTH": "fox"},
+ },
+ {
+ "LEFT_ID": "fox",
+ "REL_OP": ">>",
+ "RIGHT_ID": "r",
+ "RIGHT_ATTRS": {"ORTH": "brown"},
+ },
+ ]
+
+ pattern4 = [
+ {"RIGHT_ID": "jumped", "RIGHT_ATTRS": {"ORTH": "jumped"}},
+ {
+ "LEFT_ID": "jumped",
+ "REL_OP": ">",
+ "RIGHT_ID": "fox",
+ "RIGHT_ATTRS": {"ORTH": "fox"},
+ }
+ ]
+
+ pattern5 = [
+ {"RIGHT_ID": "jumped", "RIGHT_ATTRS": {"ORTH": "jumped"}},
+ {
+ "LEFT_ID": "jumped",
+ "REL_OP": ">>",
+ "RIGHT_ID": "fox",
+ "RIGHT_ATTRS": {"ORTH": "fox"},
+ },
+ ]
+
+ return [pattern1, pattern2, pattern3, pattern4, pattern5]
+
+
+@pytest.fixture
+def dependency_matcher(en_vocab, patterns, doc):
+ matcher = DependencyMatcher(en_vocab)
+ mock = Mock()
+ for i in range(1, len(patterns) + 1):
+ if i == 1:
+ matcher.add("pattern1", [patterns[0]], on_match=mock)
+ else:
+ matcher.add("pattern" + str(i), [patterns[i - 1]])
+
+ return matcher
+
+
+def test_dependency_matcher(dependency_matcher, doc, patterns):
+ assert len(dependency_matcher) == 5
+ assert "pattern3" in dependency_matcher
+ assert dependency_matcher.get("pattern3") == (None, [patterns[2]])
+ matches = dependency_matcher(doc)
+ assert len(matches) == 6
+ assert matches[0][1] == [3, 1, 2]
+ assert matches[1][1] == [4, 3, 5]
+ assert matches[2][1] == [4, 3, 2]
+ assert matches[3][1] == [4, 3]
+ assert matches[4][1] == [4, 3]
+ assert matches[5][1] == [4, 8]
+
+ span = doc[0:6]
+ matches = dependency_matcher(span)
+ assert len(matches) == 5
+ assert matches[0][1] == [3, 1, 2]
+ assert matches[1][1] == [4, 3, 5]
+ assert matches[2][1] == [4, 3, 2]
+ assert matches[3][1] == [4, 3]
+ assert matches[4][1] == [4, 3]
+
+
+def test_dependency_matcher_pickle(en_vocab, patterns, doc):
+ matcher = DependencyMatcher(en_vocab)
+ for i in range(1, len(patterns) + 1):
+ matcher.add("pattern" + str(i), [patterns[i - 1]])
+
+ matches = matcher(doc)
+ assert matches[0][1] == [3, 1, 2]
+ assert matches[1][1] == [4, 3, 5]
+ assert matches[2][1] == [4, 3, 2]
+ assert matches[3][1] == [4, 3]
+ assert matches[4][1] == [4, 3]
+ assert matches[5][1] == [4, 8]
+
+ b = pickle.dumps(matcher)
+ matcher_r = pickle.loads(b)
+
+ assert len(matcher) == len(matcher_r)
+ matches = matcher_r(doc)
+ assert matches[0][1] == [3, 1, 2]
+ assert matches[1][1] == [4, 3, 5]
+ assert matches[2][1] == [4, 3, 2]
+ assert matches[3][1] == [4, 3]
+ assert matches[4][1] == [4, 3]
+ assert matches[5][1] == [4, 8]
+
+
+def test_dependency_matcher_pattern_validation(en_vocab):
+ pattern = [
+ {"RIGHT_ID": "fox", "RIGHT_ATTRS": {"ORTH": "fox"}},
+ {
+ "LEFT_ID": "fox",
+ "REL_OP": ">",
+ "RIGHT_ID": "q",
+ "RIGHT_ATTRS": {"ORTH": "quick", "DEP": "amod"},
+ },
+ {
+ "LEFT_ID": "fox",
+ "REL_OP": ">",
+ "RIGHT_ID": "r",
+ "RIGHT_ATTRS": {"ORTH": "brown"},
+ },
+ ]
+
+ matcher = DependencyMatcher(en_vocab)
+ # original pattern is valid
+ matcher.add("FOUNDED", [pattern])
+ # individual pattern not wrapped in a list
+ with pytest.raises(ValueError):
+ matcher.add("FOUNDED", pattern)
+ # no anchor node
+ with pytest.raises(ValueError):
+ matcher.add("FOUNDED", [pattern[1:]])
+ # required keys missing
+ with pytest.raises(ValueError):
+ pattern2 = copy.deepcopy(pattern)
+ del pattern2[0]["RIGHT_ID"]
+ matcher.add("FOUNDED", [pattern2])
+ with pytest.raises(ValueError):
+ pattern2 = copy.deepcopy(pattern)
+ del pattern2[1]["RIGHT_ID"]
+ matcher.add("FOUNDED", [pattern2])
+ with pytest.raises(ValueError):
+ pattern2 = copy.deepcopy(pattern)
+ del pattern2[1]["RIGHT_ATTRS"]
+ matcher.add("FOUNDED", [pattern2])
+ with pytest.raises(ValueError):
+ pattern2 = copy.deepcopy(pattern)
+ del pattern2[1]["LEFT_ID"]
+ matcher.add("FOUNDED", [pattern2])
+ with pytest.raises(ValueError):
+ pattern2 = copy.deepcopy(pattern)
+ del pattern2[1]["REL_OP"]
+ matcher.add("FOUNDED", [pattern2])
+ # invalid operator
+ with pytest.raises(ValueError):
+ pattern2 = copy.deepcopy(pattern)
+ pattern2[1]["REL_OP"] = "!!!"
+ matcher.add("FOUNDED", [pattern2])
+ # duplicate node name
+ with pytest.raises(ValueError):
+ pattern2 = copy.deepcopy(pattern)
+ pattern2[1]["RIGHT_ID"] = "fox"
+ matcher.add("FOUNDED", [pattern2])
+
+
+def test_dependency_matcher_callback(en_vocab, doc):
+ pattern = [
+ {"RIGHT_ID": "quick", "RIGHT_ATTRS": {"ORTH": "quick"}},
+ ]
+
+ matcher = DependencyMatcher(en_vocab)
+ mock = Mock()
+ matcher.add("pattern", [pattern], on_match=mock)
+ matches = matcher(doc)
+ mock.assert_called_once_with(matcher, doc, 0, matches)
+
+ # check that matches with and without callback are the same (#4590)
+ matcher2 = DependencyMatcher(en_vocab)
+ matcher2.add("pattern", [pattern])
+ matches2 = matcher2(doc)
+ assert matches == matches2
+
+
+@pytest.mark.parametrize(
+ "op,num_matches", [(".", 8), (".*", 20), (";", 8), (";*", 20),]
+)
+def test_dependency_matcher_precedence_ops(en_vocab, op, num_matches):
+ # two sentences to test that all matches are within the same sentence
+ doc = get_doc(
+ en_vocab,
+ words=["a", "b", "c", "d", "e"] * 2,
+ heads=[0, -1, -2, -3, -4] * 2,
+ deps=["dep"] * 10,
+ )
+ match_count = 0
+ for text in ["a", "b", "c", "d", "e"]:
+ pattern = [
+ {"RIGHT_ID": "1", "RIGHT_ATTRS": {"ORTH": text}},
+ {"LEFT_ID": "1", "REL_OP": op, "RIGHT_ID": "2", "RIGHT_ATTRS": {},},
+ ]
+ matcher = DependencyMatcher(en_vocab)
+ matcher.add("A", [pattern])
+ matches = matcher(doc)
+ match_count += len(matches)
+ for match in matches:
+ match_id, token_ids = match
+ # token_ids[0] op token_ids[1]
+ if op == ".":
+ assert token_ids[0] == token_ids[1] - 1
+ elif op == ";":
+ assert token_ids[0] == token_ids[1] + 1
+ elif op == ".*":
+ assert token_ids[0] < token_ids[1]
+ elif op == ";*":
+ assert token_ids[0] > token_ids[1]
+ # all tokens are within the same sentence
+ assert doc[token_ids[0]].sent == doc[token_ids[1]].sent
+ assert match_count == num_matches
+
+
+@pytest.mark.parametrize(
+ "left,right,op,num_matches",
+ [
+ ("fox", "jumped", "<", 1),
+ ("the", "lazy", "<", 0),
+ ("jumped", "jumped", "<", 0),
+ ("fox", "jumped", ">", 0),
+ ("fox", "lazy", ">", 1),
+ ("lazy", "lazy", ">", 0),
+ ("fox", "jumped", "<<", 2),
+ ("jumped", "fox", "<<", 0),
+ ("the", "fox", "<<", 2),
+ ("fox", "jumped", ">>", 0),
+ ("over", "the", ">>", 1),
+ ("fox", "the", ">>", 2),
+ ("fox", "jumped", ".", 1),
+ ("lazy", "fox", ".", 1),
+ ("the", "fox", ".", 0),
+ ("the", "the", ".", 0),
+ ("fox", "jumped", ";", 0),
+ ("lazy", "fox", ";", 0),
+ ("the", "fox", ";", 0),
+ ("the", "the", ";", 0),
+ ("quick", "fox", ".*", 2),
+ ("the", "fox", ".*", 3),
+ ("the", "the", ".*", 1),
+ ("fox", "jumped", ";*", 1),
+ ("quick", "fox", ";*", 0),
+ ("the", "fox", ";*", 1),
+ ("the", "the", ";*", 1),
+ ("quick", "brown", "$+", 1),
+ ("brown", "quick", "$+", 0),
+ ("brown", "brown", "$+", 0),
+ ("quick", "brown", "$-", 0),
+ ("brown", "quick", "$-", 1),
+ ("brown", "brown", "$-", 0),
+ ("the", "brown", "$++", 1),
+ ("brown", "the", "$++", 0),
+ ("brown", "brown", "$++", 0),
+ ("the", "brown", "$--", 0),
+ ("brown", "the", "$--", 1),
+ ("brown", "brown", "$--", 0),
+ ],
+)
+def test_dependency_matcher_ops(en_vocab, doc, left, right, op, num_matches):
+ right_id = right
+ if left == right:
+ right_id = right + "2"
+ pattern = [
+ {"RIGHT_ID": left, "RIGHT_ATTRS": {"LOWER": left}},
+ {
+ "LEFT_ID": left,
+ "REL_OP": op,
+ "RIGHT_ID": right_id,
+ "RIGHT_ATTRS": {"LOWER": right},
+ },
+ ]
+
+ matcher = DependencyMatcher(en_vocab)
+ matcher.add("pattern", [pattern])
+ matches = matcher(doc)
+ assert len(matches) == num_matches
diff --git a/spacy/tests/matcher/test_matcher_api.py b/spacy/tests/matcher/test_matcher_api.py
index 8310c4466..e0f335a19 100644
--- a/spacy/tests/matcher/test_matcher_api.py
+++ b/spacy/tests/matcher/test_matcher_api.py
@@ -1,7 +1,6 @@
import pytest
-import re
from mock import Mock
-from spacy.matcher import Matcher, DependencyMatcher
+from spacy.matcher import Matcher
from spacy.tokens import Doc, Token, Span
from ..doc.test_underscore import clean_underscore # noqa: F401
@@ -292,84 +291,6 @@ def test_matcher_extension_set_membership(en_vocab):
assert len(matches) == 0
-@pytest.fixture
-def text():
- return "The quick brown fox jumped over the lazy fox"
-
-
-@pytest.fixture
-def heads():
- return [3, 2, 1, 1, 0, -1, 2, 1, -3]
-
-
-@pytest.fixture
-def deps():
- return ["det", "amod", "amod", "nsubj", "prep", "pobj", "det", "amod"]
-
-
-@pytest.fixture
-def dependency_matcher(en_vocab):
- def is_brown_yellow(text):
- return bool(re.compile(r"brown|yellow|over").match(text))
-
- IS_BROWN_YELLOW = en_vocab.add_flag(is_brown_yellow)
-
- pattern1 = [
- {"SPEC": {"NODE_NAME": "fox"}, "PATTERN": {"ORTH": "fox"}},
- {
- "SPEC": {"NODE_NAME": "q", "NBOR_RELOP": ">", "NBOR_NAME": "fox"},
- "PATTERN": {"ORTH": "quick", "DEP": "amod"},
- },
- {
- "SPEC": {"NODE_NAME": "r", "NBOR_RELOP": ">", "NBOR_NAME": "fox"},
- "PATTERN": {IS_BROWN_YELLOW: True},
- },
- ]
-
- pattern2 = [
- {"SPEC": {"NODE_NAME": "jumped"}, "PATTERN": {"ORTH": "jumped"}},
- {
- "SPEC": {"NODE_NAME": "fox", "NBOR_RELOP": ">", "NBOR_NAME": "jumped"},
- "PATTERN": {"ORTH": "fox"},
- },
- {
- "SPEC": {"NODE_NAME": "quick", "NBOR_RELOP": ".", "NBOR_NAME": "jumped"},
- "PATTERN": {"ORTH": "fox"},
- },
- ]
-
- pattern3 = [
- {"SPEC": {"NODE_NAME": "jumped"}, "PATTERN": {"ORTH": "jumped"}},
- {
- "SPEC": {"NODE_NAME": "fox", "NBOR_RELOP": ">", "NBOR_NAME": "jumped"},
- "PATTERN": {"ORTH": "fox"},
- },
- {
- "SPEC": {"NODE_NAME": "r", "NBOR_RELOP": ">>", "NBOR_NAME": "fox"},
- "PATTERN": {"ORTH": "brown"},
- },
- ]
-
- matcher = DependencyMatcher(en_vocab)
- matcher.add("pattern1", [pattern1])
- matcher.add("pattern2", [pattern2])
- matcher.add("pattern3", [pattern3])
-
- return matcher
-
-
-def test_dependency_matcher_compile(dependency_matcher):
- assert len(dependency_matcher) == 3
-
-
-# def test_dependency_matcher(dependency_matcher, text, heads, deps):
-# doc = get_doc(dependency_matcher.vocab, text.split(), heads=heads, deps=deps)
-# matches = dependency_matcher(doc)
-# assert matches[0][1] == [[3, 1, 2]]
-# assert matches[1][1] == [[4, 3, 3]]
-# assert matches[2][1] == [[4, 3, 2]]
-
-
def test_matcher_basic_check(en_vocab):
matcher = Matcher(en_vocab)
# Potential mistake: pass in pattern instead of list of patterns
diff --git a/spacy/tests/regression/test_issue4501-5000.py b/spacy/tests/regression/test_issue4501-5000.py
index 39533f70a..d83a2c718 100644
--- a/spacy/tests/regression/test_issue4501-5000.py
+++ b/spacy/tests/regression/test_issue4501-5000.py
@@ -38,32 +38,6 @@ def test_gold_misaligned(en_tokenizer, text, words):
Example.from_dict(doc, {"words": words})
-def test_issue4590(en_vocab):
- """Test that matches param in on_match method are the same as matches run with no on_match method"""
- pattern = [
- {"SPEC": {"NODE_NAME": "jumped"}, "PATTERN": {"ORTH": "jumped"}},
- {
- "SPEC": {"NODE_NAME": "fox", "NBOR_RELOP": ">", "NBOR_NAME": "jumped"},
- "PATTERN": {"ORTH": "fox"},
- },
- {
- "SPEC": {"NODE_NAME": "quick", "NBOR_RELOP": ".", "NBOR_NAME": "jumped"},
- "PATTERN": {"ORTH": "fox"},
- },
- ]
-
- on_match = Mock()
- matcher = DependencyMatcher(en_vocab)
- matcher.add("pattern", on_match, pattern)
- text = "The quick brown fox jumped over the lazy fox"
- heads = [3, 2, 1, 1, 0, -1, 2, 1, -3]
- deps = ["det", "amod", "amod", "nsubj", "ROOT", "prep", "det", "amod", "pobj"]
- doc = get_doc(en_vocab, text.split(), heads=heads, deps=deps)
- matches = matcher(doc)
- on_match_args = on_match.call_args
- assert on_match_args[0][3] == matches
-
-
def test_issue4651_with_phrase_matcher_attr():
"""Test that the EntityRuler PhraseMatcher is deserialized correctly using
the method from_disk when the EntityRuler argument phrase_matcher_attr is
diff --git a/website/docs/api/dependencymatcher.md b/website/docs/api/dependencymatcher.md
index b0395cc42..c90a715d9 100644
--- a/website/docs/api/dependencymatcher.md
+++ b/website/docs/api/dependencymatcher.md
@@ -1,65 +1,91 @@
---
title: DependencyMatcher
-teaser: Match sequences of tokens, based on the dependency parse
+teaser: Match subtrees within a dependency parse
tag: class
+new: 3
source: spacy/matcher/dependencymatcher.pyx
---
The `DependencyMatcher` follows the same API as the [`Matcher`](/api/matcher)
and [`PhraseMatcher`](/api/phrasematcher) and lets you match on dependency trees
-using the
-[Semgrex syntax](https://nlp.stanford.edu/nlp/javadoc/javanlp/edu/stanford/nlp/semgraph/semgrex/SemgrexPattern.html).
-It requires a trained [`DependencyParser`](/api/parser) or other component that
-sets the `Token.dep` attribute.
+using
+[Semgrex operators](https://nlp.stanford.edu/nlp/javadoc/javanlp/edu/stanford/nlp/semgraph/semgrex/SemgrexPattern.html).
+It requires a pretrained [`DependencyParser`](/api/parser) or other component
+that sets the `Token.dep` and `Token.head` attributes. See the
+[usage guide](/usage/rule-based-matching#dependencymatcher) for examples.
## Pattern format {#patterns}
-> ```json
+> ```python
> ### Example
+> # pattern: "[subject] ... initially founded"
> [
+> # anchor token: founded
> {
-> "SPEC": {"NODE_NAME": "founded"},
-> "PATTERN": {"ORTH": "founded"}
+> "RIGHT_ID": "founded",
+> "RIGHT_ATTRS": {"ORTH": "founded"}
> },
+> # founded -> subject
> {
-> "SPEC": {
-> "NODE_NAME": "founder",
-> "NBOR_RELOP": ">",
-> "NBOR_NAME": "founded"
-> },
-> "PATTERN": {"DEP": "nsubj"}
+> "LEFT_ID": "founded",
+> "REL_OP": ">",
+> "RIGHT_ID": "subject",
+> "RIGHT_ATTRS": {"DEP": "nsubj"}
> },
+> # "founded" follows "initially"
> {
-> "SPEC": {
-> "NODE_NAME": "object",
-> "NBOR_RELOP": ">",
-> "NBOR_NAME": "founded"
-> },
-> "PATTERN": {"DEP": "dobj"}
+> "LEFT_ID": "founded",
+> "REL_OP": ";",
+> "RIGHT_ID": "initially",
+> "RIGHT_ATTRS": {"ORTH": "initially"}
> }
> ]
> ```
A pattern added to the `DependencyMatcher` consists of a list of dictionaries,
-with each dictionary describing a node to match. Each pattern should have the
-following top-level keys:
+with each dictionary describing a token to match. Except for the first
+dictionary, which defines an anchor token using only `RIGHT_ID` and
+`RIGHT_ATTRS`, each pattern should have the following keys:
-| Name | Description |
-| --------- | ---------------------------------------------------------------------------------------------------------------------------------------------- |
-| `PATTERN` | The token attributes to match in the same format as patterns provided to the regular token-based [`Matcher`](/api/matcher). ~~Dict[str, Any]~~ |
-| `SPEC` | The relationships of the nodes in the subtree that should be matched. ~~Dict[str, str]~~ |
+| Name | Description |
+| ------------- | ---------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
+| `LEFT_ID` | The name of the left-hand node in the relation, which has been defined in an earlier node. ~~str~~ |
+| `REL_OP` | An operator that describes how the two nodes are related. ~~str~~ |
+| `RIGHT_ID` | A unique name for the right-hand node in the relation. ~~str~~ |
+| `RIGHT_ATTRS` | The token attributes to match for the right-hand node in the same format as patterns provided to the regular token-based [`Matcher`](/api/matcher). ~~Dict[str, Any]~~ |
-The `SPEC` includes the following fields:
+
-| Name | Description |
-| ------------ | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ |
-| `NODE_NAME` | A unique name for this node to refer to it in other specs. ~~str~~ |
-| `NBOR_RELOP` | A [Semgrex](https://nlp.stanford.edu/nlp/javadoc/javanlp/edu/stanford/nlp/semgraph/semgrex/SemgrexPattern.html) operator that describes how the two nodes are related. ~~str~~ |
-| `NBOR_NAME` | The unique name of the node that this node is connected to. ~~str~~ |
+For examples of how to construct dependency matcher patterns for different types
+of relations, see the usage guide on
+[dependency matching](/usage/rule-based-matching#dependencymatcher).
+
+
+
+### Operators
+
+The following operators are supported by the `DependencyMatcher`, most of which
+come directly from
+[Semgrex](https://nlp.stanford.edu/nlp/javadoc/javanlp/edu/stanford/nlp/semgraph/semgrex/SemgrexPattern.html):
+
+| Symbol | Description |
+| --------- | -------------------------------------------------------------------------------------------------------------------- |
+| `A < B` | `A` is the immediate dependent of `B`. |
+| `A > B` | `A` is the immediate head of `B`. |
+| `A << B` | `A` is the dependent in a chain to `B` following dep → head paths. |
+| `A >> B` | `A` is the head in a chain to `B` following head → dep paths. |
+| `A . B` | `A` immediately precedes `B`, i.e. `A.i == B.i - 1`, and both are within the same dependency tree. |
+| `A .* B` | `A` precedes `B`, i.e. `A.i < B.i`, and both are within the same dependency tree _(not in Semgrex)_. |
+| `A ; B` | `A` immediately follows `B`, i.e. `A.i == B.i + 1`, and both are within the same dependency tree _(not in Semgrex)_. |
+| `A ;* B` | `A` follows `B`, i.e. `A.i > B.i`, and both are within the same dependency tree _(not in Semgrex)_. |
+| `A $+ B` | `B` is a right immediate sibling of `A`, i.e. `A` and `B` have the same parent and `A.i == B.i - 1`. |
+| `A $- B` | `B` is a left immediate sibling of `A`, i.e. `A` and `B` have the same parent and `A.i == B.i + 1`. |
+| `A $++ B` | `B` is a right sibling of `A`, i.e. `A` and `B` have the same parent and `A.i < B.i`. |
+| `A $-- B` | `B` is a left sibling of `A`, i.e. `A` and `B` have the same parent and `A.i > B.i`. |
## DependencyMatcher.\_\_init\_\_ {#init tag="method"}
-Create a rule-based `DependencyMatcher`.
+Create a `DependencyMatcher`.
> #### Example
>
@@ -68,13 +94,15 @@ Create a rule-based `DependencyMatcher`.
> matcher = DependencyMatcher(nlp.vocab)
> ```
-| Name | Description |
-| ------- | ----------------------------------------------------------------------------------------------------- |
-| `vocab` | The vocabulary object, which must be shared with the documents the matcher will operate on. ~~Vocab~~ |
+| Name | Description |
+| -------------- | ----------------------------------------------------------------------------------------------------- |
+| `vocab` | The vocabulary object, which must be shared with the documents the matcher will operate on. ~~Vocab~~ |
+| _keyword-only_ | |
+| `validate` | Validate all patterns added to this matcher. ~~bool~~ |
## DependencyMatcher.\_\call\_\_ {#call tag="method"}
-Find all token sequences matching the supplied patterns on the `Doc` or `Span`.
+Find all tokens matching the supplied patterns on the `Doc` or `Span`.
> #### Example
>
@@ -82,36 +110,32 @@ Find all token sequences matching the supplied patterns on the `Doc` or `Span`.
> from spacy.matcher import DependencyMatcher
>
> matcher = DependencyMatcher(nlp.vocab)
-> pattern = [
-> {"SPEC": {"NODE_NAME": "founded"}, "PATTERN": {"ORTH": "founded"}},
-> {"SPEC": {"NODE_NAME": "founder", "NBOR_RELOP": ">", "NBOR_NAME": "founded"}, "PATTERN": {"DEP": "nsubj"}},
-> ]
-> matcher.add("Founder", [pattern])
+> pattern = [{"RIGHT_ID": "founded_id",
+> "RIGHT_ATTRS": {"ORTH": "founded"}}]
+> matcher.add("FOUNDED", [pattern])
> doc = nlp("Bill Gates founded Microsoft.")
> matches = matcher(doc)
> ```
-| Name | Description |
-| ----------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
-| `doclike` | The `Doc` or `Span` to match over. ~~Union[Doc, Span]~~ |
-| **RETURNS** | A list of `(match_id, start, end)` tuples, describing the matches. A match tuple describes a span `doc[start:end`]. The `match_id` is the ID of the added match pattern. ~~List[Tuple[int, int, int]]~~ |
+| Name | Description |
+| ----------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
+| `doclike` | The `Doc` or `Span` to match over. ~~Union[Doc, Span]~~ |
+| **RETURNS** | A list of `(match_id, token_ids)` tuples, describing the matches. The `match_id` is the ID of the match pattern and `token_ids` is a list of token indices matched by the pattern, where the position of each token in the list corresponds to the position of the node specification in the pattern. ~~List[Tuple[int, List[int]]]~~ |
## DependencyMatcher.\_\_len\_\_ {#len tag="method"}
-Get the number of rules (edges) added to the dependency matcher. Note that this
-only returns the number of rules (identical with the number of IDs), not the
-number of individual patterns.
+Get the number of rules added to the dependency matcher. Note that this only
+returns the number of rules (identical with the number of IDs), not the number
+of individual patterns.
> #### Example
>
> ```python
> matcher = DependencyMatcher(nlp.vocab)
> assert len(matcher) == 0
-> pattern = [
-> {"SPEC": {"NODE_NAME": "founded"}, "PATTERN": {"ORTH": "founded"}},
-> {"SPEC": {"NODE_NAME": "START_ENTITY", "NBOR_RELOP": ">", "NBOR_NAME": "founded"}, "PATTERN": {"DEP": "nsubj"}},
-> ]
-> matcher.add("Rule", [pattern])
+> pattern = [{"RIGHT_ID": "founded_id",
+> "RIGHT_ATTRS": {"ORTH": "founded"}}]
+> matcher.add("FOUNDED", [pattern])
> assert len(matcher) == 1
> ```
@@ -126,10 +150,10 @@ Check whether the matcher contains rules for a match ID.
> #### Example
>
> ```python
-> matcher = Matcher(nlp.vocab)
-> assert "Rule" not in matcher
-> matcher.add("Rule", [pattern])
-> assert "Rule" in matcher
+> matcher = DependencyMatcher(nlp.vocab)
+> assert "FOUNDED" not in matcher
+> matcher.add("FOUNDED", [pattern])
+> assert "FOUNDED" in matcher
> ```
| Name | Description |
@@ -152,33 +176,15 @@ will be overwritten.
> print('Matched!', matches)
>
> matcher = DependencyMatcher(nlp.vocab)
-> matcher.add("TEST_PATTERNS", patterns)
+> matcher.add("FOUNDED", patterns, on_match=on_match)
> ```
-| Name | Description |
-| -------------- | ---------------------------------------------------------------------------------------------------------------------------------------------------------- |
-| `match_id` | An ID for the thing you're matching. ~~str~~ |
-| `patterns` | list | Match pattern. A pattern consists of a list of dicts, where each dict describes a `"PATTERN"` and `"SPEC"`. ~~List[List[Dict[str, dict]]]~~ |
-| _keyword-only_ | | |
-| `on_match` | Callback function to act on matches. Takes the arguments `matcher`, `doc`, `i` and `matches`. ~~Optional[Callable[[Matcher, Doc, int, List[tuple], Any]]~~ |
-
-## DependencyMatcher.remove {#remove tag="method"}
-
-Remove a rule from the matcher. A `KeyError` is raised if the match ID does not
-exist.
-
-> #### Example
->
-> ```python
-> matcher.add("Rule", [pattern]])
-> assert "Rule" in matcher
-> matcher.remove("Rule")
-> assert "Rule" not in matcher
-> ```
-
-| Name | Description |
-| ----- | --------------------------------- |
-| `key` | The ID of the match rule. ~~str~~ |
+| Name | Description |
+| -------------- | -------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
+| `match_id` | An ID for the patterns. ~~str~~ |
+| `patterns` | A list of match patterns. A pattern consists of a list of dicts, where each dict describes a token in the tree. ~~List[List[Dict[str, Union[str, Dict]]]]~~ |
+| _keyword-only_ | | |
+| `on_match` | Callback function to act on matches. Takes the arguments `matcher`, `doc`, `i` and `matches`. ~~Optional[Callable[[DependencyMatcher, Doc, int, List[Tuple], Any]]~~ |
## DependencyMatcher.get {#get tag="method"}
@@ -188,11 +194,29 @@ Retrieve the pattern stored for a key. Returns the rule as an
> #### Example
>
> ```python
-> matcher.add("Rule", [pattern], on_match=on_match)
-> on_match, patterns = matcher.get("Rule")
+> matcher.add("FOUNDED", patterns, on_match=on_match)
+> on_match, patterns = matcher.get("FOUNDED")
> ```
-| Name | Description |
-| ----------- | --------------------------------------------------------------------------------------------- |
-| `key` | The ID of the match rule. ~~str~~ |
-| **RETURNS** | The rule, as an `(on_match, patterns)` tuple. ~~Tuple[Optional[Callable], List[List[dict]]]~~ |
+| Name | Description |
+| ----------- | ----------------------------------------------------------------------------------------------------------- |
+| `key` | The ID of the match rule. ~~str~~ |
+| **RETURNS** | The rule, as an `(on_match, patterns)` tuple. ~~Tuple[Optional[Callable], List[List[Union[Dict, Tuple]]]]~~ |
+
+## DependencyMatcher.remove {#remove tag="method"}
+
+Remove a rule from the dependency matcher. A `KeyError` is raised if the match
+ID does not exist.
+
+> #### Example
+>
+> ```python
+> matcher.add("FOUNDED", patterns)
+> assert "FOUNDED" in matcher
+> matcher.remove("FOUNDED")
+> assert "FOUNDED" not in matcher
+> ```
+
+| Name | Description |
+| ----- | --------------------------------- |
+| `key` | The ID of the match rule. ~~str~~ |
diff --git a/website/docs/images/dep-match-diagram.svg b/website/docs/images/dep-match-diagram.svg
new file mode 100644
index 000000000..676be4137
--- /dev/null
+++ b/website/docs/images/dep-match-diagram.svg
@@ -0,0 +1,39 @@
+
diff --git a/website/docs/images/displacy-dep-founded.html b/website/docs/images/displacy-dep-founded.html
new file mode 100644
index 000000000..e22984ee1
--- /dev/null
+++ b/website/docs/images/displacy-dep-founded.html
@@ -0,0 +1,58 @@
+
diff --git a/website/docs/usage/rule-based-matching.md b/website/docs/usage/rule-based-matching.md
index fb54c9936..01d60ddb8 100644
--- a/website/docs/usage/rule-based-matching.md
+++ b/website/docs/usage/rule-based-matching.md
@@ -4,6 +4,7 @@ teaser: Find phrases and tokens, and match entities
menu:
- ['Token Matcher', 'matcher']
- ['Phrase Matcher', 'phrasematcher']
+ - ['Dependency Matcher', 'dependencymatcher']
- ['Entity Ruler', 'entityruler']
- ['Models & Rules', 'models-rules']
---
@@ -939,10 +940,10 @@ object patterns as efficiently as possible and without running any of the other
pipeline components. If the token attribute you want to match on are set by a
pipeline component, **make sure that the pipeline component runs** when you
create the pattern. For example, to match on `POS` or `LEMMA`, the pattern `Doc`
-objects need to have part-of-speech tags set by the `tagger`. You can either
-call the `nlp` object on your pattern texts instead of `nlp.make_doc`, or use
-[`nlp.select_pipes`](/api/language#select_pipes) to disable components
-selectively.
+objects need to have part-of-speech tags set by the `tagger` or `morphologizer`.
+You can either call the `nlp` object on your pattern texts instead of
+`nlp.make_doc`, or use [`nlp.select_pipes`](/api/language#select_pipes) to
+disable components selectively.
@@ -973,10 +974,287 @@ to match phrases with the same sequence of punctuation and non-punctuation
tokens as the pattern. But this can easily get confusing and doesn't have much
of an advantage over writing one or two token patterns.
+## Dependency Matcher {#dependencymatcher new="3" model="parser"}
+
+The [`DependencyMatcher`](/api/dependencymatcher) lets you match patterns within
+the dependency parse using
+[Semgrex](https://nlp.stanford.edu/nlp/javadoc/javanlp/edu/stanford/nlp/semgraph/semgrex/SemgrexPattern.html)
+operators. It requires a model containing a parser such as the
+[`DependencyParser`](/api/dependencyparser). Instead of defining a list of
+adjacent tokens as in `Matcher` patterns, the `DependencyMatcher` patterns match
+tokens in the dependency parse and specify the relations between them.
+
+> ```python
+> ### Example
+> from spacy.matcher import DependencyMatcher
+>
+> # "[subject] ... initially founded"
+> pattern = [
+> # anchor token: founded
+> {
+> "RIGHT_ID": "founded",
+> "RIGHT_ATTRS": {"ORTH": "founded"}
+> },
+> # founded -> subject
+> {
+> "LEFT_ID": "founded",
+> "REL_OP": ">",
+> "RIGHT_ID": "subject",
+> "RIGHT_ATTRS": {"DEP": "nsubj"}
+> },
+> # "founded" follows "initially"
+> {
+> "LEFT_ID": "founded",
+> "REL_OP": ";",
+> "RIGHT_ID": "initially",
+> "RIGHT_ATTRS": {"ORTH": "initially"}
+> }
+> ]
+>
+> matcher = DependencyMatcher(nlp.vocab)
+> matcher.add("FOUNDED", [pattern])
+> matches = matcher(doc)
+> ```
+
+A pattern added to the dependency matcher consists of a **list of
+dictionaries**, with each dictionary describing a **token to match** and its
+**relation to an existing token** in the pattern. Except for the first
+dictionary, which defines an anchor token using only `RIGHT_ID` and
+`RIGHT_ATTRS`, each pattern should have the following keys:
+
+| Name | Description |
+| ------------- | ---------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
+| `LEFT_ID` | The name of the left-hand node in the relation, which has been defined in an earlier node. ~~str~~ |
+| `REL_OP` | An operator that describes how the two nodes are related. ~~str~~ |
+| `RIGHT_ID` | A unique name for the right-hand node in the relation. ~~str~~ |
+| `RIGHT_ATTRS` | The token attributes to match for the right-hand node in the same format as patterns provided to the regular token-based [`Matcher`](/api/matcher). ~~Dict[str, Any]~~ |
+
+Each additional token added to the pattern is linked to an existing token
+`LEFT_ID` by the relation `REL_OP`. The new token is given the name `RIGHT_ID`
+and described by the attributes `RIGHT_ATTRS`.
+
+
+
+Because the unique token **names** in `LEFT_ID` and `RIGHT_ID` are used to
+identify tokens, the order of the dicts in the patterns is important: a token
+name needs to be defined as `RIGHT_ID` in one dict in the pattern **before** it
+can be used as `LEFT_ID` in another dict.
+
+
+
+### Dependency matcher operators {#dependencymatcher-operators}
+
+The following operators are supported by the `DependencyMatcher`, most of which
+come directly from
+[Semgrex](https://nlp.stanford.edu/nlp/javadoc/javanlp/edu/stanford/nlp/semgraph/semgrex/SemgrexPattern.html):
+
+| Symbol | Description |
+| --------- | -------------------------------------------------------------------------------------------------------------------- |
+| `A < B` | `A` is the immediate dependent of `B`. |
+| `A > B` | `A` is the immediate head of `B`. |
+| `A << B` | `A` is the dependent in a chain to `B` following dep → head paths. |
+| `A >> B` | `A` is the head in a chain to `B` following head → dep paths. |
+| `A . B` | `A` immediately precedes `B`, i.e. `A.i == B.i - 1`, and both are within the same dependency tree. |
+| `A .* B` | `A` precedes `B`, i.e. `A.i < B.i`, and both are within the same dependency tree _(not in Semgrex)_. |
+| `A ; B` | `A` immediately follows `B`, i.e. `A.i == B.i + 1`, and both are within the same dependency tree _(not in Semgrex)_. |
+| `A ;* B` | `A` follows `B`, i.e. `A.i > B.i`, and both are within the same dependency tree _(not in Semgrex)_. |
+| `A $+ B` | `B` is a right immediate sibling of `A`, i.e. `A` and `B` have the same parent and `A.i == B.i - 1`. |
+| `A $- B` | `B` is a left immediate sibling of `A`, i.e. `A` and `B` have the same parent and `A.i == B.i + 1`. |
+| `A $++ B` | `B` is a right sibling of `A`, i.e. `A` and `B` have the same parent and `A.i < B.i`. |
+| `A $-- B` | `B` is a left sibling of `A`, i.e. `A` and `B` have the same parent and `A.i > B.i`. |
+
+### Designing dependency matcher patterns {#dependencymatcher-patterns}
+
+Let's say we want to find sentences describing who founded what kind of company:
+
+- _Smith founded a healthcare company in 2005._
+- _Williams initially founded an insurance company in 1987._
+- _Lee, an experienced CEO, has founded two AI startups._
+
+The dependency parse for "Smith founded a healthcare company" shows types of
+relations and tokens we want to match:
+
+> #### Visualizing the parse
+>
+> The [`displacy` visualizer](/usage/visualizer) lets you render `Doc` objects
+> and their dependency parse and part-of-speech tags:
+>
+> ```python
+> import spacy
+> from spacy import displacy
+>
+> nlp = spacy.load("en_core_web_sm")
+> doc = nlp("Smith founded a healthcare company")
+> displacy.serve(doc)
+> ```
+
+import DisplaCyDepFoundedHtml from 'images/displacy-dep-founded.html'
+
+
+
+The relations we're interested in are:
+
+- the founder is the **subject** (`nsubj`) of the token with the text `founded`
+- the company is the **object** (`dobj`) of `founded`
+- the kind of company may be an **adjective** (`amod`, not shown above) or a
+ **compound** (`compound`)
+
+The first step is to pick an **anchor token** for the pattern. Since it's the
+root of the dependency parse, `founded` is a good choice here. It is often
+easier to construct patterns when all dependency relation operators point from
+the head to the children. In this example, we'll only use `>`, which connects a
+head to an immediate dependent as `head > child`.
+
+The simplest dependency matcher pattern will identify and name a single token in
+the tree:
+
+```python
+### {executable="true"}
+import spacy
+from spacy.matcher import DependencyMatcher
+
+nlp = spacy.load("en_core_web_sm")
+matcher = DependencyMatcher(nlp.vocab)
+pattern = [
+ {
+ "RIGHT_ID": "anchor_founded", # unique name
+ "RIGHT_ATTRS": {"ORTH": "founded"} # token pattern for "founded"
+ }
+]
+matcher.add("FOUNDED", [pattern])
+doc = nlp("Smith founded two companies.")
+matches = matcher(doc)
+print(matches) # [(4851363122962674176, [1])]
+```
+
+Now that we have a named anchor token (`anchor_founded`), we can add the founder
+as the immediate dependent (`>`) of `founded` with the dependency label `nsubj`:
+
+```python
+### Step 1 {highlight="8,10"}
+pattern = [
+ {
+ "RIGHT_ID": "anchor_founded",
+ "RIGHT_ATTRS": {"ORTH": "founded"}
+ },
+ {
+ "LEFT_ID": "anchor_founded",
+ "REL_OP": ">",
+ "RIGHT_ID": "subject",
+ "RIGHT_ATTRS": {"DEP": "nsubj"},
+ }
+ # ...
+]
+```
+
+The direct object (`dobj`) is added in the same way:
+
+```python
+### Step 2 {highlight=""}
+pattern = [
+ #...
+ {
+ "LEFT_ID": "anchor_founded",
+ "REL_OP": ">",
+ "RIGHT_ID": "founded_object",
+ "RIGHT_ATTRS": {"DEP": "dobj"},
+ }
+ # ...
+]
+```
+
+When the subject and object tokens are added, they are required to have names
+under the key `RIGHT_ID`, which are allowed to be any unique string, e.g.
+`founded_subject`. These names can then be used as `LEFT_ID` to **link new
+tokens into the pattern**. For the final part of our pattern, we'll specify that
+the token `founded_object` should have a modifier with the dependency relation
+`amod` or `compound`:
+
+```python
+### Step 3 {highlight="7"}
+pattern = [
+ # ...
+ {
+ "LEFT_ID": "founded_object",
+ "REL_OP": ">",
+ "RIGHT_ID": "founded_object_modifier",
+ "RIGHT_ATTRS": {"DEP": {"IN": ["amod", "compound"]}},
+ }
+]
+```
+
+You can picture the process of creating a dependency matcher pattern as defining
+an anchor token on the left and building up the pattern by linking tokens
+one-by-one on the right using relation operators. To create a valid pattern,
+each new token needs to be linked to an existing token on its left. As for
+`founded` in this example, a token may be linked to more than one token on its
+right:
+
+![Dependency matcher pattern](../images/dep-match-diagram.svg)
+
+The full pattern comes together as shown in the example below:
+
+```python
+### {executable="true"}
+import spacy
+from spacy.matcher import DependencyMatcher
+
+nlp = spacy.load("en_core_web_sm")
+matcher = DependencyMatcher(nlp.vocab)
+
+pattern = [
+ {
+ "RIGHT_ID": "anchor_founded",
+ "RIGHT_ATTRS": {"ORTH": "founded"}
+ },
+ {
+ "LEFT_ID": "anchor_founded",
+ "REL_OP": ">",
+ "RIGHT_ID": "subject",
+ "RIGHT_ATTRS": {"DEP": "nsubj"},
+ },
+ {
+ "LEFT_ID": "anchor_founded",
+ "REL_OP": ">",
+ "RIGHT_ID": "founded_object",
+ "RIGHT_ATTRS": {"DEP": "dobj"},
+ },
+ {
+ "LEFT_ID": "founded_object",
+ "REL_OP": ">",
+ "RIGHT_ID": "founded_object_modifier",
+ "RIGHT_ATTRS": {"DEP": {"IN": ["amod", "compound"]}},
+ }
+]
+
+matcher.add("FOUNDED", [pattern])
+doc = nlp("Lee, an experienced CEO, has founded two AI startups.")
+matches = matcher(doc)
+
+print(matches) # [(4851363122962674176, [6, 0, 10, 9])]
+# Each token_id corresponds to one pattern dict
+match_id, token_ids = matches[0]
+for i in range(len(token_ids)):
+ print(pattern[i]["RIGHT_ID"] + ":", doc[token_ids[i]].text)
+```
+
+
+
+The dependency matcher may be slow when token patterns can potentially match
+many tokens in the sentence or when relation operators allow longer paths in the
+dependency parse, e.g. `<<`, `>>`, `.*` and `;*`.
+
+To improve the matcher speed, try to make your token patterns and operators as
+specific as possible. For example, use `>` instead of `>>` if possible and use
+token patterns that include dependency labels and other token attributes instead
+of patterns such as `{}` that match any token in the sentence.
+
+
+
## Rule-based entity recognition {#entityruler new="2.1"}
-The [`EntityRuler`](/api/entityruler) is an exciting new component that lets you
-add named entities based on pattern dictionaries, and makes it easy to combine
+The [`EntityRuler`](/api/entityruler) is a component that lets you add named
+entities based on pattern dictionaries, which makes it easy to combine
rule-based and statistical named entity recognition for even more powerful
pipelines.
diff --git a/website/docs/usage/v3.md b/website/docs/usage/v3.md
index 45ed7b0c8..e5228ab21 100644
--- a/website/docs/usage/v3.md
+++ b/website/docs/usage/v3.md
@@ -26,6 +26,7 @@ menu:
- [End-to-end project workflows](#features-projects)
- [New built-in components](#features-pipeline-components)
- [New custom component API](#features-components)
+- [Dependency matching](#features-dep-matcher)
- [Python type hints](#features-types)
- [New methods & attributes](#new-methods)
- [New & updated documentation](#new-docs)
@@ -201,6 +202,34 @@ aren't set.
+### Dependency matching {#features-dep-matcher}
+
+
+
+> #### Example
+>
+> ```python
+> # TODO: example
+> ```
+
+The [`DependencyMatcher`](/api/dependencymatcher) lets you match patterns within
+the dependency parse using
+[Semgrex](https://nlp.stanford.edu/nlp/javadoc/javanlp/edu/stanford/nlp/semgraph/semgrex/SemgrexPattern.html)
+operators. It follows the same API as the token-based [`Matcher`](/api/matcher).
+A pattern added to the dependency matcher consists of a **list of
+dictionaries**, with each dictionary describing a **token to match** and its
+**relation to an existing token** in the pattern.
+
+
+
+- **Usage:**
+ [Dependency matching](/usage/rule-based-matching#dependencymatcher),
+- **API:** [`DependencyMatcher`](/api/dependencymatcher),
+- **Implementation:**
+ [`spacy/matcher/dependencymatcher.pyx`](https://github.com/explosion/spaCy/tree/develop/spacy/matcher/dependencymatcher.pyx)
+
+
+
### Type hints and type-based data validation {#features-types}
> #### Example
@@ -313,7 +342,8 @@ format for documenting argument and return types.
[`Transformer`](/api/transformer), [`Lemmatizer`](/api/lemmatizer),
[`Morphologizer`](/api/morphologizer),
[`AttributeRuler`](/api/attributeruler),
- [`SentenceRecognizer`](/api/sentencerecognizer), [`Pipe`](/api/pipe),
+ [`SentenceRecognizer`](/api/sentencerecognizer),
+ [`DependencyMatcher`])(/api/dependencymatcher), [`Pipe`](/api/pipe),
[`Corpus`](/api/corpus)