* Fix phrase matcher

This commit is contained in:
Matthew Honnibal 2015-10-09 02:00:45 +11:00
parent b3a70e6375
commit 801d55a6d9
1 changed files with 144 additions and 32 deletions

View File

@ -1,11 +1,18 @@
# cython: profile=True
from __future__ import unicode_literals
from os import path from os import path
from .typedefs cimport attr_t from .typedefs cimport attr_t
from .typedefs cimport hash_t
from .attrs cimport attr_id_t from .attrs cimport attr_id_t
from .structs cimport TokenC from .structs cimport TokenC, LexemeC
from .lexeme cimport Lexeme
from cymem.cymem cimport Pool from cymem.cymem cimport Pool
from preshed.maps cimport PreshMap
from libcpp.vector cimport vector from libcpp.vector cimport vector
from murmurhash.mrmr cimport hash64
from .attrs cimport LENGTH, ENT_TYPE, ORTH, NORM, LEMMA, LOWER, SHAPE from .attrs cimport LENGTH, ENT_TYPE, ORTH, NORM, LEMMA, LOWER, SHAPE
from .attrs cimport FLAG13, FLAG14, FLAG15, FLAG16, FLAG17, FLAG18, FLAG19, FLAG20, FLAG21, FLAG22, FLAG23, FLAG24, FLAG25 from .attrs cimport FLAG13, FLAG14, FLAG15, FLAG16, FLAG17, FLAG18, FLAG19, FLAG20, FLAG21, FLAG22, FLAG23, FLAG24, FLAG25
@ -15,6 +22,38 @@ from .vocab cimport Vocab
from libcpp.vector cimport vector from libcpp.vector cimport vector
from .attrs import FLAG61 as U_ENT
from .attrs import FLAG60 as B2_ENT
from .attrs import FLAG59 as B3_ENT
from .attrs import FLAG58 as B4_ENT
from .attrs import FLAG57 as B5_ENT
from .attrs import FLAG56 as B6_ENT
from .attrs import FLAG55 as B7_ENT
from .attrs import FLAG54 as B8_ENT
from .attrs import FLAG53 as B9_ENT
from .attrs import FLAG52 as B10_ENT
from .attrs import FLAG51 as I3_ENT
from .attrs import FLAG50 as I4_ENT
from .attrs import FLAG49 as I5_ENT
from .attrs import FLAG48 as I6_ENT
from .attrs import FLAG47 as I7_ENT
from .attrs import FLAG46 as I8_ENT
from .attrs import FLAG45 as I9_ENT
from .attrs import FLAG44 as I10_ENT
from .attrs import FLAG43 as L2_ENT
from .attrs import FLAG42 as L3_ENT
from .attrs import FLAG41 as L4_ENT
from .attrs import FLAG40 as L5_ENT
from .attrs import FLAG39 as L6_ENT
from .attrs import FLAG38 as L7_ENT
from .attrs import FLAG37 as L8_ENT
from .attrs import FLAG36 as L9_ENT
from .attrs import FLAG35 as L10_ENT
try: try:
import ujson as json import ujson as json
except ImportError: except ImportError:
@ -41,7 +80,7 @@ cdef Pattern* init_pattern(Pool mem, object token_specs, attr_t entity_type) exc
pattern[i].spec[j].attr = attr pattern[i].spec[j].attr = attr
pattern[i].spec[j].value = value pattern[i].spec[j].value = value
i = len(token_specs) i = len(token_specs)
pattern[i].spec = <AttrValue*>mem.alloc(1, sizeof(AttrValue)) pattern[i].spec = <AttrValue*>mem.alloc(2, sizeof(AttrValue))
pattern[i].spec[0].attr = ENT_TYPE pattern[i].spec[0].attr = ENT_TYPE
pattern[i].spec[0].value = entity_type pattern[i].spec[0].value = entity_type
pattern[i].spec[1].attr = LENGTH pattern[i].spec[1].attr = LENGTH
@ -83,6 +122,32 @@ def _convert_strings(token_specs, string_store):
return converted return converted
def get_bilou(length):
if length == 1:
return [U_ENT]
elif length == 2:
return [B2_ENT, L2_ENT]
elif length == 3:
return [B3_ENT, I3_ENT, L3_ENT]
elif length == 4:
return [B4_ENT, I4_ENT, I4_ENT, L4_ENT]
elif length == 5:
return [B5_ENT, I5_ENT, I5_ENT, I5_ENT, L5_ENT]
elif length == 6:
return [B6_ENT, I6_ENT, I6_ENT, I6_ENT, I6_ENT, L6_ENT]
elif length == 7:
return [B7_ENT, I7_ENT, I7_ENT, I7_ENT, I7_ENT, I7_ENT, L7_ENT]
elif length == 8:
return [B8_ENT, I8_ENT, I8_ENT, I8_ENT, I8_ENT, I8_ENT, I8_ENT, L8_ENT]
elif length == 9:
return [B9_ENT, I9_ENT, I9_ENT, I9_ENT, I9_ENT, I9_ENT, I9_ENT, I9_ENT, L9_ENT]
elif length == 10:
return [B10_ENT, I10_ENT, I10_ENT, I10_ENT, I10_ENT, I10_ENT, I10_ENT,
I10_ENT, I10_ENT, L10_ENT]
else:
raise ValueError("Max length currently 10 for phrase matching")
def map_attr_name(attr): def map_attr_name(attr):
attr = attr.upper() attr = attr.upper()
if attr == 'ORTH': if attr == 'ORTH':
@ -95,32 +160,6 @@ def map_attr_name(attr):
return SHAPE return SHAPE
elif attr == 'NORM': elif attr == 'NORM':
return NORM return NORM
elif attr == 'FLAG13':
return FLAG13
elif attr == 'FLAG14':
return FLAG14
elif attr == 'FLAG15':
return FLAG15
elif attr == 'FLAG16':
return FLAG16
elif attr == 'FLAG17':
return FLAG17
elif attr == 'FLAG18':
return FLAG18
elif attr == 'FLAG19':
return FLAG19
elif attr == 'FLAG20':
return FLAG20
elif attr == 'FLAG21':
return FLAG21
elif attr == 'FLAG22':
return FLAG22
elif attr == 'FLAG23':
return FLAG23
elif attr == 'FLAG24':
return FLAG24
elif attr == 'FLAG25':
return FLAG25
else: else:
raise Exception("TODO: Finish supporting attr mapping %s" % attr) raise Exception("TODO: Finish supporting attr mapping %s" % attr)
@ -163,7 +202,7 @@ cdef class Matcher:
spec = _convert_strings(spec, self.vocab.strings) spec = _convert_strings(spec, self.vocab.strings)
self.patterns.push_back(init_pattern(self.mem, spec, etype)) self.patterns.push_back(init_pattern(self.mem, spec, etype))
def __call__(self, Doc doc): def __call__(self, Doc doc, acceptor=None):
cdef vector[Pattern*] partials cdef vector[Pattern*] partials
cdef int n_partials = 0 cdef int n_partials = 0
cdef int q = 0 cdef int q = 0
@ -174,21 +213,94 @@ cdef class Matcher:
for token_i in range(doc.length): for token_i in range(doc.length):
token = &doc.data[token_i] token = &doc.data[token_i]
q = 0 q = 0
# Go over the open matches, extending or finalizing if able. Otherwise,
# we over-write them (q doesn't advance)
for i in range(partials.size()): for i in range(partials.size()):
state = partials.at(i) state = partials.at(i)
if match(state, token): if match(state, token):
if is_final(state): if is_final(state):
matches.append(get_entity(state, token, token_i)) label, start, end = get_entity(state, token, token_i)
if acceptor is None or acceptor(doc, label, start, end):
matches.append((label, start, end))
else: else:
partials[q] = state + 1 partials[q] = state + 1
q += 1 q += 1
partials.resize(q) partials.resize(q)
# Check whether we open any new patterns on this token
for i in range(self.n_patterns): for i in range(self.n_patterns):
state = self.patterns[i] state = self.patterns[i]
if match(state, token): if match(state, token):
if is_final(state): if is_final(state):
matches.append(get_entity(state, token, token_i)) label, start, end = get_entity(state, token, token_i)
if acceptor is None or acceptor(doc, label, start, end):
matches.append((label, start, end))
else: else:
partials.push_back(state + 1) partials.push_back(state + 1)
doc.ents = [(e.label, e.start, e.end) for e in doc.ents] + matches doc.ents = [(e.label, e.start, e.end) for e in doc.ents] + matches
return matches return matches
cdef class PhraseMatcher:
cdef Pool mem
cdef Vocab vocab
cdef Matcher matcher
cdef PreshMap phrase_ids
cdef int max_length
cdef attr_t* _phrase_key
def __init__(self, Vocab vocab, phrases, max_length=10):
self.mem = Pool()
self._phrase_key = <attr_t*>self.mem.alloc(max_length, sizeof(attr_t))
self.max_length = max_length
self.vocab = vocab
self.matcher = Matcher(self.vocab, {})
self.phrase_ids = PreshMap()
for phrase in phrases:
if len(phrase) < max_length:
self.add(phrase)
abstract_patterns = []
for length in range(1, max_length):
abstract_patterns.append([{tag: True} for tag in get_bilou(length)])
self.matcher.add('Candidate', 'MWE', {}, abstract_patterns)
def add(self, Doc tokens):
cdef int length = tokens.length
assert length < self.max_length
tags = get_bilou(length)
assert len(tags) == length, length
cdef int i
for i in range(self.max_length):
self._phrase_key[i] = 0
for i, tag in enumerate(tags):
lexeme = self.vocab[tokens.data[i].lex.orth]
lexeme.set_flag(tag, True)
self._phrase_key[i] = lexeme.orth
cdef hash_t key = hash64(self._phrase_key, self.max_length * sizeof(attr_t), 0)
self.phrase_ids[key] = True
def __call__(self, Doc doc):
matches = []
for label, start, end in self.matcher(doc, acceptor=self.accept_match):
cand = doc[start : end]
start = cand[0].idx
end = cand[-1].idx + len(cand[-1])
matches.append((start, end, cand.root.tag_, cand.text, 'MWE'))
for match in matches:
doc.merge(*match)
return matches
def accept_match(self, Doc doc, int label, int start, int end):
assert (end - start) < self.max_length
cdef int i, j
for i in range(self.max_length):
self._phrase_key[i] = 0
for i, j in enumerate(range(start, end)):
self._phrase_key[i] = doc.data[j].lex.orth
cdef hash_t key = hash64(self._phrase_key, self.max_length * sizeof(attr_t), 0)
if self.phrase_ids.get(key):
return True
else:
return False