diff --git a/bin/init_model.py b/bin/init_model.py index 3307bffa8..9a635f296 100644 --- a/bin/init_model.py +++ b/bin/init_model.py @@ -29,8 +29,6 @@ from shutil import copytree import codecs from collections import defaultdict -from spacy.en import get_lex_props -from spacy.en.lemmatizer import Lemmatizer from spacy.vocab import Vocab from spacy.vocab import write_binary_vectors from spacy.strings import hash_string @@ -38,6 +36,11 @@ from preshed.counter import PreshCounter from spacy.parts_of_speech import NOUN, VERB, ADJ +import spacy.en +import spacy.de + + + def setup_tokenizer(lang_data_dir, tok_dir): if not tok_dir.exists(): @@ -139,7 +142,7 @@ def _read_senses(loc): return lexicon -def setup_vocab(src_dir, dst_dir): +def setup_vocab(get_lex_attr, src_dir, dst_dir): if not dst_dir.exists(): dst_dir.mkdir() @@ -148,13 +151,13 @@ def setup_vocab(src_dir, dst_dir): write_binary_vectors(str(vectors_src), str(dst_dir / 'vec.bin')) else: print("Warning: Word vectors file not found") - vocab = Vocab(data_dir=None, get_lex_props=get_lex_props) + vocab = Vocab(data_dir=None, get_lex_attr=get_lex_attr) clusters = _read_clusters(src_dir / 'clusters.txt') probs, oov_prob = _read_probs(src_dir / 'words.sgt.prob') if not probs: probs, oov_prob = _read_freqs(src_dir / 'freqs.txt') if not probs: - oov_prob = 0.0 + oov_prob = -20 else: oov_prob = min(probs.values()) for word in clusters: @@ -163,23 +166,30 @@ def setup_vocab(src_dir, dst_dir): lexicon = [] for word, prob in reversed(sorted(list(probs.items()), key=lambda item: item[1])): - entry = get_lex_props(word) - entry['prob'] = float(prob) - cluster = clusters.get(word, '0') + lexeme = vocab[word] + lexeme.prob = prob + lexeme.is_oov = False # Decode as a little-endian string, so that we can do & 15 to get # the first 4 bits. See _parse_features.pyx - entry['cluster'] = int(cluster[::-1], 2) - vocab[word] = entry + if word in clusters: + lexeme.cluster = int(clusters[word][::-1], 2) + else: + lexeme.cluster = 0 vocab.dump(str(dst_dir / 'lexemes.bin')) vocab.strings.dump(str(dst_dir / 'strings.txt')) with (dst_dir / 'oov_prob').open('w') as file_: file_.write('%f' % oov_prob) -def main(lang_data_dir, corpora_dir, model_dir): +def main(lang_id, lang_data_dir, corpora_dir, model_dir): + languages = { + 'en': spacy.en.get_lex_attr, + 'de': spacy.en.get_lex_attr + } + model_dir = Path(model_dir) - lang_data_dir = Path(lang_data_dir) - corpora_dir = Path(corpora_dir) + lang_data_dir = Path(lang_data_dir) / lang_id + corpora_dir = Path(corpora_dir) / lang_id assert corpora_dir.exists() assert lang_data_dir.exists() @@ -188,12 +198,12 @@ def main(lang_data_dir, corpora_dir, model_dir): model_dir.mkdir() setup_tokenizer(lang_data_dir, model_dir / 'tokenizer') - setup_vocab(corpora_dir, model_dir / 'vocab') + setup_vocab(languages[lang_id], corpora_dir, model_dir / 'vocab') if (lang_data_dir / 'gazetteer.json').exists(): copyfile(str(lang_data_dir / 'gazetteer.json'), str(model_dir / 'vocab' / 'gazetteer.json')) - if not (model_dir / 'wordnet').exists(): + if not (model_dir / 'wordnet').exists() and (corpora_dir / 'wordnet').exists(): copytree(str(corpora_dir / 'wordnet' / 'dict'), str(model_dir / 'wordnet'))