diff --git a/spacy/cli/init_model.py b/spacy/cli/init_model.py index bad63209e..172345284 100644 --- a/spacy/cli/init_model.py +++ b/spacy/cli/init_model.py @@ -11,6 +11,8 @@ from preshed.counter import PreshCounter import tarfile import gzip import zipfile +import ujson as json +from spacy.lexeme import intify_attrs from ._messages import Messages from ..vectors import Vectors @@ -26,7 +28,8 @@ except ImportError: @plac.annotations( lang=("model language", "positional", None, str), output_dir=("model output directory", "positional", None, Path), - freqs_loc=("location of words frequencies file", "positional", None, Path), + freqs_loc=("location of words frequencies file", "optional", "f", Path), + jsonl_loc=("location of JSONL-formatted attributes file", "optional", "j", Path), clusters_loc=("optional: location of brown clusters data", "option", "c", str), vectors_loc=("optional: location of vectors file in Word2Vec format " @@ -35,20 +38,37 @@ except ImportError: prune_vectors=("optional: number of vectors to prune to", "option", "V", int) ) -def init_model(lang, output_dir, freqs_loc=None, clusters_loc=None, +def init_model(lang, output_dir, freqs_loc=None, clusters_loc=None, jsonl_loc=None, vectors_loc=None, prune_vectors=-1): """ Create a new model from raw data, like word frequencies, Brown clusters and word vectors. """ - if freqs_loc is not None and not freqs_loc.exists(): - prints(freqs_loc, title=Messages.M037, exits=1) - clusters_loc = ensure_path(clusters_loc) + if jsonl_loc is not None: + if freqs_loc is not None or clusters_loc is not None: + settings = ['-j'] + if freqs_loc: + settings.append('-f') + if clusters_loc: + settings.append('-c') + prints(' '.join(settings), + title=( + "The -f and -c arguments are deprecated, and not compatible " + "with the -j argument, which should specify the same information. " + "Either merge the frequencies and clusters data into the " + "jsonl-formatted file (recommended), or use only the -f and " + "-c files, without the other lexical attributes.")) + jsonl_loc = ensure_path(jsonl_loc) + lex_attrs = (json.loads(line) for line in jsonl_loc.open()) + else: + clusters_loc = ensure_path(clusters_loc) + freqs_loc = ensure_path(freqs_loc) + if freqs_loc is not None and not freqs_loc.exists(): + prints(freqs_loc, title=Messages.M037, exits=1) + lex_attrs = read_attrs_from_deprecated(freqs_loc, clusters_loc) vectors_loc = ensure_path(vectors_loc) - probs, oov_prob = read_freqs(freqs_loc) if freqs_loc is not None else ({}, -20) vectors_data, vector_keys = read_vectors(vectors_loc) if vectors_loc else (None, None) - clusters = read_clusters(clusters_loc) if clusters_loc else {} - nlp = create_model(lang, probs, oov_prob, clusters, vectors_data, vector_keys, prune_vectors) + nlp = create_model(lang, lex_attrs, vectors_data, vector_keys, prune_vectors) if not output_dir.exists(): output_dir.mkdir() nlp.to_disk(output_dir) @@ -70,26 +90,38 @@ def open_file(loc): else: return loc.open('r', encoding='utf8') -def create_model(lang, probs, oov_prob, clusters, vectors_data, vector_keys, prune_vectors): +def read_attrs_from_deprecated(freqs_loc, clusters_loc): + probs, oov_prob = read_freqs(freqs_loc) if freqs_loc is not None else ({}, -20) + clusters = read_clusters(clusters_loc) if clusters_loc else {} + lex_attrs = {} + sorted_probs = sorted(probs.items(), key=lambda item: item[1], reverse=True) + for i, (word, prob) in tqdm(enumerate(sorted_probs)): + attrs = {'orth': word, 'rank': i, 'prob': prob} + # Decode as a little-endian string, so that we can do & 15 to get + # the first 4 bits. See _parse_features.pyx + if word in clusters: + attrs['cluster'] = int(clusters[word][::-1], 2) + else: + attrs['cluster'] = 0 + lex_attrs.append(attrs) + return lex_attrs + + +def create_model(lang, lex_attrs, vectors_data, vector_keys, prune_vectors): print("Creating model...") lang_class = get_lang_class(lang) nlp = lang_class() for lexeme in nlp.vocab: lexeme.rank = 0 lex_added = 0 - for i, (word, prob) in enumerate(tqdm(sorted(probs.items(), key=lambda item: item[1], reverse=True))): - lexeme = nlp.vocab[word] - lexeme.rank = i - lexeme.prob = prob + for attrs in lex_attrs: + lexeme = nlp.vocab[attrs['orth']] + lexeme.set_attrs(**intify_attrs(attrs)) lexeme.is_oov = False - # Decode as a little-endian string, so that we can do & 15 to get - # the first 4 bits. See _parse_features.pyx - if word in clusters: - lexeme.cluster = int(clusters[word][::-1], 2) - else: - lexeme.cluster = 0 lex_added += 1 - nlp.vocab.cfg.update({'oov_prob': oov_prob}) + lex_added += 1 + oov_prob = min(lex.prob for lex in nlp.vocab) + nlp.vocab.cfg.update({'oov_prob': oov_prob-1}) if vector_keys is not None: for word in vector_keys: if word not in nlp.vocab: