💫 Allow matching non-ORTH attributes in PhraseMatcher (#2925)

* Allow matching non-orth attributes in PhraseMatcher (see #1971)

Usage: PhraseMatcher(nlp.vocab, attr='POS')

* Allow attr argument to be int

* Fix formatting

* Fix typo
This commit is contained in:
Ines Montani 2018-11-15 03:00:58 +01:00 committed by Matthew Honnibal
parent 7ed9124a45
commit e89708c3eb
2 changed files with 88 additions and 5 deletions

View File

@ -12,7 +12,7 @@ from .lexeme cimport attr_id_t
from .vocab cimport Vocab
from .tokens.doc cimport Doc
from .tokens.doc cimport get_token_attr
from .attrs cimport ID, attr_id_t, NULL_ATTR
from .attrs cimport ID, attr_id_t, NULL_ATTR, ORTH
from .errors import Errors, TempErrors, Warnings, deprecation_warning
from .attrs import IDS
@ -546,16 +546,21 @@ cdef class PhraseMatcher:
cdef Matcher matcher
cdef PreshMap phrase_ids
cdef int max_length
cdef attr_id_t attr
cdef public object _callbacks
cdef public object _patterns
def __init__(self, Vocab vocab, max_length=0):
def __init__(self, Vocab vocab, max_length=0, attr='ORTH'):
if max_length != 0:
deprecation_warning(Warnings.W010)
self.mem = Pool()
self.max_length = max_length
self.vocab = vocab
self.matcher = Matcher(self.vocab)
if isinstance(attr, long):
self.attr = attr
else:
self.attr = self.vocab.strings[attr]
self.phrase_ids = PreshMap()
abstract_patterns = [
[{U_ENT: True}],
@ -609,7 +614,8 @@ cdef class PhraseMatcher:
tags = get_bilou(length)
phrase_key = <attr_t*>mem.alloc(length, sizeof(attr_t))
for i, tag in enumerate(tags):
lexeme = self.vocab[doc.c[i].lex.orth]
attr_value = self.get_lex_value(doc, i)
lexeme = self.vocab[attr_value]
lexeme.set_flag(tag, True)
phrase_key[i] = lexeme.orth
phrase_hash = hash64(phrase_key,
@ -625,8 +631,16 @@ cdef class PhraseMatcher:
`doc[start:end]`. The `label_id` and `key` are both integers.
"""
matches = []
for _, start, end in self.matcher(doc):
ent_id = self.accept_match(doc, start, end)
if self.attr == ORTH:
match_doc = doc
else:
# If we're not matching on the ORTH, match_doc will be a Doc whose
# token.orth values are the attribute values we're matching on,
# e.g. Doc(nlp.vocab, words=[token.pos_ for token in doc])
words = [self.get_lex_value(doc, i) for i in range(len(doc))]
match_doc = Doc(self.vocab, words=words)
for _, start, end in self.matcher(match_doc):
ent_id = self.accept_match(match_doc, start, end)
if ent_id is not None:
matches.append((ent_id, start, end))
for i, (ent_id, start, end) in enumerate(matches):
@ -680,6 +694,23 @@ cdef class PhraseMatcher:
else:
return ent_id
def get_lex_value(self, Doc doc, int i):
if self.attr == ORTH:
# Return the regular orth value of the lexeme
return doc.c[i].lex.orth
# Get the attribute value instead, e.g. token.pos
attr_value = get_token_attr(&doc.c[i], self.attr)
if attr_value in (0, 1):
# Value is boolean, convert to string
string_attr_value = str(attr_value)
else:
string_attr_value = self.vocab.strings[attr_value]
string_attr_name = self.vocab.strings[self.attr]
# Concatenate the attr name and value to not pollute lexeme space
# e.g. 'POS-VERB' instead of just 'VERB', which could otherwise
# create false positive matches
return 'matcher:{}-{}'.format(string_attr_name, string_attr_value)
cdef class DependencyTreeMatcher:
"""Match dependency parse tree based on pattern rules."""

View File

@ -5,6 +5,8 @@ import pytest
from spacy.matcher import PhraseMatcher
from spacy.tokens import Doc
from ..util import get_doc
def test_matcher_phrase_matcher(en_vocab):
doc = Doc(en_vocab, words=["Google", "Now"])
@ -28,3 +30,53 @@ def test_phrase_matcher_contains(en_vocab):
matcher.add('TEST', None, Doc(en_vocab, words=['test']))
assert 'TEST' in matcher
assert 'TEST2' not in matcher
def test_phrase_matcher_string_attrs(en_vocab):
words1 = ['I', 'like', 'cats']
pos1 = ['PRON', 'VERB', 'NOUN']
words2 = ['Yes', ',', 'you', 'hate', 'dogs', 'very', 'much']
pos2 = ['INTJ', 'PUNCT', 'PRON', 'VERB', 'NOUN', 'ADV', 'ADV']
pattern = get_doc(en_vocab, words=words1, pos=pos1)
matcher = PhraseMatcher(en_vocab, attr='POS')
matcher.add('TEST', None, pattern)
doc = get_doc(en_vocab, words=words2, pos=pos2)
matches = matcher(doc)
assert len(matches) == 1
match_id, start, end = matches[0]
assert match_id == en_vocab.strings['TEST']
assert start == 2
assert end == 5
def test_phrase_matcher_string_attrs_negative(en_vocab):
"""Test that token with the control codes as ORTH are *not* matched."""
words1 = ['I', 'like', 'cats']
pos1 = ['PRON', 'VERB', 'NOUN']
words2 = ['matcher:POS-PRON', 'matcher:POS-VERB', 'matcher:POS-NOUN']
pos2 = ['X', 'X', 'X']
pattern = get_doc(en_vocab, words=words1, pos=pos1)
matcher = PhraseMatcher(en_vocab, attr='POS')
matcher.add('TEST', None, pattern)
doc = get_doc(en_vocab, words=words2, pos=pos2)
matches = matcher(doc)
assert len(matches) == 0
def test_phrase_matcher_bool_attrs(en_vocab):
words1 = ['Hello', 'world', '!']
words2 = ['No', 'problem', ',', 'he', 'said', '.']
pattern = Doc(en_vocab, words=words1)
matcher = PhraseMatcher(en_vocab, attr='IS_PUNCT')
matcher.add('TEST', None, pattern)
doc = Doc(en_vocab, words=words2)
matches = matcher(doc)
assert len(matches) == 2
match_id1, start1, end1 = matches[0]
match_id2, start2, end2 = matches[1]
assert match_id1 == en_vocab.strings['TEST']
assert match_id2 == en_vocab.strings['TEST']
assert start1 == 0
assert end1 == 3
assert start2 == 3
assert end2 == 6