* Fix multi word matcher

This commit is contained in:
Matthew Honnibal 2015-10-09 02:02:37 +11:00
parent 801d55a6d9
commit 4bbc8f45c6
1 changed files with 45 additions and 111 deletions

View File

@ -26,137 +26,71 @@ from ast import literal_eval
from bz2 import BZ2File from bz2 import BZ2File
import time import time
import math import math
import codecs
import plac import plac
from preshed.maps import PreshMap from preshed.maps import PreshMap
from preshed.counter import PreshCounter
from spacy.strings import hash_string from spacy.strings import hash_string
from spacy.en import English from spacy.en import English
from spacy.matcher import Matcher from spacy.matcher import PhraseMatcher
from spacy.attrs import FLAG63 as B_ENT
from spacy.attrs import FLAG62 as L_ENT
from spacy.attrs import FLAG61 as I_ENT
from spacy.attrs import FLAG60 as B2_ENT
from spacy.attrs import FLAG59 as B3_ENT
from spacy.attrs import FLAG58 as B4_ENT
from spacy.attrs import FLAG57 as B5_ENT
from spacy.attrs import FLAG56 as B6_ENT
from spacy.attrs import FLAG55 as B7_ENT
from spacy.attrs import FLAG54 as B8_ENT
from spacy.attrs import FLAG53 as B9_ENT
from spacy.attrs import FLAG52 as B10_ENT
from spacy.attrs import FLAG51 as I3_ENT
from spacy.attrs import FLAG50 as I4_ENT
from spacy.attrs import FLAG49 as I5_ENT
from spacy.attrs import FLAG48 as I6_ENT
from spacy.attrs import FLAG47 as I7_ENT
from spacy.attrs import FLAG46 as I8_ENT
from spacy.attrs import FLAG45 as I9_ENT
from spacy.attrs import FLAG44 as I10_ENT
from spacy.attrs import FLAG43 as L2_ENT
from spacy.attrs import FLAG42 as L3_ENT
from spacy.attrs import FLAG41 as L4_ENT
from spacy.attrs import FLAG40 as L5_ENT
from spacy.attrs import FLAG39 as L6_ENT
from spacy.attrs import FLAG38 as L7_ENT
from spacy.attrs import FLAG37 as L8_ENT
from spacy.attrs import FLAG36 as L9_ENT
from spacy.attrs import FLAG35 as L10_ENT
def get_bilou(length): def read_gazetteer(tokenizer, loc, n=-1):
if length == 1: for i, line in enumerate(open(loc)):
return [U_ENT]
elif length == 2:
return [B2_ENT, L2_ENT]
elif length == 3:
return [B3_ENT, I3_ENT, L3_ENT]
elif length == 4:
return [B4_ENT, I4_ENT, I4_ENT, L4_ENT]
elif length == 5:
return [B5_ENT, I5_ENT, I5_ENT, L5_ENT]
elif length == 6:
return [B6_ENT, I6_ENT, I6_ENT, I6_ENT, I6_ENT, L6_ENT]
elif length == 7:
return [B7_ENT, I7_ENT, I7_ENT, I7_ENT, I7_ENT, I7_ENT, L7_ENT]
elif length == 8:
return [B8_ENT, I8_ENT, I8_ENT, I8_ENT, I8_ENT, I8_ENT, I8_ENT, L8_ENT]
elif length == 9:
return [B9_ENT, I9_ENT, I9_ENT, I9_ENT, I9_ENT, I9_ENT, I9_ENT, L9_ENT]
elif length == 10:
return [B10_ENT, I10_ENT, I10_ENT, I10_ENT, I10_ENT, I10_ENT, I10_ENT, L10_ENT]
def make_matcher(vocab, max_length):
abstract_patterns = []
for length in range(2, max_length):
abstract_patterns.append([{tag: True} for tag in get_bilou(length)])
return Matcher(vocab, {'Candidate': ('CAND', {}, abstract_patterns)})
def get_matches(matcher, pattern_ids, doc):
matches = []
for label, start, end in matcher(doc):
candidate = doc[start : end]
if pattern_ids[hash_string(candidate.text)] == True:
start = candidate[0].idx
end = candidate[-1].idx + len(candidate[-1])
matches.append((start, end, candidate.root.tag_, candidate.text))
return matches
def merge_matches(doc, matches):
for start, end, tag, text in matches:
doc.merge(start, end, tag, text, 'MWE')
def read_gazetteer(loc):
for line in open(loc):
phrase = literal_eval('u' + line.strip()) phrase = literal_eval('u' + line.strip())
if ' (' in phrase and phrase.endswith(')'): if ' (' in phrase and phrase.endswith(')'):
phrase = phrase.split(' (', 1)[0] phrase = phrase.split(' (', 1)[0]
if i >= n:
break
phrase = tokenizer(phrase)
if len(phrase) >= 2:
yield phrase yield phrase
def read_text(bz2_loc): def read_text(bz2_loc):
with BZ2File(bz2_loc) as file_: with BZ2File(bz2_loc) as file_:
for line in file_: for line in file_:
yield line.decode('utf8') yield line.decode('utf8')
def main(patterns_loc, text_loc):
def get_matches(tokenizer, phrases, texts, max_length=6):
matcher = PhraseMatcher(tokenizer.vocab, phrases, max_length=max_length)
print("Match")
for text in texts:
doc = tokenizer(text)
matches = matcher(doc)
for mwe in doc.ents:
yield mwe
def main(patterns_loc, text_loc, counts_loc, n=10000000):
nlp = English(parser=False, tagger=False, entity=False) nlp = English(parser=False, tagger=False, entity=False)
print("Make matcher")
pattern_ids = PreshMap() phrases = read_gazetteer(nlp.tokenizer, patterns_loc, n=n)
max_length = 10 counts = PreshCounter()
i = 0
for pattern_str in read_gazetteer(patterns_loc):
pattern = nlp.tokenizer(pattern_str)
if len(pattern) < 2 or len(pattern) >= max_length:
continue
bilou_tags = get_bilou(len(pattern))
for word, tag in zip(pattern, bilou_tags):
lexeme = nlp.vocab[word.orth]
lexeme.set_flag(tag, True)
pattern_ids[hash_string(pattern.text)] = True
i += 1
if i >= 10000001:
break
matcher = make_matcher(nlp.vocab, max_length)
t1 = time.time() t1 = time.time()
for mwe in get_matches(nlp.tokenizer, phrases, read_text(text_loc)):
for text in read_text(text_loc): counts.inc(hash_string(mwe.text), 1)
doc = nlp.tokenizer(text)
matches = get_matches(matcher, pattern_ids, doc)
merge_matches(doc, matches)
t2 = time.time() t2 = time.time()
print('10 ^ %d patterns took %d s' % (round(math.log(i, 10)), t2-t1)) print("10m tokens in %d s" % (t2 - t1))
with codecs.open(counts_loc, 'w', 'utf8') as file_:
for phrase in read_gazetteer(nlp.tokenizer, patterns_loc, n=n):
text = phrase.string
key = hash_string(text)
count = counts[key]
if count != 0:
file_.write('%d\t%s\n' % (count, text))
if __name__ == '__main__': if __name__ == '__main__':
if False:
import cProfile
import pstats
cProfile.runctx("plac.call(main)", globals(), locals(), "Profile.prof")
s = pstats.Stats("Profile.prof")
s.strip_dirs().sort_stats("time").print_stats()
else:
plac.call(main) plac.call(main)