diff --git a/spacy/matcher.pyx b/spacy/matcher.pyx index 2cc91a368..f6f1ad3ba 100644 --- a/spacy/matcher.pyx +++ b/spacy/matcher.pyx @@ -8,6 +8,7 @@ from cymem.cymem cimport Pool from libcpp.vector cimport vector from .attrs cimport LENGTH, ENT_TYPE, ORTH, NORM, LEMMA, LOWER, SHAPE +from .attrs cimport FLAG13, FLAG14, FLAG15, FLAG16, FLAG17, FLAG18, FLAG19, FLAG20, FLAG21, FLAG22, FLAG23, FLAG24, FLAG25 from .tokens.doc cimport get_token_attr from .tokens.doc cimport Doc from .vocab cimport Vocab @@ -53,6 +54,8 @@ cdef int match(const Pattern* pattern, const TokenC* token) except -1: cdef int i for i in range(pattern.length): if get_token_attr(token, pattern.spec[i].attr) != pattern.spec[i].value: + print "Pattern fail", pattern.spec[i].attr, pattern.spec[i].value + print get_token_attr(token, pattern.spec[i].attr) return False return True @@ -76,7 +79,10 @@ def _convert_strings(token_specs, string_store): attr = map_attr_name(attr) if isinstance(value, basestring): value = string_store[value] + if isinstance(value, bool): + value = int(value) converted[-1].append((attr, value)) + print "Converted", converted[-1] return converted @@ -92,6 +98,32 @@ def map_attr_name(attr): return SHAPE elif attr == 'NORM': return NORM + elif attr == 'FLAG13': + return FLAG13 + elif attr == 'FLAG14': + return FLAG14 + elif attr == 'FLAG15': + return FLAG15 + elif attr == 'FLAG16': + return FLAG16 + elif attr == 'FLAG17': + return FLAG17 + elif attr == 'FLAG18': + return FLAG18 + elif attr == 'FLAG19': + return FLAG19 + elif attr == 'FLAG20': + return FLAG20 + elif attr == 'FLAG21': + return FLAG21 + elif attr == 'FLAG22': + return FLAG22 + elif attr == 'FLAG23': + return FLAG23 + elif attr == 'FLAG24': + return FLAG24 + elif attr == 'FLAG25': + return FLAG25 else: raise Exception("TODO: Finish supporting attr mapping %s" % attr) @@ -130,6 +162,7 @@ cdef class Matcher: # TODO: Do something more clever about multiple patterns for single # entity for spec in specs: + assert len(spec) >= 1 spec = _convert_strings(spec, self.vocab.strings) self.patterns.push_back(init_pattern(self.mem, spec, etype)) @@ -142,11 +175,13 @@ cdef class Matcher: cdef Pattern* state matches = [] for token_i in range(doc.length): + print 'check', doc[token_i].orth_ token = &doc.data[token_i] q = 0 for i in range(partials.size()): state = partials.at(i) if match(state, token): + print 'match!' if is_final(state): matches.append(get_entity(state, token, token_i)) else: @@ -156,6 +191,7 @@ cdef class Matcher: for i in range(self.n_patterns): state = self.patterns[i] if match(state, token): + print 'match!' if is_final(state): matches.append(get_entity(state, token, token_i)) else: