spaCy/spacy/matcher.pyx

317 lines
10 KiB
Cython
Raw Normal View History

2015-10-08 15:00:45 +00:00
# cython: profile=True
from __future__ import unicode_literals
2015-08-06 12:33:21 +00:00
from os import path
from .typedefs cimport attr_t
2015-10-08 15:00:45 +00:00
from .typedefs cimport hash_t
from .attrs cimport attr_id_t
2015-10-08 15:00:45 +00:00
from .structs cimport TokenC, LexemeC
from .lexeme cimport Lexeme
from cymem.cymem cimport Pool
2015-10-08 15:00:45 +00:00
from preshed.maps cimport PreshMap
from libcpp.vector cimport vector
2015-10-08 15:00:45 +00:00
from murmurhash.mrmr cimport hash64
2015-08-06 12:33:21 +00:00
from .attrs cimport LENGTH, ENT_TYPE, ORTH, NORM, LEMMA, LOWER, SHAPE
from .tokens.doc cimport get_token_attr
from .tokens.doc cimport Doc
from .vocab cimport Vocab
2015-10-08 15:00:45 +00:00
from .attrs import FLAG61 as U_ENT
2016-01-13 18:46:17 +00:00
from .util import get_package
2015-10-08 15:00:45 +00:00
from .attrs import FLAG60 as B2_ENT
from .attrs import FLAG59 as B3_ENT
from .attrs import FLAG58 as B4_ENT
from .attrs import FLAG57 as B5_ENT
from .attrs import FLAG56 as B6_ENT
from .attrs import FLAG55 as B7_ENT
from .attrs import FLAG54 as B8_ENT
from .attrs import FLAG53 as B9_ENT
from .attrs import FLAG52 as B10_ENT
from .attrs import FLAG51 as I3_ENT
from .attrs import FLAG50 as I4_ENT
from .attrs import FLAG49 as I5_ENT
from .attrs import FLAG48 as I6_ENT
from .attrs import FLAG47 as I7_ENT
from .attrs import FLAG46 as I8_ENT
from .attrs import FLAG45 as I9_ENT
from .attrs import FLAG44 as I10_ENT
from .attrs import FLAG43 as L2_ENT
from .attrs import FLAG42 as L3_ENT
from .attrs import FLAG41 as L4_ENT
from .attrs import FLAG40 as L5_ENT
from .attrs import FLAG39 as L6_ENT
from .attrs import FLAG38 as L7_ENT
from .attrs import FLAG37 as L8_ENT
from .attrs import FLAG36 as L9_ENT
from .attrs import FLAG35 as L10_ENT
2015-08-06 12:33:21 +00:00
try:
import ujson as json
except ImportError:
import json
cdef struct AttrValue:
attr_id_t attr
attr_t value
cdef struct Pattern:
AttrValue* spec
int length
cdef Pattern* init_pattern(Pool mem, object token_specs, attr_t entity_type) except NULL:
pattern = <Pattern*>mem.alloc(len(token_specs) + 1, sizeof(Pattern))
cdef int i
for i, spec in enumerate(token_specs):
pattern[i].spec = <AttrValue*>mem.alloc(len(spec), sizeof(AttrValue))
pattern[i].length = len(spec)
for j, (attr, value) in enumerate(spec):
pattern[i].spec[j].attr = attr
pattern[i].spec[j].value = value
i = len(token_specs)
2015-10-08 15:00:45 +00:00
pattern[i].spec = <AttrValue*>mem.alloc(2, sizeof(AttrValue))
pattern[i].spec[0].attr = ENT_TYPE
pattern[i].spec[0].value = entity_type
pattern[i].spec[1].attr = LENGTH
pattern[i].spec[1].value = len(token_specs)
pattern[i].length = 0
return pattern
cdef int match(const Pattern* pattern, const TokenC* token) except -1:
cdef int i
for i in range(pattern.length):
if get_token_attr(token, pattern.spec[i].attr) != pattern.spec[i].value:
return False
return True
cdef int is_final(const Pattern* pattern) except -1:
return (pattern + 1).length == 0
cdef object get_entity(const Pattern* pattern, const TokenC* tokens, int i):
pattern += 1
i += 1
return (pattern.spec[0].value, i - pattern.spec[1].value, i)
2015-08-06 12:33:21 +00:00
def _convert_strings(token_specs, string_store):
converted = []
for spec in token_specs:
converted.append([])
for attr, value in spec.items():
if isinstance(attr, basestring):
attr = map_attr_name(attr)
if isinstance(value, basestring):
value = string_store[value]
if isinstance(value, bool):
value = int(value)
2015-08-06 12:33:21 +00:00
converted[-1].append((attr, value))
return converted
2015-10-08 15:00:45 +00:00
def get_bilou(length):
if length == 1:
return [U_ENT]
elif length == 2:
return [B2_ENT, L2_ENT]
elif length == 3:
return [B3_ENT, I3_ENT, L3_ENT]
elif length == 4:
return [B4_ENT, I4_ENT, I4_ENT, L4_ENT]
elif length == 5:
return [B5_ENT, I5_ENT, I5_ENT, I5_ENT, L5_ENT]
elif length == 6:
return [B6_ENT, I6_ENT, I6_ENT, I6_ENT, I6_ENT, L6_ENT]
elif length == 7:
return [B7_ENT, I7_ENT, I7_ENT, I7_ENT, I7_ENT, I7_ENT, L7_ENT]
elif length == 8:
return [B8_ENT, I8_ENT, I8_ENT, I8_ENT, I8_ENT, I8_ENT, I8_ENT, L8_ENT]
elif length == 9:
return [B9_ENT, I9_ENT, I9_ENT, I9_ENT, I9_ENT, I9_ENT, I9_ENT, I9_ENT, L9_ENT]
elif length == 10:
return [B10_ENT, I10_ENT, I10_ENT, I10_ENT, I10_ENT, I10_ENT, I10_ENT,
I10_ENT, I10_ENT, L10_ENT]
else:
raise ValueError("Max length currently 10 for phrase matching")
2015-08-06 12:33:21 +00:00
def map_attr_name(attr):
attr = attr.upper()
if attr == 'ORTH':
return ORTH
elif attr == 'LEMMA':
return LEMMA
elif attr == 'LOWER':
return LOWER
elif attr == 'SHAPE':
2015-08-06 12:33:21 +00:00
return SHAPE
elif attr == 'NORM':
return NORM
else:
raise Exception("TODO: Finish supporting attr mapping %s" % attr)
cdef class Matcher:
cdef Pool mem
2015-08-22 20:04:34 +00:00
cdef vector[Pattern*] patterns
cdef readonly Vocab vocab
cdef object _patterns
2015-08-22 20:04:34 +00:00
@classmethod
2016-01-16 11:23:45 +00:00
def load(cls, data_dir, Vocab vocab):
return cls.from_package(get_package(data_dir), vocab=vocab)
@classmethod
def from_package(cls, package, Vocab vocab):
2015-12-31 03:13:15 +00:00
patterns = package.load_json(('vocab', 'gazetteer.json'))
2015-12-07 05:01:28 +00:00
return cls(vocab, patterns)
def __init__(self, vocab, patterns):
self.vocab = vocab
self.mem = Pool()
self.vocab = vocab
self._patterns = dict(patterns)
for entity_key, (etype, attrs, specs) in sorted(patterns.items()):
self.add(entity_key, etype, attrs, specs)
def __reduce__(self):
return (self.__class__, (self.vocab, self._patterns), None, None)
property n_patterns:
def __get__(self): return self.patterns.size()
2015-08-22 20:04:34 +00:00
def add(self, entity_key, etype, attrs, specs):
self._patterns[entity_key] = (etype, dict(attrs), list(specs))
2015-08-22 20:04:34 +00:00
if isinstance(entity_key, basestring):
entity_key = self.vocab.strings[entity_key]
2015-08-22 20:04:34 +00:00
if isinstance(etype, basestring):
etype = self.vocab.strings[etype]
2015-08-22 20:04:34 +00:00
elif etype is None:
etype = -1
# TODO: Do something more clever about multiple patterns for single
# entity
for spec in specs:
spec = _convert_strings(spec, self.vocab.strings)
2015-08-22 20:04:34 +00:00
self.patterns.push_back(init_pattern(self.mem, spec, etype))
2015-10-08 15:00:45 +00:00
def __call__(self, Doc doc, acceptor=None):
cdef vector[Pattern*] partials
cdef int n_partials = 0
cdef int q = 0
cdef int i, token_i
cdef const TokenC* token
cdef Pattern* state
matches = []
for token_i in range(doc.length):
2015-11-03 13:15:14 +00:00
token = &doc.c[token_i]
q = 0
2015-10-08 15:00:45 +00:00
# Go over the open matches, extending or finalizing if able. Otherwise,
# we over-write them (q doesn't advance)
for state in partials:
if match(state, token):
if is_final(state):
2015-10-08 15:00:45 +00:00
label, start, end = get_entity(state, token, token_i)
if acceptor is None or acceptor(doc, label, start, end):
matches.append((label, start, end))
else:
partials[q] = state + 1
q += 1
partials.resize(q)
2015-10-08 15:00:45 +00:00
# Check whether we open any new patterns on this token
for state in self.patterns:
if match(state, token):
if is_final(state):
2015-10-08 15:00:45 +00:00
label, start, end = get_entity(state, token, token_i)
if acceptor is None or acceptor(doc, label, start, end):
matches.append((label, start, end))
else:
partials.push_back(state + 1)
seen = set()
filtered = []
for label, start, end in sorted(matches, key=lambda m: (m[1], -(m[1] - m[2]))):
if all(i in seen for i in range(start, end)):
continue
else:
for i in range(start, end):
seen.add(i)
filtered.append((label, start, end))
doc.ents = [(e.label, e.start, e.end) for e in doc.ents] + filtered
return matches
2015-10-08 15:00:45 +00:00
cdef class PhraseMatcher:
cdef Pool mem
cdef Vocab vocab
cdef Matcher matcher
cdef PreshMap phrase_ids
cdef int max_length
cdef attr_t* _phrase_key
def __init__(self, Vocab vocab, phrases, max_length=10):
self.mem = Pool()
self._phrase_key = <attr_t*>self.mem.alloc(max_length, sizeof(attr_t))
self.max_length = max_length
self.vocab = vocab
self.matcher = Matcher(self.vocab, {})
self.phrase_ids = PreshMap()
for phrase in phrases:
if len(phrase) < max_length:
self.add(phrase)
abstract_patterns = []
for length in range(1, max_length):
abstract_patterns.append([{tag: True} for tag in get_bilou(length)])
self.matcher.add('Candidate', 'MWE', {}, abstract_patterns)
def add(self, Doc tokens):
cdef int length = tokens.length
assert length < self.max_length
tags = get_bilou(length)
assert len(tags) == length, length
cdef int i
for i in range(self.max_length):
self._phrase_key[i] = 0
for i, tag in enumerate(tags):
2015-11-03 13:15:14 +00:00
lexeme = self.vocab[tokens.c[i].lex.orth]
2015-10-08 15:00:45 +00:00
lexeme.set_flag(tag, True)
self._phrase_key[i] = lexeme.orth
cdef hash_t key = hash64(self._phrase_key, self.max_length * sizeof(attr_t), 0)
self.phrase_ids[key] = True
def __call__(self, Doc doc):
matches = []
for label, start, end in self.matcher(doc, acceptor=self.accept_match):
cand = doc[start : end]
start = cand[0].idx
end = cand[-1].idx + len(cand[-1])
matches.append((start, end, cand.root.tag_, cand.text, 'MWE'))
for match in matches:
doc.merge(*match)
return matches
def accept_match(self, Doc doc, int label, int start, int end):
assert (end - start) < self.max_length
cdef int i, j
for i in range(self.max_length):
self._phrase_key[i] = 0
for i, j in enumerate(range(start, end)):
2015-11-03 13:15:14 +00:00
self._phrase_key[i] = doc.c[j].lex.orth
2015-10-08 15:00:45 +00:00
cdef hash_t key = hash64(self._phrase_key, self.max_length * sizeof(attr_t), 0)
if self.phrase_ids.get(key):
return True
else:
return False