mirror of https://github.com/explosion/spaCy.git
💫 Allow matching non-ORTH attributes in PhraseMatcher (#2925)
* Allow matching non-orth attributes in PhraseMatcher (see #1971) Usage: PhraseMatcher(nlp.vocab, attr='POS') * Allow attr argument to be int * Fix formatting * Fix typo
This commit is contained in:
parent
7ed9124a45
commit
e89708c3eb
|
@ -12,7 +12,7 @@ from .lexeme cimport attr_id_t
|
|||
from .vocab cimport Vocab
|
||||
from .tokens.doc cimport Doc
|
||||
from .tokens.doc cimport get_token_attr
|
||||
from .attrs cimport ID, attr_id_t, NULL_ATTR
|
||||
from .attrs cimport ID, attr_id_t, NULL_ATTR, ORTH
|
||||
from .errors import Errors, TempErrors, Warnings, deprecation_warning
|
||||
|
||||
from .attrs import IDS
|
||||
|
@ -546,16 +546,21 @@ cdef class PhraseMatcher:
|
|||
cdef Matcher matcher
|
||||
cdef PreshMap phrase_ids
|
||||
cdef int max_length
|
||||
cdef attr_id_t attr
|
||||
cdef public object _callbacks
|
||||
cdef public object _patterns
|
||||
|
||||
def __init__(self, Vocab vocab, max_length=0):
|
||||
def __init__(self, Vocab vocab, max_length=0, attr='ORTH'):
|
||||
if max_length != 0:
|
||||
deprecation_warning(Warnings.W010)
|
||||
self.mem = Pool()
|
||||
self.max_length = max_length
|
||||
self.vocab = vocab
|
||||
self.matcher = Matcher(self.vocab)
|
||||
if isinstance(attr, long):
|
||||
self.attr = attr
|
||||
else:
|
||||
self.attr = self.vocab.strings[attr]
|
||||
self.phrase_ids = PreshMap()
|
||||
abstract_patterns = [
|
||||
[{U_ENT: True}],
|
||||
|
@ -609,7 +614,8 @@ cdef class PhraseMatcher:
|
|||
tags = get_bilou(length)
|
||||
phrase_key = <attr_t*>mem.alloc(length, sizeof(attr_t))
|
||||
for i, tag in enumerate(tags):
|
||||
lexeme = self.vocab[doc.c[i].lex.orth]
|
||||
attr_value = self.get_lex_value(doc, i)
|
||||
lexeme = self.vocab[attr_value]
|
||||
lexeme.set_flag(tag, True)
|
||||
phrase_key[i] = lexeme.orth
|
||||
phrase_hash = hash64(phrase_key,
|
||||
|
@ -625,8 +631,16 @@ cdef class PhraseMatcher:
|
|||
`doc[start:end]`. The `label_id` and `key` are both integers.
|
||||
"""
|
||||
matches = []
|
||||
for _, start, end in self.matcher(doc):
|
||||
ent_id = self.accept_match(doc, start, end)
|
||||
if self.attr == ORTH:
|
||||
match_doc = doc
|
||||
else:
|
||||
# If we're not matching on the ORTH, match_doc will be a Doc whose
|
||||
# token.orth values are the attribute values we're matching on,
|
||||
# e.g. Doc(nlp.vocab, words=[token.pos_ for token in doc])
|
||||
words = [self.get_lex_value(doc, i) for i in range(len(doc))]
|
||||
match_doc = Doc(self.vocab, words=words)
|
||||
for _, start, end in self.matcher(match_doc):
|
||||
ent_id = self.accept_match(match_doc, start, end)
|
||||
if ent_id is not None:
|
||||
matches.append((ent_id, start, end))
|
||||
for i, (ent_id, start, end) in enumerate(matches):
|
||||
|
@ -680,6 +694,23 @@ cdef class PhraseMatcher:
|
|||
else:
|
||||
return ent_id
|
||||
|
||||
def get_lex_value(self, Doc doc, int i):
|
||||
if self.attr == ORTH:
|
||||
# Return the regular orth value of the lexeme
|
||||
return doc.c[i].lex.orth
|
||||
# Get the attribute value instead, e.g. token.pos
|
||||
attr_value = get_token_attr(&doc.c[i], self.attr)
|
||||
if attr_value in (0, 1):
|
||||
# Value is boolean, convert to string
|
||||
string_attr_value = str(attr_value)
|
||||
else:
|
||||
string_attr_value = self.vocab.strings[attr_value]
|
||||
string_attr_name = self.vocab.strings[self.attr]
|
||||
# Concatenate the attr name and value to not pollute lexeme space
|
||||
# e.g. 'POS-VERB' instead of just 'VERB', which could otherwise
|
||||
# create false positive matches
|
||||
return 'matcher:{}-{}'.format(string_attr_name, string_attr_value)
|
||||
|
||||
|
||||
cdef class DependencyTreeMatcher:
|
||||
"""Match dependency parse tree based on pattern rules."""
|
||||
|
|
|
@ -5,6 +5,8 @@ import pytest
|
|||
from spacy.matcher import PhraseMatcher
|
||||
from spacy.tokens import Doc
|
||||
|
||||
from ..util import get_doc
|
||||
|
||||
|
||||
def test_matcher_phrase_matcher(en_vocab):
|
||||
doc = Doc(en_vocab, words=["Google", "Now"])
|
||||
|
@ -28,3 +30,53 @@ def test_phrase_matcher_contains(en_vocab):
|
|||
matcher.add('TEST', None, Doc(en_vocab, words=['test']))
|
||||
assert 'TEST' in matcher
|
||||
assert 'TEST2' not in matcher
|
||||
|
||||
|
||||
def test_phrase_matcher_string_attrs(en_vocab):
|
||||
words1 = ['I', 'like', 'cats']
|
||||
pos1 = ['PRON', 'VERB', 'NOUN']
|
||||
words2 = ['Yes', ',', 'you', 'hate', 'dogs', 'very', 'much']
|
||||
pos2 = ['INTJ', 'PUNCT', 'PRON', 'VERB', 'NOUN', 'ADV', 'ADV']
|
||||
pattern = get_doc(en_vocab, words=words1, pos=pos1)
|
||||
matcher = PhraseMatcher(en_vocab, attr='POS')
|
||||
matcher.add('TEST', None, pattern)
|
||||
doc = get_doc(en_vocab, words=words2, pos=pos2)
|
||||
matches = matcher(doc)
|
||||
assert len(matches) == 1
|
||||
match_id, start, end = matches[0]
|
||||
assert match_id == en_vocab.strings['TEST']
|
||||
assert start == 2
|
||||
assert end == 5
|
||||
|
||||
|
||||
def test_phrase_matcher_string_attrs_negative(en_vocab):
|
||||
"""Test that token with the control codes as ORTH are *not* matched."""
|
||||
words1 = ['I', 'like', 'cats']
|
||||
pos1 = ['PRON', 'VERB', 'NOUN']
|
||||
words2 = ['matcher:POS-PRON', 'matcher:POS-VERB', 'matcher:POS-NOUN']
|
||||
pos2 = ['X', 'X', 'X']
|
||||
pattern = get_doc(en_vocab, words=words1, pos=pos1)
|
||||
matcher = PhraseMatcher(en_vocab, attr='POS')
|
||||
matcher.add('TEST', None, pattern)
|
||||
doc = get_doc(en_vocab, words=words2, pos=pos2)
|
||||
matches = matcher(doc)
|
||||
assert len(matches) == 0
|
||||
|
||||
|
||||
def test_phrase_matcher_bool_attrs(en_vocab):
|
||||
words1 = ['Hello', 'world', '!']
|
||||
words2 = ['No', 'problem', ',', 'he', 'said', '.']
|
||||
pattern = Doc(en_vocab, words=words1)
|
||||
matcher = PhraseMatcher(en_vocab, attr='IS_PUNCT')
|
||||
matcher.add('TEST', None, pattern)
|
||||
doc = Doc(en_vocab, words=words2)
|
||||
matches = matcher(doc)
|
||||
assert len(matches) == 2
|
||||
match_id1, start1, end1 = matches[0]
|
||||
match_id2, start2, end2 = matches[1]
|
||||
assert match_id1 == en_vocab.strings['TEST']
|
||||
assert match_id2 == en_vocab.strings['TEST']
|
||||
assert start1 == 0
|
||||
assert end1 == 3
|
||||
assert start2 == 3
|
||||
assert end2 == 6
|
||||
|
|
Loading…
Reference in New Issue