diff --git a/examples/multi_word_matches.py b/examples/multi_word_matches.py index 59d3c2a63..3c715736e 100644 --- a/examples/multi_word_matches.py +++ b/examples/multi_word_matches.py @@ -26,137 +26,71 @@ from ast import literal_eval from bz2 import BZ2File import time import math +import codecs import plac from preshed.maps import PreshMap +from preshed.counter import PreshCounter from spacy.strings import hash_string from spacy.en import English -from spacy.matcher import Matcher - -from spacy.attrs import FLAG63 as B_ENT -from spacy.attrs import FLAG62 as L_ENT -from spacy.attrs import FLAG61 as I_ENT - -from spacy.attrs import FLAG60 as B2_ENT -from spacy.attrs import FLAG59 as B3_ENT -from spacy.attrs import FLAG58 as B4_ENT -from spacy.attrs import FLAG57 as B5_ENT -from spacy.attrs import FLAG56 as B6_ENT -from spacy.attrs import FLAG55 as B7_ENT -from spacy.attrs import FLAG54 as B8_ENT -from spacy.attrs import FLAG53 as B9_ENT -from spacy.attrs import FLAG52 as B10_ENT - -from spacy.attrs import FLAG51 as I3_ENT -from spacy.attrs import FLAG50 as I4_ENT -from spacy.attrs import FLAG49 as I5_ENT -from spacy.attrs import FLAG48 as I6_ENT -from spacy.attrs import FLAG47 as I7_ENT -from spacy.attrs import FLAG46 as I8_ENT -from spacy.attrs import FLAG45 as I9_ENT -from spacy.attrs import FLAG44 as I10_ENT - -from spacy.attrs import FLAG43 as L2_ENT -from spacy.attrs import FLAG42 as L3_ENT -from spacy.attrs import FLAG41 as L4_ENT -from spacy.attrs import FLAG40 as L5_ENT -from spacy.attrs import FLAG39 as L6_ENT -from spacy.attrs import FLAG38 as L7_ENT -from spacy.attrs import FLAG37 as L8_ENT -from spacy.attrs import FLAG36 as L9_ENT -from spacy.attrs import FLAG35 as L10_ENT +from spacy.matcher import PhraseMatcher -def get_bilou(length): - if length == 1: - return [U_ENT] - elif length == 2: - return [B2_ENT, L2_ENT] - elif length == 3: - return [B3_ENT, I3_ENT, L3_ENT] - elif length == 4: - return [B4_ENT, I4_ENT, I4_ENT, L4_ENT] - elif length == 5: - return [B5_ENT, I5_ENT, I5_ENT, L5_ENT] - elif length == 6: - return [B6_ENT, I6_ENT, I6_ENT, I6_ENT, I6_ENT, L6_ENT] - elif length == 7: - return [B7_ENT, I7_ENT, I7_ENT, I7_ENT, I7_ENT, I7_ENT, L7_ENT] - elif length == 8: - return [B8_ENT, I8_ENT, I8_ENT, I8_ENT, I8_ENT, I8_ENT, I8_ENT, L8_ENT] - elif length == 9: - return [B9_ENT, I9_ENT, I9_ENT, I9_ENT, I9_ENT, I9_ENT, I9_ENT, L9_ENT] - elif length == 10: - return [B10_ENT, I10_ENT, I10_ENT, I10_ENT, I10_ENT, I10_ENT, I10_ENT, L10_ENT] - - -def make_matcher(vocab, max_length): - abstract_patterns = [] - for length in range(2, max_length): - abstract_patterns.append([{tag: True} for tag in get_bilou(length)]) - return Matcher(vocab, {'Candidate': ('CAND', {}, abstract_patterns)}) - - -def get_matches(matcher, pattern_ids, doc): - matches = [] - for label, start, end in matcher(doc): - candidate = doc[start : end] - if pattern_ids[hash_string(candidate.text)] == True: - start = candidate[0].idx - end = candidate[-1].idx + len(candidate[-1]) - matches.append((start, end, candidate.root.tag_, candidate.text)) - return matches - - -def merge_matches(doc, matches): - for start, end, tag, text in matches: - doc.merge(start, end, tag, text, 'MWE') - - -def read_gazetteer(loc): - for line in open(loc): +def read_gazetteer(tokenizer, loc, n=-1): + for i, line in enumerate(open(loc)): phrase = literal_eval('u' + line.strip()) if ' (' in phrase and phrase.endswith(')'): phrase = phrase.split(' (', 1)[0] - yield phrase + if i >= n: + break + phrase = tokenizer(phrase) + if len(phrase) >= 2: + yield phrase + def read_text(bz2_loc): with BZ2File(bz2_loc) as file_: for line in file_: yield line.decode('utf8') -def main(patterns_loc, text_loc): + +def get_matches(tokenizer, phrases, texts, max_length=6): + matcher = PhraseMatcher(tokenizer.vocab, phrases, max_length=max_length) + print("Match") + for text in texts: + doc = tokenizer(text) + matches = matcher(doc) + for mwe in doc.ents: + yield mwe + + +def main(patterns_loc, text_loc, counts_loc, n=10000000): nlp = English(parser=False, tagger=False, entity=False) - - pattern_ids = PreshMap() - max_length = 10 - i = 0 - for pattern_str in read_gazetteer(patterns_loc): - pattern = nlp.tokenizer(pattern_str) - if len(pattern) < 2 or len(pattern) >= max_length: - continue - bilou_tags = get_bilou(len(pattern)) - for word, tag in zip(pattern, bilou_tags): - lexeme = nlp.vocab[word.orth] - lexeme.set_flag(tag, True) - pattern_ids[hash_string(pattern.text)] = True - i += 1 - if i >= 10000001: - break - - matcher = make_matcher(nlp.vocab, max_length) - + print("Make matcher") + phrases = read_gazetteer(nlp.tokenizer, patterns_loc, n=n) + counts = PreshCounter() t1 = time.time() - - for text in read_text(text_loc): - doc = nlp.tokenizer(text) - matches = get_matches(matcher, pattern_ids, doc) - merge_matches(doc, matches) + for mwe in get_matches(nlp.tokenizer, phrases, read_text(text_loc)): + counts.inc(hash_string(mwe.text), 1) t2 = time.time() - print('10 ^ %d patterns took %d s' % (round(math.log(i, 10)), t2-t1)) - + print("10m tokens in %d s" % (t2 - t1)) + + with codecs.open(counts_loc, 'w', 'utf8') as file_: + for phrase in read_gazetteer(nlp.tokenizer, patterns_loc, n=n): + text = phrase.string + key = hash_string(text) + count = counts[key] + if count != 0: + file_.write('%d\t%s\n' % (count, text)) if __name__ == '__main__': - plac.call(main) + if False: + import cProfile + import pstats + cProfile.runctx("plac.call(main)", globals(), locals(), "Profile.prof") + s = pstats.Stats("Profile.prof") + s.strip_dirs().sort_stats("time").print_stats() + else: + plac.call(main)