diff --git a/bin/init_model.py b/bin/init_model.py index 0de17bdfa..d40e7813d 100644 --- a/bin/init_model.py +++ b/bin/init_model.py @@ -17,6 +17,8 @@ Requires: """ from __future__ import unicode_literals +from ast import literal_eval + import plac from pathlib import Path @@ -29,6 +31,8 @@ from spacy.en import get_lex_props from spacy.en.lemmatizer import Lemmatizer from spacy.vocab import Vocab from spacy.vocab import write_binary_vectors +from spacy.strings import hash_string +from preshed.counter import PreshCounter from spacy.parts_of_speech import NOUN, VERB, ADJ @@ -84,6 +88,28 @@ def _read_probs(loc): return probs +def _read_freqs(loc): + counts = PreshCounter() + total = 0 + for line in open(loc): + freq, doc_freq, key = line.split('\t', 2) + freq = int(freq) + counts[hash_string(key)] = freq + total += freq + counts.smooth() + log_total = math.log(total) + probs = {} + for line in open(loc): + freq, doc_freq, key = line.split('\t', 2) + if int(doc_freq) >= 2 and int(freq) >= 5 and len(key) < 200: + word = literal_eval(key) + smooth_count = counts.smoother(int(freq)) + log_smooth_count = math.log(smooth_count) + probs[word] = math.log(smooth_count) - log_total + probs['-OOV-'] = math.log(counts.smoother(0)) - log_total + return probs + + def _read_senses(loc): lexicon = defaultdict(lambda: defaultdict(list)) if not loc.exists(): @@ -115,6 +141,8 @@ def setup_vocab(src_dir, dst_dir): vocab = Vocab(data_dir=None, get_lex_props=get_lex_props) clusters = _read_clusters(src_dir / 'clusters.txt') probs = _read_probs(src_dir / 'words.sgt.prob') + if not probs: + probs = _read_freqs(src_dir / 'freqs.txt') if not probs: min_prob = 0.0 else: