mirror of https://github.com/explosion/spaCy.git
💫 Allow matching non-ORTH attributes in PhraseMatcher (#2925)
* Allow matching non-orth attributes in PhraseMatcher (see #1971) Usage: PhraseMatcher(nlp.vocab, attr='POS') * Allow attr argument to be int * Fix formatting * Fix typo
This commit is contained in:
parent
7ed9124a45
commit
e89708c3eb
|
@ -12,7 +12,7 @@ from .lexeme cimport attr_id_t
|
||||||
from .vocab cimport Vocab
|
from .vocab cimport Vocab
|
||||||
from .tokens.doc cimport Doc
|
from .tokens.doc cimport Doc
|
||||||
from .tokens.doc cimport get_token_attr
|
from .tokens.doc cimport get_token_attr
|
||||||
from .attrs cimport ID, attr_id_t, NULL_ATTR
|
from .attrs cimport ID, attr_id_t, NULL_ATTR, ORTH
|
||||||
from .errors import Errors, TempErrors, Warnings, deprecation_warning
|
from .errors import Errors, TempErrors, Warnings, deprecation_warning
|
||||||
|
|
||||||
from .attrs import IDS
|
from .attrs import IDS
|
||||||
|
@ -546,16 +546,21 @@ cdef class PhraseMatcher:
|
||||||
cdef Matcher matcher
|
cdef Matcher matcher
|
||||||
cdef PreshMap phrase_ids
|
cdef PreshMap phrase_ids
|
||||||
cdef int max_length
|
cdef int max_length
|
||||||
|
cdef attr_id_t attr
|
||||||
cdef public object _callbacks
|
cdef public object _callbacks
|
||||||
cdef public object _patterns
|
cdef public object _patterns
|
||||||
|
|
||||||
def __init__(self, Vocab vocab, max_length=0):
|
def __init__(self, Vocab vocab, max_length=0, attr='ORTH'):
|
||||||
if max_length != 0:
|
if max_length != 0:
|
||||||
deprecation_warning(Warnings.W010)
|
deprecation_warning(Warnings.W010)
|
||||||
self.mem = Pool()
|
self.mem = Pool()
|
||||||
self.max_length = max_length
|
self.max_length = max_length
|
||||||
self.vocab = vocab
|
self.vocab = vocab
|
||||||
self.matcher = Matcher(self.vocab)
|
self.matcher = Matcher(self.vocab)
|
||||||
|
if isinstance(attr, long):
|
||||||
|
self.attr = attr
|
||||||
|
else:
|
||||||
|
self.attr = self.vocab.strings[attr]
|
||||||
self.phrase_ids = PreshMap()
|
self.phrase_ids = PreshMap()
|
||||||
abstract_patterns = [
|
abstract_patterns = [
|
||||||
[{U_ENT: True}],
|
[{U_ENT: True}],
|
||||||
|
@ -609,7 +614,8 @@ cdef class PhraseMatcher:
|
||||||
tags = get_bilou(length)
|
tags = get_bilou(length)
|
||||||
phrase_key = <attr_t*>mem.alloc(length, sizeof(attr_t))
|
phrase_key = <attr_t*>mem.alloc(length, sizeof(attr_t))
|
||||||
for i, tag in enumerate(tags):
|
for i, tag in enumerate(tags):
|
||||||
lexeme = self.vocab[doc.c[i].lex.orth]
|
attr_value = self.get_lex_value(doc, i)
|
||||||
|
lexeme = self.vocab[attr_value]
|
||||||
lexeme.set_flag(tag, True)
|
lexeme.set_flag(tag, True)
|
||||||
phrase_key[i] = lexeme.orth
|
phrase_key[i] = lexeme.orth
|
||||||
phrase_hash = hash64(phrase_key,
|
phrase_hash = hash64(phrase_key,
|
||||||
|
@ -625,8 +631,16 @@ cdef class PhraseMatcher:
|
||||||
`doc[start:end]`. The `label_id` and `key` are both integers.
|
`doc[start:end]`. The `label_id` and `key` are both integers.
|
||||||
"""
|
"""
|
||||||
matches = []
|
matches = []
|
||||||
for _, start, end in self.matcher(doc):
|
if self.attr == ORTH:
|
||||||
ent_id = self.accept_match(doc, start, end)
|
match_doc = doc
|
||||||
|
else:
|
||||||
|
# If we're not matching on the ORTH, match_doc will be a Doc whose
|
||||||
|
# token.orth values are the attribute values we're matching on,
|
||||||
|
# e.g. Doc(nlp.vocab, words=[token.pos_ for token in doc])
|
||||||
|
words = [self.get_lex_value(doc, i) for i in range(len(doc))]
|
||||||
|
match_doc = Doc(self.vocab, words=words)
|
||||||
|
for _, start, end in self.matcher(match_doc):
|
||||||
|
ent_id = self.accept_match(match_doc, start, end)
|
||||||
if ent_id is not None:
|
if ent_id is not None:
|
||||||
matches.append((ent_id, start, end))
|
matches.append((ent_id, start, end))
|
||||||
for i, (ent_id, start, end) in enumerate(matches):
|
for i, (ent_id, start, end) in enumerate(matches):
|
||||||
|
@ -680,6 +694,23 @@ cdef class PhraseMatcher:
|
||||||
else:
|
else:
|
||||||
return ent_id
|
return ent_id
|
||||||
|
|
||||||
|
def get_lex_value(self, Doc doc, int i):
|
||||||
|
if self.attr == ORTH:
|
||||||
|
# Return the regular orth value of the lexeme
|
||||||
|
return doc.c[i].lex.orth
|
||||||
|
# Get the attribute value instead, e.g. token.pos
|
||||||
|
attr_value = get_token_attr(&doc.c[i], self.attr)
|
||||||
|
if attr_value in (0, 1):
|
||||||
|
# Value is boolean, convert to string
|
||||||
|
string_attr_value = str(attr_value)
|
||||||
|
else:
|
||||||
|
string_attr_value = self.vocab.strings[attr_value]
|
||||||
|
string_attr_name = self.vocab.strings[self.attr]
|
||||||
|
# Concatenate the attr name and value to not pollute lexeme space
|
||||||
|
# e.g. 'POS-VERB' instead of just 'VERB', which could otherwise
|
||||||
|
# create false positive matches
|
||||||
|
return 'matcher:{}-{}'.format(string_attr_name, string_attr_value)
|
||||||
|
|
||||||
|
|
||||||
cdef class DependencyTreeMatcher:
|
cdef class DependencyTreeMatcher:
|
||||||
"""Match dependency parse tree based on pattern rules."""
|
"""Match dependency parse tree based on pattern rules."""
|
||||||
|
|
|
@ -5,6 +5,8 @@ import pytest
|
||||||
from spacy.matcher import PhraseMatcher
|
from spacy.matcher import PhraseMatcher
|
||||||
from spacy.tokens import Doc
|
from spacy.tokens import Doc
|
||||||
|
|
||||||
|
from ..util import get_doc
|
||||||
|
|
||||||
|
|
||||||
def test_matcher_phrase_matcher(en_vocab):
|
def test_matcher_phrase_matcher(en_vocab):
|
||||||
doc = Doc(en_vocab, words=["Google", "Now"])
|
doc = Doc(en_vocab, words=["Google", "Now"])
|
||||||
|
@ -28,3 +30,53 @@ def test_phrase_matcher_contains(en_vocab):
|
||||||
matcher.add('TEST', None, Doc(en_vocab, words=['test']))
|
matcher.add('TEST', None, Doc(en_vocab, words=['test']))
|
||||||
assert 'TEST' in matcher
|
assert 'TEST' in matcher
|
||||||
assert 'TEST2' not in matcher
|
assert 'TEST2' not in matcher
|
||||||
|
|
||||||
|
|
||||||
|
def test_phrase_matcher_string_attrs(en_vocab):
|
||||||
|
words1 = ['I', 'like', 'cats']
|
||||||
|
pos1 = ['PRON', 'VERB', 'NOUN']
|
||||||
|
words2 = ['Yes', ',', 'you', 'hate', 'dogs', 'very', 'much']
|
||||||
|
pos2 = ['INTJ', 'PUNCT', 'PRON', 'VERB', 'NOUN', 'ADV', 'ADV']
|
||||||
|
pattern = get_doc(en_vocab, words=words1, pos=pos1)
|
||||||
|
matcher = PhraseMatcher(en_vocab, attr='POS')
|
||||||
|
matcher.add('TEST', None, pattern)
|
||||||
|
doc = get_doc(en_vocab, words=words2, pos=pos2)
|
||||||
|
matches = matcher(doc)
|
||||||
|
assert len(matches) == 1
|
||||||
|
match_id, start, end = matches[0]
|
||||||
|
assert match_id == en_vocab.strings['TEST']
|
||||||
|
assert start == 2
|
||||||
|
assert end == 5
|
||||||
|
|
||||||
|
|
||||||
|
def test_phrase_matcher_string_attrs_negative(en_vocab):
|
||||||
|
"""Test that token with the control codes as ORTH are *not* matched."""
|
||||||
|
words1 = ['I', 'like', 'cats']
|
||||||
|
pos1 = ['PRON', 'VERB', 'NOUN']
|
||||||
|
words2 = ['matcher:POS-PRON', 'matcher:POS-VERB', 'matcher:POS-NOUN']
|
||||||
|
pos2 = ['X', 'X', 'X']
|
||||||
|
pattern = get_doc(en_vocab, words=words1, pos=pos1)
|
||||||
|
matcher = PhraseMatcher(en_vocab, attr='POS')
|
||||||
|
matcher.add('TEST', None, pattern)
|
||||||
|
doc = get_doc(en_vocab, words=words2, pos=pos2)
|
||||||
|
matches = matcher(doc)
|
||||||
|
assert len(matches) == 0
|
||||||
|
|
||||||
|
|
||||||
|
def test_phrase_matcher_bool_attrs(en_vocab):
|
||||||
|
words1 = ['Hello', 'world', '!']
|
||||||
|
words2 = ['No', 'problem', ',', 'he', 'said', '.']
|
||||||
|
pattern = Doc(en_vocab, words=words1)
|
||||||
|
matcher = PhraseMatcher(en_vocab, attr='IS_PUNCT')
|
||||||
|
matcher.add('TEST', None, pattern)
|
||||||
|
doc = Doc(en_vocab, words=words2)
|
||||||
|
matches = matcher(doc)
|
||||||
|
assert len(matches) == 2
|
||||||
|
match_id1, start1, end1 = matches[0]
|
||||||
|
match_id2, start2, end2 = matches[1]
|
||||||
|
assert match_id1 == en_vocab.strings['TEST']
|
||||||
|
assert match_id2 == en_vocab.strings['TEST']
|
||||||
|
assert start1 == 0
|
||||||
|
assert end1 == 3
|
||||||
|
assert start2 == 3
|
||||||
|
assert end2 == 6
|
||||||
|
|
Loading…
Reference in New Issue