mirror of https://github.com/explosion/spaCy.git
Lots of updates to Matcher, to make entity handling sane.
This commit is contained in:
parent
7fd98fc91c
commit
6cbdc94959
|
@ -92,8 +92,8 @@ ctypedef TokenPatternC* TokenPatternC_ptr
|
|||
ctypedef pair[int, TokenPatternC_ptr] StateC
|
||||
|
||||
|
||||
cdef TokenPatternC* init_pattern(Pool mem, object token_specs, attr_t entity_id,
|
||||
attr_t entity_type) except NULL:
|
||||
cdef TokenPatternC* init_pattern(Pool mem, attr_t entity_id, attr_t label,
|
||||
object token_specs) except NULL:
|
||||
pattern = <TokenPatternC*>mem.alloc(len(token_specs) + 1, sizeof(TokenPatternC))
|
||||
cdef int i
|
||||
for i, (quantifier, spec) in enumerate(token_specs):
|
||||
|
@ -108,7 +108,7 @@ cdef TokenPatternC* init_pattern(Pool mem, object token_specs, attr_t entity_id,
|
|||
pattern[i].attrs[0].attr = ID
|
||||
pattern[i].attrs[0].value = entity_id
|
||||
pattern[i].attrs[1].attr = ENT_TYPE
|
||||
pattern[i].attrs[1].value = entity_type
|
||||
pattern[i].attrs[1].value = label
|
||||
pattern[i].nr_attr = 0
|
||||
return pattern
|
||||
|
||||
|
@ -163,37 +163,14 @@ def _convert_strings(token_specs, string_store):
|
|||
return tokens
|
||||
|
||||
|
||||
def get_bilou(length):
|
||||
if length == 1:
|
||||
return [U_ENT]
|
||||
elif length == 2:
|
||||
return [B2_ENT, L2_ENT]
|
||||
elif length == 3:
|
||||
return [B3_ENT, I3_ENT, L3_ENT]
|
||||
elif length == 4:
|
||||
return [B4_ENT, I4_ENT, I4_ENT, L4_ENT]
|
||||
elif length == 5:
|
||||
return [B5_ENT, I5_ENT, I5_ENT, I5_ENT, L5_ENT]
|
||||
elif length == 6:
|
||||
return [B6_ENT, I6_ENT, I6_ENT, I6_ENT, I6_ENT, L6_ENT]
|
||||
elif length == 7:
|
||||
return [B7_ENT, I7_ENT, I7_ENT, I7_ENT, I7_ENT, I7_ENT, L7_ENT]
|
||||
elif length == 8:
|
||||
return [B8_ENT, I8_ENT, I8_ENT, I8_ENT, I8_ENT, I8_ENT, I8_ENT, L8_ENT]
|
||||
elif length == 9:
|
||||
return [B9_ENT, I9_ENT, I9_ENT, I9_ENT, I9_ENT, I9_ENT, I9_ENT, I9_ENT, L9_ENT]
|
||||
elif length == 10:
|
||||
return [B10_ENT, I10_ENT, I10_ENT, I10_ENT, I10_ENT, I10_ENT, I10_ENT,
|
||||
I10_ENT, I10_ENT, L10_ENT]
|
||||
else:
|
||||
raise ValueError("Max length currently 10 for phrase matching")
|
||||
|
||||
|
||||
cdef class Matcher:
|
||||
cdef Pool mem
|
||||
cdef vector[TokenPatternC*] patterns
|
||||
cdef readonly Vocab vocab
|
||||
cdef public object _patterns
|
||||
cdef public object _entities
|
||||
cdef public object _callbacks
|
||||
cdef public object _acceptors
|
||||
|
||||
@classmethod
|
||||
def load(cls, path, vocab):
|
||||
|
@ -205,12 +182,17 @@ cdef class Matcher:
|
|||
return cls(vocab, patterns)
|
||||
|
||||
def __init__(self, vocab, patterns={}):
|
||||
self._patterns = dict(patterns) # Make sure we own the object
|
||||
self._patterns = {}
|
||||
self._entities = {}
|
||||
self._acceptors = {}
|
||||
self._callbacks = {}
|
||||
self.vocab = vocab
|
||||
self.mem = Pool()
|
||||
self.vocab = vocab
|
||||
for entity_key, (etype, attrs, specs) in sorted(self._patterns.items()):
|
||||
self.add(entity_key, etype, attrs, specs)
|
||||
for entity_key, (etype, attrs, specs) in sorted(patterns.items()):
|
||||
self.add_entity(entity_key, attrs)
|
||||
for spec in specs:
|
||||
self.add_pattern(entity_key, spec, label=etype)
|
||||
|
||||
def __reduce__(self):
|
||||
return (self.__class__, (self.vocab, self._patterns), None, None)
|
||||
|
@ -218,21 +200,67 @@ cdef class Matcher:
|
|||
property n_patterns:
|
||||
def __get__(self): return self.patterns.size()
|
||||
|
||||
def add(self, entity_key, etype, attrs, specs):
|
||||
self._patterns[entity_key] = (etype, dict(attrs), list(specs))
|
||||
if isinstance(entity_key, basestring):
|
||||
entity_key = self.vocab.strings[entity_key]
|
||||
if isinstance(etype, basestring):
|
||||
etype = self.vocab.strings[etype]
|
||||
elif etype is None:
|
||||
etype = -1
|
||||
# TODO: Do something more clever about multiple patterns for single
|
||||
# entity
|
||||
def add_entity(self, entity_key, attrs=None, if_exists='raise',
|
||||
acceptor=None, on_match=None):
|
||||
if if_exists not in ('raise', 'ignore', 'update'):
|
||||
raise ValueError(
|
||||
"Unexpected value for if_exists: %s.\n"
|
||||
"Expected one of: ['raise', 'ignore', 'update']" % if_exists)
|
||||
if attrs is None:
|
||||
attrs = {}
|
||||
entity_key = self.normalize_entity_key(entity_key)
|
||||
if self.has_entity(entity_key):
|
||||
if if_exists == 'raise':
|
||||
raise KeyError(
|
||||
"Tried to add entity %s. Entity exists, and if_exists='raise'.\n"
|
||||
"Set if_exists='ignore' or if_exists='update', or check with "
|
||||
"matcher.has_entity()")
|
||||
elif if_exists == 'ignore':
|
||||
return
|
||||
self._entities[entity_key] = dict(attrs)
|
||||
self._patterns.setdefault(entity_key, [])
|
||||
self._acceptors[entity_key] = acceptor
|
||||
self._callbacks[entity_key] = on_match
|
||||
|
||||
def add_pattern(self, entity_key, token_specs, label=""):
|
||||
entity_key = self.normalize_entity_key(entity_key)
|
||||
if not self.has_entity(entity_key):
|
||||
self.add_entity(entity_key)
|
||||
if isinstance(label, basestring):
|
||||
label = self.vocab.strings[label]
|
||||
|
||||
spec = _convert_strings(token_specs, self.vocab.strings)
|
||||
self.patterns.push_back(init_pattern(self.mem, entity_key, label, spec))
|
||||
self._patterns[entity_key].append((label, token_specs))
|
||||
|
||||
def add(self, entity_key, label, attrs, specs, acceptor=None, on_match=None):
|
||||
self.add_entity(entity_key, attrs=attrs, if_exists='update',
|
||||
acceptor=acceptor, on_match=on_match)
|
||||
for spec in specs:
|
||||
spec = _convert_strings(spec, self.vocab.strings)
|
||||
self.patterns.push_back(init_pattern(self.mem, spec, entity_key, etype))
|
||||
self.add_pattern(entity_key, spec, label=label)
|
||||
|
||||
def normalize_entity_key(self, entity_key):
|
||||
if isinstance(entity_key, basestring):
|
||||
return self.vocab.strings[entity_key]
|
||||
else:
|
||||
return entity_key
|
||||
|
||||
def has_entity(self, entity_key):
|
||||
entity_key = self.normalize_entity_key(entity_key)
|
||||
return entity_key in self._entities
|
||||
|
||||
def get_entity(self, entity_key):
|
||||
entity_key = self.normalize_entity_key(entity_key)
|
||||
if entity_key in self._entities:
|
||||
return self._entities[entity_key]
|
||||
else:
|
||||
return None
|
||||
|
||||
def __call__(self, Doc doc, acceptor=None):
|
||||
if acceptor is not None:
|
||||
raise ValueError(
|
||||
"acceptor keyword argument to Matcher deprecated. Specify acceptor "
|
||||
"functions when you add patterns instead.")
|
||||
cdef vector[StateC] partials
|
||||
cdef int n_partials = 0
|
||||
cdef int q = 0
|
||||
|
@ -267,7 +295,11 @@ cdef class Matcher:
|
|||
end = token_i+1
|
||||
ent_id = state.second[1].attrs[0].value
|
||||
label = state.second[1].attrs[1].value
|
||||
if acceptor is None or acceptor(doc, ent_id, label, start, end):
|
||||
acceptor = self._acceptors.get(ent_id)
|
||||
if acceptor is not None:
|
||||
match = acceptor(doc, ent_id, label, start, end)
|
||||
if match:
|
||||
ent_id, label, start, end = match
|
||||
matches.append((ent_id, label, start, end))
|
||||
partials.resize(q)
|
||||
# Check whether we open any new patterns on this token
|
||||
|
@ -293,6 +325,10 @@ cdef class Matcher:
|
|||
label = pattern[1].attrs[1].value
|
||||
if acceptor is None or acceptor(doc, ent_id, label, start, end):
|
||||
matches.append((ent_id, label, start, end))
|
||||
for i, (ent_id, label, start, end) in enumerate(matches):
|
||||
on_match = self._callbacks.get(ent_id)
|
||||
if on_match is not None:
|
||||
on_match(self, doc, i, matches)
|
||||
return matches
|
||||
|
||||
def pipe(self, docs, batch_size=1000, n_threads=2):
|
||||
|
@ -301,6 +337,32 @@ cdef class Matcher:
|
|||
yield doc
|
||||
|
||||
|
||||
def get_bilou(length):
|
||||
if length == 1:
|
||||
return [U_ENT]
|
||||
elif length == 2:
|
||||
return [B2_ENT, L2_ENT]
|
||||
elif length == 3:
|
||||
return [B3_ENT, I3_ENT, L3_ENT]
|
||||
elif length == 4:
|
||||
return [B4_ENT, I4_ENT, I4_ENT, L4_ENT]
|
||||
elif length == 5:
|
||||
return [B5_ENT, I5_ENT, I5_ENT, I5_ENT, L5_ENT]
|
||||
elif length == 6:
|
||||
return [B6_ENT, I6_ENT, I6_ENT, I6_ENT, I6_ENT, L6_ENT]
|
||||
elif length == 7:
|
||||
return [B7_ENT, I7_ENT, I7_ENT, I7_ENT, I7_ENT, I7_ENT, L7_ENT]
|
||||
elif length == 8:
|
||||
return [B8_ENT, I8_ENT, I8_ENT, I8_ENT, I8_ENT, I8_ENT, I8_ENT, L8_ENT]
|
||||
elif length == 9:
|
||||
return [B9_ENT, I9_ENT, I9_ENT, I9_ENT, I9_ENT, I9_ENT, I9_ENT, I9_ENT, L9_ENT]
|
||||
elif length == 10:
|
||||
return [B10_ENT, I10_ENT, I10_ENT, I10_ENT, I10_ENT, I10_ENT, I10_ENT,
|
||||
I10_ENT, I10_ENT, L10_ENT]
|
||||
else:
|
||||
raise ValueError("Max length currently 10 for phrase matching")
|
||||
|
||||
|
||||
cdef class PhraseMatcher:
|
||||
cdef Pool mem
|
||||
cdef Vocab vocab
|
||||
|
|
Loading…
Reference in New Issue