From 044397e269c555665218296aa48c872600b04b89 Mon Sep 17 00:00:00 2001 From: Matthew Honnibal Date: Wed, 21 Mar 2018 14:33:23 +0100 Subject: [PATCH] Support .gz and .tar.gz files in spacy init-model --- spacy/cli/init_model.py | 43 ++++++++++++++++++++++++++++------------- 1 file changed, 30 insertions(+), 13 deletions(-) diff --git a/spacy/cli/init_model.py b/spacy/cli/init_model.py index 99a6e87eb..71efe1b2e 100644 --- a/spacy/cli/init_model.py +++ b/spacy/cli/init_model.py @@ -8,6 +8,8 @@ import numpy from ast import literal_eval from pathlib import Path from preshed.counter import PreshCounter +import tarfile +import gzip from ..compat import fix_text from ..vectors import Vectors @@ -25,17 +27,17 @@ from ..util import prints, ensure_path, get_lang_class prune_vectors=("optional: number of vectors to prune to", "option", "V", int) ) -def init_model(lang, output_dir, freqs_loc, clusters_loc=None, vectors_loc=None, prune_vectors=-1): +def init_model(lang, output_dir, freqs_loc=None, clusters_loc=None, vectors_loc=None, prune_vectors=-1): """ Create a new model from raw data, like word frequencies, Brown clusters and word vectors. """ - if not freqs_loc.exists(): + if freqs_loc is not None and not freqs_loc.exists(): prints(freqs_loc, title="Can't find words frequencies file", exits=1) clusters_loc = ensure_path(clusters_loc) vectors_loc = ensure_path(vectors_loc) - probs, oov_prob = read_freqs(freqs_loc) + probs, oov_prob = read_freqs(freqs_loc) if freqs_loc is not None else ({}, -20) vectors_data, vector_keys = read_vectors(vectors_loc) if vectors_loc else (None, None) clusters = read_clusters(clusters_loc) if clusters_loc else {} @@ -46,6 +48,16 @@ def init_model(lang, output_dir, freqs_loc, clusters_loc=None, vectors_loc=None, nlp.to_disk(output_dir) return nlp +def open_file(loc): + '''Handle .gz, .tar.gz or unzipped files''' + loc = ensure_path(loc) + if tarfile.is_tarfile(str(loc)): + return tarfile.open(str(loc), 'r:gz') + elif loc.parts[-1].endswith('gz'): + return (line.decode('utf8') for line in gzip.open(str(loc), 'r')) + else: + return loc.open('r', encoding='utf8') + def create_model(lang, probs, oov_prob, clusters, vectors_data, vector_keys, prune_vectors): print("Creating model...") @@ -68,6 +80,11 @@ def create_model(lang, probs, oov_prob, clusters, vectors_data, vector_keys, pru lexeme.cluster = 0 lex_added += 1 nlp.vocab.cfg.update({'oov_prob': oov_prob}) + for word in vector_keys: + if word not in nlp.vocab: + lexeme = nlp.vocab[word] + lexeme.is_oov = False + lex_added += 1 if len(vectors_data): nlp.vocab.vectors = Vectors(data=vectors_data, keys=vector_keys) @@ -81,16 +98,16 @@ def create_model(lang, probs, oov_prob, clusters, vectors_data, vector_keys, pru def read_vectors(vectors_loc): - print("Reading vectors...") - with vectors_loc.open() as f: - shape = tuple(int(size) for size in f.readline().split()) - vectors_data = numpy.zeros(shape=shape, dtype='f') - vectors_keys = [] - for i, line in enumerate(tqdm(f)): - pieces = line.split() - word = pieces.pop(0) - vectors_data[i] = numpy.array([float(val_str) for val_str in pieces], dtype='f') - vectors_keys.append(word) + print("Reading vectors from %s" % vectors_loc) + f = open_file(vectors_loc) + shape = tuple(int(size) for size in next(f).split()) + vectors_data = numpy.zeros(shape=shape, dtype='f') + vectors_keys = [] + for i, line in enumerate(tqdm(f)): + pieces = line.split() + word = pieces.pop(0) + vectors_data[i] = numpy.asarray(pieces, dtype='f') + vectors_keys.append(word) return vectors_data, vectors_keys