diff --git a/spacy/matcher.pyx b/spacy/matcher.pyx index 601505770..66459d0c4 100644 --- a/spacy/matcher.pyx +++ b/spacy/matcher.pyx @@ -12,7 +12,7 @@ from .lexeme cimport attr_id_t from .vocab cimport Vocab from .tokens.doc cimport Doc from .tokens.doc cimport get_token_attr -from .attrs cimport ID, attr_id_t, NULL_ATTR +from .attrs cimport ID, attr_id_t, NULL_ATTR, ORTH from .errors import Errors, TempErrors, Warnings, deprecation_warning from .attrs import IDS @@ -546,16 +546,21 @@ cdef class PhraseMatcher: cdef Matcher matcher cdef PreshMap phrase_ids cdef int max_length + cdef attr_id_t attr cdef public object _callbacks cdef public object _patterns - def __init__(self, Vocab vocab, max_length=0): + def __init__(self, Vocab vocab, max_length=0, attr='ORTH'): if max_length != 0: deprecation_warning(Warnings.W010) self.mem = Pool() self.max_length = max_length self.vocab = vocab self.matcher = Matcher(self.vocab) + if isinstance(attr, long): + self.attr = attr + else: + self.attr = self.vocab.strings[attr] self.phrase_ids = PreshMap() abstract_patterns = [ [{U_ENT: True}], @@ -609,7 +614,8 @@ cdef class PhraseMatcher: tags = get_bilou(length) phrase_key = mem.alloc(length, sizeof(attr_t)) for i, tag in enumerate(tags): - lexeme = self.vocab[doc.c[i].lex.orth] + attr_value = self.get_lex_value(doc, i) + lexeme = self.vocab[attr_value] lexeme.set_flag(tag, True) phrase_key[i] = lexeme.orth phrase_hash = hash64(phrase_key, @@ -625,8 +631,16 @@ cdef class PhraseMatcher: `doc[start:end]`. The `label_id` and `key` are both integers. """ matches = [] - for _, start, end in self.matcher(doc): - ent_id = self.accept_match(doc, start, end) + if self.attr == ORTH: + match_doc = doc + else: + # If we're not matching on the ORTH, match_doc will be a Doc whose + # token.orth values are the attribute values we're matching on, + # e.g. Doc(nlp.vocab, words=[token.pos_ for token in doc]) + words = [self.get_lex_value(doc, i) for i in range(len(doc))] + match_doc = Doc(self.vocab, words=words) + for _, start, end in self.matcher(match_doc): + ent_id = self.accept_match(match_doc, start, end) if ent_id is not None: matches.append((ent_id, start, end)) for i, (ent_id, start, end) in enumerate(matches): @@ -680,6 +694,23 @@ cdef class PhraseMatcher: else: return ent_id + def get_lex_value(self, Doc doc, int i): + if self.attr == ORTH: + # Return the regular orth value of the lexeme + return doc.c[i].lex.orth + # Get the attribute value instead, e.g. token.pos + attr_value = get_token_attr(&doc.c[i], self.attr) + if attr_value in (0, 1): + # Value is boolean, convert to string + string_attr_value = str(attr_value) + else: + string_attr_value = self.vocab.strings[attr_value] + string_attr_name = self.vocab.strings[self.attr] + # Concatenate the attr name and value to not pollute lexeme space + # e.g. 'POS-VERB' instead of just 'VERB', which could otherwise + # create false positive matches + return 'matcher:{}-{}'.format(string_attr_name, string_attr_value) + cdef class DependencyTreeMatcher: """Match dependency parse tree based on pattern rules.""" diff --git a/spacy/tests/matcher/test_phrase_matcher.py b/spacy/tests/matcher/test_phrase_matcher.py index 578f2b5d0..125d7be74 100644 --- a/spacy/tests/matcher/test_phrase_matcher.py +++ b/spacy/tests/matcher/test_phrase_matcher.py @@ -5,6 +5,8 @@ import pytest from spacy.matcher import PhraseMatcher from spacy.tokens import Doc +from ..util import get_doc + def test_matcher_phrase_matcher(en_vocab): doc = Doc(en_vocab, words=["Google", "Now"]) @@ -28,3 +30,53 @@ def test_phrase_matcher_contains(en_vocab): matcher.add('TEST', None, Doc(en_vocab, words=['test'])) assert 'TEST' in matcher assert 'TEST2' not in matcher + + +def test_phrase_matcher_string_attrs(en_vocab): + words1 = ['I', 'like', 'cats'] + pos1 = ['PRON', 'VERB', 'NOUN'] + words2 = ['Yes', ',', 'you', 'hate', 'dogs', 'very', 'much'] + pos2 = ['INTJ', 'PUNCT', 'PRON', 'VERB', 'NOUN', 'ADV', 'ADV'] + pattern = get_doc(en_vocab, words=words1, pos=pos1) + matcher = PhraseMatcher(en_vocab, attr='POS') + matcher.add('TEST', None, pattern) + doc = get_doc(en_vocab, words=words2, pos=pos2) + matches = matcher(doc) + assert len(matches) == 1 + match_id, start, end = matches[0] + assert match_id == en_vocab.strings['TEST'] + assert start == 2 + assert end == 5 + + +def test_phrase_matcher_string_attrs_negative(en_vocab): + """Test that token with the control codes as ORTH are *not* matched.""" + words1 = ['I', 'like', 'cats'] + pos1 = ['PRON', 'VERB', 'NOUN'] + words2 = ['matcher:POS-PRON', 'matcher:POS-VERB', 'matcher:POS-NOUN'] + pos2 = ['X', 'X', 'X'] + pattern = get_doc(en_vocab, words=words1, pos=pos1) + matcher = PhraseMatcher(en_vocab, attr='POS') + matcher.add('TEST', None, pattern) + doc = get_doc(en_vocab, words=words2, pos=pos2) + matches = matcher(doc) + assert len(matches) == 0 + + +def test_phrase_matcher_bool_attrs(en_vocab): + words1 = ['Hello', 'world', '!'] + words2 = ['No', 'problem', ',', 'he', 'said', '.'] + pattern = Doc(en_vocab, words=words1) + matcher = PhraseMatcher(en_vocab, attr='IS_PUNCT') + matcher.add('TEST', None, pattern) + doc = Doc(en_vocab, words=words2) + matches = matcher(doc) + assert len(matches) == 2 + match_id1, start1, end1 = matches[0] + match_id2, start2, end2 = matches[1] + assert match_id1 == en_vocab.strings['TEST'] + assert match_id2 == en_vocab.strings['TEST'] + assert start1 == 0 + assert end1 == 3 + assert start2 == 3 + assert end2 == 6