From 6cbdc949593301173922a827a1016e19aef6f39b Mon Sep 17 00:00:00 2001 From: Matthew Honnibal Date: Mon, 17 Oct 2016 15:23:31 +0200 Subject: [PATCH] Lots of updates to Matcher, to make entity handling sane. --- spacy/matcher.pyx | 154 ++++++++++++++++++++++++++++++++-------------- 1 file changed, 108 insertions(+), 46 deletions(-) diff --git a/spacy/matcher.pyx b/spacy/matcher.pyx index 3307eb864..b6c2def8e 100644 --- a/spacy/matcher.pyx +++ b/spacy/matcher.pyx @@ -92,8 +92,8 @@ ctypedef TokenPatternC* TokenPatternC_ptr ctypedef pair[int, TokenPatternC_ptr] StateC -cdef TokenPatternC* init_pattern(Pool mem, object token_specs, attr_t entity_id, - attr_t entity_type) except NULL: +cdef TokenPatternC* init_pattern(Pool mem, attr_t entity_id, attr_t label, + object token_specs) except NULL: pattern = mem.alloc(len(token_specs) + 1, sizeof(TokenPatternC)) cdef int i for i, (quantifier, spec) in enumerate(token_specs): @@ -108,7 +108,7 @@ cdef TokenPatternC* init_pattern(Pool mem, object token_specs, attr_t entity_id, pattern[i].attrs[0].attr = ID pattern[i].attrs[0].value = entity_id pattern[i].attrs[1].attr = ENT_TYPE - pattern[i].attrs[1].value = entity_type + pattern[i].attrs[1].value = label pattern[i].nr_attr = 0 return pattern @@ -163,37 +163,14 @@ def _convert_strings(token_specs, string_store): return tokens -def get_bilou(length): - if length == 1: - return [U_ENT] - elif length == 2: - return [B2_ENT, L2_ENT] - elif length == 3: - return [B3_ENT, I3_ENT, L3_ENT] - elif length == 4: - return [B4_ENT, I4_ENT, I4_ENT, L4_ENT] - elif length == 5: - return [B5_ENT, I5_ENT, I5_ENT, I5_ENT, L5_ENT] - elif length == 6: - return [B6_ENT, I6_ENT, I6_ENT, I6_ENT, I6_ENT, L6_ENT] - elif length == 7: - return [B7_ENT, I7_ENT, I7_ENT, I7_ENT, I7_ENT, I7_ENT, L7_ENT] - elif length == 8: - return [B8_ENT, I8_ENT, I8_ENT, I8_ENT, I8_ENT, I8_ENT, I8_ENT, L8_ENT] - elif length == 9: - return [B9_ENT, I9_ENT, I9_ENT, I9_ENT, I9_ENT, I9_ENT, I9_ENT, I9_ENT, L9_ENT] - elif length == 10: - return [B10_ENT, I10_ENT, I10_ENT, I10_ENT, I10_ENT, I10_ENT, I10_ENT, - I10_ENT, I10_ENT, L10_ENT] - else: - raise ValueError("Max length currently 10 for phrase matching") - - cdef class Matcher: cdef Pool mem cdef vector[TokenPatternC*] patterns cdef readonly Vocab vocab cdef public object _patterns + cdef public object _entities + cdef public object _callbacks + cdef public object _acceptors @classmethod def load(cls, path, vocab): @@ -205,12 +182,17 @@ cdef class Matcher: return cls(vocab, patterns) def __init__(self, vocab, patterns={}): - self._patterns = dict(patterns) # Make sure we own the object + self._patterns = {} + self._entities = {} + self._acceptors = {} + self._callbacks = {} self.vocab = vocab self.mem = Pool() self.vocab = vocab - for entity_key, (etype, attrs, specs) in sorted(self._patterns.items()): - self.add(entity_key, etype, attrs, specs) + for entity_key, (etype, attrs, specs) in sorted(patterns.items()): + self.add_entity(entity_key, attrs) + for spec in specs: + self.add_pattern(entity_key, spec, label=etype) def __reduce__(self): return (self.__class__, (self.vocab, self._patterns), None, None) @@ -218,21 +200,67 @@ cdef class Matcher: property n_patterns: def __get__(self): return self.patterns.size() - def add(self, entity_key, etype, attrs, specs): - self._patterns[entity_key] = (etype, dict(attrs), list(specs)) - if isinstance(entity_key, basestring): - entity_key = self.vocab.strings[entity_key] - if isinstance(etype, basestring): - etype = self.vocab.strings[etype] - elif etype is None: - etype = -1 - # TODO: Do something more clever about multiple patterns for single - # entity + def add_entity(self, entity_key, attrs=None, if_exists='raise', + acceptor=None, on_match=None): + if if_exists not in ('raise', 'ignore', 'update'): + raise ValueError( + "Unexpected value for if_exists: %s.\n" + "Expected one of: ['raise', 'ignore', 'update']" % if_exists) + if attrs is None: + attrs = {} + entity_key = self.normalize_entity_key(entity_key) + if self.has_entity(entity_key): + if if_exists == 'raise': + raise KeyError( + "Tried to add entity %s. Entity exists, and if_exists='raise'.\n" + "Set if_exists='ignore' or if_exists='update', or check with " + "matcher.has_entity()") + elif if_exists == 'ignore': + return + self._entities[entity_key] = dict(attrs) + self._patterns.setdefault(entity_key, []) + self._acceptors[entity_key] = acceptor + self._callbacks[entity_key] = on_match + + def add_pattern(self, entity_key, token_specs, label=""): + entity_key = self.normalize_entity_key(entity_key) + if not self.has_entity(entity_key): + self.add_entity(entity_key) + if isinstance(label, basestring): + label = self.vocab.strings[label] + + spec = _convert_strings(token_specs, self.vocab.strings) + self.patterns.push_back(init_pattern(self.mem, entity_key, label, spec)) + self._patterns[entity_key].append((label, token_specs)) + + def add(self, entity_key, label, attrs, specs, acceptor=None, on_match=None): + self.add_entity(entity_key, attrs=attrs, if_exists='update', + acceptor=acceptor, on_match=on_match) for spec in specs: - spec = _convert_strings(spec, self.vocab.strings) - self.patterns.push_back(init_pattern(self.mem, spec, entity_key, etype)) + self.add_pattern(entity_key, spec, label=label) + + def normalize_entity_key(self, entity_key): + if isinstance(entity_key, basestring): + return self.vocab.strings[entity_key] + else: + return entity_key + + def has_entity(self, entity_key): + entity_key = self.normalize_entity_key(entity_key) + return entity_key in self._entities + + def get_entity(self, entity_key): + entity_key = self.normalize_entity_key(entity_key) + if entity_key in self._entities: + return self._entities[entity_key] + else: + return None def __call__(self, Doc doc, acceptor=None): + if acceptor is not None: + raise ValueError( + "acceptor keyword argument to Matcher deprecated. Specify acceptor " + "functions when you add patterns instead.") cdef vector[StateC] partials cdef int n_partials = 0 cdef int q = 0 @@ -267,8 +295,12 @@ cdef class Matcher: end = token_i+1 ent_id = state.second[1].attrs[0].value label = state.second[1].attrs[1].value - if acceptor is None or acceptor(doc, ent_id, label, start, end): - matches.append((ent_id, label, start, end)) + acceptor = self._acceptors.get(ent_id) + if acceptor is not None: + match = acceptor(doc, ent_id, label, start, end) + if match: + ent_id, label, start, end = match + matches.append((ent_id, label, start, end)) partials.resize(q) # Check whether we open any new patterns on this token for pattern in self.patterns: @@ -293,6 +325,10 @@ cdef class Matcher: label = pattern[1].attrs[1].value if acceptor is None or acceptor(doc, ent_id, label, start, end): matches.append((ent_id, label, start, end)) + for i, (ent_id, label, start, end) in enumerate(matches): + on_match = self._callbacks.get(ent_id) + if on_match is not None: + on_match(self, doc, i, matches) return matches def pipe(self, docs, batch_size=1000, n_threads=2): @@ -301,6 +337,32 @@ cdef class Matcher: yield doc +def get_bilou(length): + if length == 1: + return [U_ENT] + elif length == 2: + return [B2_ENT, L2_ENT] + elif length == 3: + return [B3_ENT, I3_ENT, L3_ENT] + elif length == 4: + return [B4_ENT, I4_ENT, I4_ENT, L4_ENT] + elif length == 5: + return [B5_ENT, I5_ENT, I5_ENT, I5_ENT, L5_ENT] + elif length == 6: + return [B6_ENT, I6_ENT, I6_ENT, I6_ENT, I6_ENT, L6_ENT] + elif length == 7: + return [B7_ENT, I7_ENT, I7_ENT, I7_ENT, I7_ENT, I7_ENT, L7_ENT] + elif length == 8: + return [B8_ENT, I8_ENT, I8_ENT, I8_ENT, I8_ENT, I8_ENT, I8_ENT, L8_ENT] + elif length == 9: + return [B9_ENT, I9_ENT, I9_ENT, I9_ENT, I9_ENT, I9_ENT, I9_ENT, I9_ENT, L9_ENT] + elif length == 10: + return [B10_ENT, I10_ENT, I10_ENT, I10_ENT, I10_ENT, I10_ENT, I10_ENT, + I10_ENT, I10_ENT, L10_ENT] + else: + raise ValueError("Max length currently 10 for phrase matching") + + cdef class PhraseMatcher: cdef Pool mem cdef Vocab vocab