* Temporarily import flag attributes in matcher

This commit is contained in:
Matthew Honnibal 2015-09-06 17:53:12 +02:00
parent 7cc56ada6e
commit 6427a3fcac
1 changed files with 36 additions and 0 deletions

View File

@ -8,6 +8,7 @@ from cymem.cymem cimport Pool
from libcpp.vector cimport vector
from .attrs cimport LENGTH, ENT_TYPE, ORTH, NORM, LEMMA, LOWER, SHAPE
from .attrs cimport FLAG13, FLAG14, FLAG15, FLAG16, FLAG17, FLAG18, FLAG19, FLAG20, FLAG21, FLAG22, FLAG23, FLAG24, FLAG25
from .tokens.doc cimport get_token_attr
from .tokens.doc cimport Doc
from .vocab cimport Vocab
@ -53,6 +54,8 @@ cdef int match(const Pattern* pattern, const TokenC* token) except -1:
cdef int i
for i in range(pattern.length):
if get_token_attr(token, pattern.spec[i].attr) != pattern.spec[i].value:
print "Pattern fail", pattern.spec[i].attr, pattern.spec[i].value
print get_token_attr(token, pattern.spec[i].attr)
return False
return True
@ -76,7 +79,10 @@ def _convert_strings(token_specs, string_store):
attr = map_attr_name(attr)
if isinstance(value, basestring):
value = string_store[value]
if isinstance(value, bool):
value = int(value)
converted[-1].append((attr, value))
print "Converted", converted[-1]
return converted
@ -92,6 +98,32 @@ def map_attr_name(attr):
return SHAPE
elif attr == 'NORM':
return NORM
elif attr == 'FLAG13':
return FLAG13
elif attr == 'FLAG14':
return FLAG14
elif attr == 'FLAG15':
return FLAG15
elif attr == 'FLAG16':
return FLAG16
elif attr == 'FLAG17':
return FLAG17
elif attr == 'FLAG18':
return FLAG18
elif attr == 'FLAG19':
return FLAG19
elif attr == 'FLAG20':
return FLAG20
elif attr == 'FLAG21':
return FLAG21
elif attr == 'FLAG22':
return FLAG22
elif attr == 'FLAG23':
return FLAG23
elif attr == 'FLAG24':
return FLAG24
elif attr == 'FLAG25':
return FLAG25
else:
raise Exception("TODO: Finish supporting attr mapping %s" % attr)
@ -130,6 +162,7 @@ cdef class Matcher:
# TODO: Do something more clever about multiple patterns for single
# entity
for spec in specs:
assert len(spec) >= 1
spec = _convert_strings(spec, self.vocab.strings)
self.patterns.push_back(init_pattern(self.mem, spec, etype))
@ -142,11 +175,13 @@ cdef class Matcher:
cdef Pattern* state
matches = []
for token_i in range(doc.length):
print 'check', doc[token_i].orth_
token = &doc.data[token_i]
q = 0
for i in range(partials.size()):
state = partials.at(i)
if match(state, token):
print 'match!'
if is_final(state):
matches.append(get_entity(state, token, token_i))
else:
@ -156,6 +191,7 @@ cdef class Matcher:
for i in range(self.n_patterns):
state = self.patterns[i]
if match(state, token):
print 'match!'
if is_final(state):
matches.append(get_entity(state, token, token_i))
else: