Set up dependency tree pattern matching skeleton (#2732)

This commit is contained in:
Suraj Rajan 2018-09-27 16:57:18 +05:30 committed by Matthew Honnibal
parent 8809dc4514
commit bbdc6456c6
2 changed files with 287 additions and 8 deletions

View File

@ -48,6 +48,7 @@ from .attrs import FLAG37 as L8_ENT
from .attrs import FLAG36 as L9_ENT
from .attrs import FLAG35 as L10_ENT
DELIMITER = '||'
cpdef enum quantifier_t:
_META
@ -66,7 +67,7 @@ cdef enum action_t:
ACCEPT_PREV
PANIC
# A "match expression" conists of one or more token patterns
# A "match expression" consists of one or more token patterns
# Each token pattern consists of a quantifier and 0+ (attr, value) pairs.
# A state is an (int, pattern pointer) pair, where the int is the start
# position, and the pattern pointer shows where we're up to
@ -76,16 +77,16 @@ cdef struct AttrValueC:
attr_id_t attr
attr_t value
cdef struct TokenPatternC:
AttrValueC* attrs
int32_t nr_attr
quantifier_t quantifier
ctypedef TokenPatternC* TokenPatternC_ptr
ctypedef pair[int, TokenPatternC_ptr] StateC
DEF PADDING = 5
cdef TokenPatternC* init_pattern(Pool mem, attr_t entity_id,
object token_specs) except NULL:
@ -105,7 +106,6 @@ cdef TokenPatternC* init_pattern(Pool mem, attr_t entity_id,
pattern[i].nr_attr = 0
return pattern
cdef attr_t get_pattern_key(const TokenPatternC* pattern) except 0:
while pattern.nr_attr != 0:
pattern += 1
@ -262,7 +262,7 @@ cdef class Matcher:
key (unicode): The match ID.
on_match (callable): Callback executed on match.
*patterns (list): List of token descritions.
*patterns (list): List of token descriptions.
"""
for pattern in patterns:
if len(pattern) == 0:
@ -526,6 +526,7 @@ cdef class PhraseMatcher:
self.phrase_ids.set(phrase_hash, <void*>ent_id)
def __call__(self, Doc doc):
"""Find all sequences matching the supplied patterns on the `Doc`.
doc (Doc): The document to match over.
@ -573,3 +574,236 @@ cdef class PhraseMatcher:
return None
else:
return ent_id
cdef class DependencyTreeMatcher:
"""Match dependency parse tree based on pattern rules."""
cdef Pool mem
cdef readonly Vocab vocab
cdef readonly Matcher token_matcher
cdef public object _patterns
cdef public object _keys_to_token
cdef public object _root
cdef public object _entities
cdef public object _callbacks
cdef public object _nodes
cdef public object _tree
def __init__(self, vocab):
"""Create the DependencyTreeMatcher.
vocab (Vocab): The vocabulary object, which must be shared with the
documents the matcher will operate on.
RETURNS (DependencyTreeMatcher): The newly constructed object.
"""
size = 20
self.token_matcher = Matcher(vocab)
self._keys_to_token = {}
self._patterns = {}
self._root = {}
self._nodes = {}
self._tree = {}
self._entities = {}
self._callbacks = {}
self.vocab = vocab
self.mem = Pool()
def __reduce__(self):
data = (self.vocab, self._patterns,self._tree, self._callbacks)
return (unpickle_matcher, data, None, None)
def __len__(self):
"""Get the number of rules, which are edges ,added to the dependency tree matcher.
RETURNS (int): The number of rules.
"""
return len(self._patterns)
def __contains__(self, key):
"""Check whether the matcher contains rules for a match ID.
key (unicode): The match ID.
RETURNS (bool): Whether the matcher contains rules for this match ID.
"""
return self._normalize_key(key) in self._patterns
def add(self, key, on_match, *patterns):
# TODO : validations
# 1. check if input pattern is connected
# 2. check if pattern format is correct
# 3. check if atleast one root node is present
# 4. check if node names are not repeated
# 5. check if each node has only one head
for pattern in patterns:
if len(pattern) == 0:
raise ValueError(Errors.E012.format(key=key))
key = self._normalize_key(key)
_patterns = []
for pattern in patterns:
token_patterns = []
for i in range(len(pattern)):
token_pattern = [pattern[i]['PATTERN']]
token_patterns.append(token_pattern)
# self.patterns.append(token_patterns)
_patterns.append(token_patterns)
self._patterns.setdefault(key, [])
self._callbacks[key] = on_match
self._patterns[key].extend(_patterns)
# Add each node pattern of all the input patterns individually to the matcher.
# This enables only a single instance of Matcher to be used.
# Multiple adds are required to track each node pattern.
_keys_to_token_list = []
for i in range(len(_patterns)):
_keys_to_token = {}
# TODO : Better ways to hash edges in pattern?
for j in range(len(_patterns[i])):
k = self._normalize_key(unicode(key)+DELIMITER+unicode(i)+DELIMITER+unicode(j))
self.token_matcher.add(k,None,_patterns[i][j])
_keys_to_token[k] = j
_keys_to_token_list.append(_keys_to_token)
self._keys_to_token.setdefault(key, [])
self._keys_to_token[key].extend(_keys_to_token_list)
_nodes_list = []
for pattern in patterns:
nodes = {}
for i in range(len(pattern)):
nodes[pattern[i]['SPEC']['NODE_NAME']]=i
_nodes_list.append(nodes)
self._nodes.setdefault(key, [])
self._nodes[key].extend(_nodes_list)
# Create an object tree to traverse later on.
# This datastructure enable easy tree pattern match.
# Doc-Token based tree cannot be reused since it is memory heavy and tightly coupled with doc
self.retrieve_tree(patterns,_nodes_list,key)
def retrieve_tree(self,patterns,_nodes_list,key):
_heads_list = []
_root_list = []
for i in range(len(patterns)):
heads = {}
root = -1
for j in range(len(patterns[i])):
token_pattern = patterns[i][j]
if('NBOR_RELOP' not in token_pattern['SPEC']):
heads[j] = j
root = j
else:
# TODO: Add semgrex rules
# 1. >
if(token_pattern['SPEC']['NBOR_RELOP'] == '>'):
heads[j] = _nodes_list[i][token_pattern['SPEC']['NBOR_NAME']]
# 2. <
if(token_pattern['SPEC']['NBOR_RELOP'] == '<'):
heads[_nodes_list[i][token_pattern['SPEC']['NBOR_NAME']]] = j
_heads_list.append(heads)
_root_list.append(root)
_tree_list = []
for i in range(len(patterns)):
tree = {}
for j in range(len(patterns[i])):
if(j == _heads_list[i][j]):
continue
head = _heads_list[i][j]
if(head not in tree):
tree[head] = []
tree[head].append(j)
_tree_list.append(tree)
self._tree.setdefault(key, [])
self._tree[key].extend(_tree_list)
self._root.setdefault(key, [])
self._root[key].extend(_root_list)
def has_key(self, key):
"""Check whether the matcher has a rule with a given key.
key (string or int): The key to check.
RETURNS (bool): Whether the matcher has the rule.
"""
key = self._normalize_key(key)
return key in self._patterns
def get(self, key, default=None):
"""Retrieve the pattern stored for a key.
key (unicode or int): The key to retrieve.
RETURNS (tuple): The rule, as an (on_match, patterns) tuple.
"""
key = self._normalize_key(key)
if key not in self._patterns:
return default
return (self._callbacks[key], self._patterns[key])
def __call__(self, Doc doc):
matched_trees = []
matches = self.token_matcher(doc)
for key in list(self._patterns.keys()):
_patterns_list = self._patterns[key]
_keys_to_token_list = self._keys_to_token[key]
_root_list = self._root[key]
_tree_list = self._tree[key]
_nodes_list = self._nodes[key]
length = len(_patterns_list)
for i in range(length):
_keys_to_token = _keys_to_token_list[i]
_root = _root_list[i]
_tree = _tree_list[i]
_nodes = _nodes_list[i]
id_to_position = {}
# This could be taken outside to improve running time..?
for match_id, start, end in matches:
if match_id in _keys_to_token:
if _keys_to_token[match_id] not in id_to_position:
id_to_position[_keys_to_token[match_id]] = []
id_to_position[_keys_to_token[match_id]].append(start)
length = len(_nodes)
if _root in id_to_position:
candidates = id_to_position[_root]
for candidate in candidates:
isVisited = {}
self.dfs(candidate,_root,_tree,id_to_position,doc,isVisited)
# to check if the subtree pattern is completely identified
if(len(isVisited) == length):
matched_trees.append((key,list(isVisited)))
for i, (ent_id, nodes) in enumerate(matched_trees):
on_match = self._callbacks.get(ent_id)
if on_match is not None:
on_match(self, doc, i, matches)
return matched_trees
def dfs(self,candidate,root,tree,id_to_position,doc,isVisited):
if(root in id_to_position and candidate in id_to_position[root]):
# color the node since it is valid
isVisited[candidate] = True
candidate_children = doc[candidate].children
for candidate_child in candidate_children:
if root in tree:
for root_child in tree[root]:
self.dfs(candidate_child.i,root_child,tree,id_to_position,doc,isVisited)
def _normalize_key(self, key):
if isinstance(key, basestring):
return self.vocab.strings.add(key)
else:
return key

View File

@ -1,12 +1,14 @@
# coding: utf-8
from __future__ import unicode_literals
from ..matcher import Matcher, PhraseMatcher
from numpy import sort
from ..matcher import Matcher, PhraseMatcher, DependencyTreeMatcher
from .util import get_doc
from ..tokens import Doc
import pytest
import re
@pytest.fixture
def matcher(en_vocab):
@ -20,7 +22,6 @@ def matcher(en_vocab):
matcher.add(key, None, *patterns)
return matcher
def test_matcher_from_api_docs(en_vocab):
matcher = Matcher(en_vocab)
pattern = [{'ORTH': 'test'}]
@ -258,3 +259,47 @@ def test_matcher_end_zero_plus(matcher):
assert len(matcher(nlp(u'a b c'))) == 1
assert len(matcher(nlp(u'a b b c'))) == 1
assert len(matcher(nlp(u'a b b'))) == 1
@pytest.fixture
def text():
return u"The quick brown fox jumped over the lazy fox"
@pytest.fixture
def heads():
return [3,2,1,1,0,-1,2,1,-3]
@pytest.fixture
def deps():
return ['det', 'amod', 'amod', 'nsubj', 'prep', 'pobj', 'det', 'amod']
@pytest.fixture
def dependency_tree_matcher(en_vocab):
is_brown_yellow = lambda text: bool(re.compile(r'brown|yellow|over').match(text))
IS_BROWN_YELLOW = en_vocab.add_flag(is_brown_yellow)
pattern1 = [
{'SPEC': {'NODE_NAME': 'fox'}, 'PATTERN': {'ORTH': 'fox'}},
{'SPEC': {'NODE_NAME': 'q', 'NBOR_RELOP': '>', 'NBOR_NAME': 'fox'},'PATTERN': {'LOWER': u'quick'}},
{'SPEC': {'NODE_NAME': 'r', 'NBOR_RELOP': '>', 'NBOR_NAME': 'fox'}, 'PATTERN': {IS_BROWN_YELLOW: True}}
]
pattern2 = [
{'SPEC': {'NODE_NAME': 'jumped'}, 'PATTERN': {'ORTH': 'jumped'}},
{'SPEC': {'NODE_NAME': 'fox', 'NBOR_RELOP': '>', 'NBOR_NAME': 'jumped'},'PATTERN': {'LOWER': u'fox'}},
{'SPEC': {'NODE_NAME': 'over', 'NBOR_RELOP': '>', 'NBOR_NAME': 'fox'}, 'PATTERN': {IS_BROWN_YELLOW: True}}
]
matcher = DependencyTreeMatcher(en_vocab)
matcher.add('pattern1', None, pattern1)
matcher.add('pattern2', None, pattern2)
return matcher
def test_dependency_tree_matcher_compile(dependency_tree_matcher):
assert len(dependency_tree_matcher) == 2
def test_dependency_tree_matcher(dependency_tree_matcher,text,heads,deps):
doc = get_doc(dependency_tree_matcher.vocab,text.split(),heads=heads,deps=deps)
matches = dependency_tree_matcher(doc)
assert len(matches) == 2