# cython: infer_types=True # cython: profile=True from __future__ import unicode_literals from cymem.cymem cimport Pool from murmurhash.mrmr cimport hash64 from preshed.maps cimport PreshMap from .matcher cimport Matcher from ..attrs cimport ORTH, POS, TAG, DEP, LEMMA, attr_id_t from ..vocab cimport Vocab from ..tokens.doc cimport Doc, get_token_attr from ..typedefs cimport attr_t, hash_t from ..errors import Warnings, deprecation_warning, user_warning from ..attrs import FLAG61 as U_ENT from ..attrs import FLAG60 as B2_ENT from ..attrs import FLAG59 as B3_ENT from ..attrs import FLAG58 as B4_ENT from ..attrs import FLAG43 as L2_ENT from ..attrs import FLAG42 as L3_ENT from ..attrs import FLAG41 as L4_ENT from ..attrs import FLAG42 as I3_ENT from ..attrs import FLAG41 as I4_ENT cdef class PhraseMatcher: cdef Pool mem cdef Vocab vocab cdef Matcher matcher cdef PreshMap phrase_ids cdef int max_length cdef attr_id_t attr cdef public object _callbacks cdef public object _patterns cdef public object _validate def __init__(self, Vocab vocab, max_length=0, attr='ORTH', validate=False): if max_length != 0: deprecation_warning(Warnings.W010) self.mem = Pool() self.max_length = max_length self.vocab = vocab self.matcher = Matcher(self.vocab, validate=False) if isinstance(attr, long): self.attr = attr else: self.attr = self.vocab.strings[attr] self.phrase_ids = PreshMap() abstract_patterns = [ [{U_ENT: True}], [{B2_ENT: True}, {L2_ENT: True}], [{B3_ENT: True}, {I3_ENT: True}, {L3_ENT: True}], [{B4_ENT: True}, {I4_ENT: True}, {I4_ENT: True, "OP": "+"}, {L4_ENT: True}], ] self.matcher.add('Candidate', None, *abstract_patterns) self._callbacks = {} self._validate = validate def __len__(self): """Get the number of rules added to the matcher. Note that this only returns the number of rules (identical with the number of IDs), not the number of individual patterns. RETURNS (int): The number of rules. """ return len(self.phrase_ids) def __contains__(self, key): """Check whether the matcher contains rules for a match ID. key (unicode): The match ID. RETURNS (bool): Whether the matcher contains rules for this match ID. """ cdef hash_t ent_id = self.matcher._normalize_key(key) return ent_id in self._callbacks def __reduce__(self): return (self.__class__, (self.vocab,), None, None) def add(self, key, on_match, *docs): """Add a match-rule to the phrase-matcher. A match-rule consists of: an ID key, an on_match callback, and one or more patterns. key (unicode): The match ID. on_match (callable): Callback executed on match. *docs (Doc): `Doc` objects representing match patterns. """ cdef Doc doc cdef hash_t ent_id = self.matcher._normalize_key(key) self._callbacks[ent_id] = on_match cdef int length cdef int i cdef hash_t phrase_hash cdef Pool mem = Pool() for doc in docs: length = doc.length if length == 0: continue if self._validate and (doc.is_tagged or doc.is_parsed) \ and self.attr not in (DEP, POS, TAG, LEMMA): string_attr = self.vocab.strings[self.attr] user_warning(Warnings.W012.format(key=key, attr=string_attr)) tags = get_bilou(length) phrase_key = mem.alloc(length, sizeof(attr_t)) for i, tag in enumerate(tags): attr_value = self.get_lex_value(doc, i) lexeme = self.vocab[attr_value] lexeme.set_flag(tag, True) phrase_key[i] = lexeme.orth phrase_hash = hash64(phrase_key, length * sizeof(attr_t), 0) self.phrase_ids.set(phrase_hash, ent_id) def __call__(self, Doc doc): """Find all sequences matching the supplied patterns on the `Doc`. doc (Doc): The document to match over. RETURNS (list): A list of `(key, start, end)` tuples, describing the matches. A match tuple describes a span `doc[start:end]`. The `label_id` and `key` are both integers. """ matches = [] if self.attr == ORTH: match_doc = doc else: # If we're not matching on the ORTH, match_doc will be a Doc whose # token.orth values are the attribute values we're matching on, # e.g. Doc(nlp.vocab, words=[token.pos_ for token in doc]) words = [self.get_lex_value(doc, i) for i in range(len(doc))] match_doc = Doc(self.vocab, words=words) for _, start, end in self.matcher(match_doc): ent_id = self.accept_match(match_doc, start, end) if ent_id is not None: matches.append((ent_id, start, end)) for i, (ent_id, start, end) in enumerate(matches): on_match = self._callbacks.get(ent_id) if on_match is not None: on_match(self, doc, i, matches) return matches def pipe(self, stream, batch_size=1000, n_threads=1, return_matches=False, as_tuples=False): """Match a stream of documents, yielding them in turn. docs (iterable): A stream of documents. batch_size (int): Number of documents to accumulate into a working set. n_threads (int): The number of threads with which to work on the buffer in parallel, if the implementation supports multi-threading. return_matches (bool): Yield the match lists along with the docs, making results (doc, matches) tuples. as_tuples (bool): Interpret the input stream as (doc, context) tuples, and yield (result, context) tuples out. If both return_matches and as_tuples are True, the output will be a sequence of ((doc, matches), context) tuples. YIELDS (Doc): Documents, in order. """ if as_tuples: for doc, context in stream: matches = self(doc) if return_matches: yield ((doc, matches), context) else: yield (doc, context) else: for doc in stream: matches = self(doc) if return_matches: yield (doc, matches) else: yield doc def accept_match(self, Doc doc, int start, int end): cdef int i, j cdef Pool mem = Pool() phrase_key = mem.alloc(end-start, sizeof(attr_t)) for i, j in enumerate(range(start, end)): phrase_key[i] = doc.c[j].lex.orth cdef hash_t key = hash64(phrase_key, (end-start) * sizeof(attr_t), 0) ent_id = self.phrase_ids.get(key) if ent_id == 0: return None else: return ent_id def get_lex_value(self, Doc doc, int i): if self.attr == ORTH: # Return the regular orth value of the lexeme return doc.c[i].lex.orth # Get the attribute value instead, e.g. token.pos attr_value = get_token_attr(&doc.c[i], self.attr) if attr_value in (0, 1): # Value is boolean, convert to string string_attr_value = str(attr_value) else: string_attr_value = self.vocab.strings[attr_value] string_attr_name = self.vocab.strings[self.attr] # Concatenate the attr name and value to not pollute lexeme space # e.g. 'POS-VERB' instead of just 'VERB', which could otherwise # create false positive matches return 'matcher:{}-{}'.format(string_attr_name, string_attr_value) def get_bilou(length): if length == 0: raise ValueError("Length must be >= 1") elif length == 1: return [U_ENT] elif length == 2: return [B2_ENT, L2_ENT] elif length == 3: return [B3_ENT, I3_ENT, L3_ENT] else: return [B4_ENT, I4_ENT] + [I4_ENT] * (length-3) + [L4_ENT]