mirror of https://github.com/explosion/spaCy.git
Fix init-model for npz vectors
This commit is contained in:
parent
59d655e8d0
commit
dee8bdb900
|
@ -66,14 +66,14 @@ def init_model(lang, output_dir, freqs_loc=None, clusters_loc=None, jsonl_loc=No
|
|||
if freqs_loc is not None and not freqs_loc.exists():
|
||||
prints(freqs_loc, title=Messages.M037, exits=1)
|
||||
lex_attrs = read_attrs_from_deprecated(freqs_loc, clusters_loc)
|
||||
vectors_loc = ensure_path(vectors_loc)
|
||||
if vectors_loc and vectors_loc.parts[-1].endswith('.npz'):
|
||||
vectors_data = numpy.load(vectors_loc.open('rb'))
|
||||
vector_keys = [lex['orth'] for lex in lex_attrs
|
||||
if 'id' in lex and lex['id'] < vectors_data.shape[0]]
|
||||
else:
|
||||
vectors_data, vector_keys = read_vectors(vectors_loc) if vectors_loc else (None, None)
|
||||
nlp = create_model(lang, lex_attrs, vectors_data, vector_keys, prune_vectors)
|
||||
|
||||
nlp = create_model(lang, lex_attrs)
|
||||
if vectors_loc is not None:
|
||||
add_vectors(nlp, vectors_loc, prune_vectors)
|
||||
vec_added = len(nlp.vocab.vectors)
|
||||
lex_added = len(nlp.vocab)
|
||||
prints(Messages.M039.format(entries=lex_added, vectors=vec_added),
|
||||
title=Messages.M038)
|
||||
if not output_dir.exists():
|
||||
output_dir.mkdir()
|
||||
nlp.to_disk(output_dir)
|
||||
|
@ -112,7 +112,7 @@ def read_attrs_from_deprecated(freqs_loc, clusters_loc):
|
|||
return lex_attrs
|
||||
|
||||
|
||||
def create_model(lang, lex_attrs, vectors_data, vector_keys, prune_vectors):
|
||||
def create_model(lang, lex_attrs):
|
||||
print("Creating model...")
|
||||
lang_class = get_lang_class(lang)
|
||||
nlp = lang_class()
|
||||
|
@ -120,28 +120,38 @@ def create_model(lang, lex_attrs, vectors_data, vector_keys, prune_vectors):
|
|||
lexeme.rank = 0
|
||||
lex_added = 0
|
||||
for attrs in lex_attrs:
|
||||
if 'settings' in attrs:
|
||||
continue
|
||||
lexeme = nlp.vocab[attrs['orth']]
|
||||
lexeme.set_attrs(**intify_attrs(attrs))
|
||||
lexeme.set_attrs(**attrs)
|
||||
lexeme.is_oov = False
|
||||
lex_added += 1
|
||||
lex_added += 1
|
||||
oov_prob = min(lex.prob for lex in nlp.vocab)
|
||||
nlp.vocab.cfg.update({'oov_prob': oov_prob-1})
|
||||
if vector_keys is not None:
|
||||
for word in vector_keys:
|
||||
if word not in nlp.vocab:
|
||||
lexeme = nlp.vocab[word]
|
||||
lexeme.is_oov = False
|
||||
lex_added += 1
|
||||
if vectors_data is not None:
|
||||
nlp.vocab.vectors = Vectors(data=vectors_data, keys=vector_keys)
|
||||
if prune_vectors >= 1:
|
||||
nlp.vocab.prune_vectors(prune_vectors)
|
||||
vec_added = len(nlp.vocab.vectors)
|
||||
prints(Messages.M039.format(entries=lex_added, vectors=vec_added),
|
||||
title=Messages.M038)
|
||||
return nlp
|
||||
|
||||
def add_vectors(nlp, vectors_loc, prune_vectors):
|
||||
vectors_loc = ensure_path(vectors_loc)
|
||||
if vectors_loc and vectors_loc.parts[-1].endswith('.npz'):
|
||||
nlp.vocab.vectors = Vectors(data=numpy.load(vectors_loc.open('rb')))
|
||||
for lex in nlp.vocab:
|
||||
if lex.rank:
|
||||
nlp.vocab.vectors.add(lex.orth, row=lex.rank)
|
||||
else:
|
||||
vectors_data, vector_keys = read_vectors(vectors_loc) if vectors_loc else (None, None)
|
||||
if vector_keys is not None:
|
||||
for word in vector_keys:
|
||||
if word not in nlp.vocab:
|
||||
lexeme = nlp.vocab[word]
|
||||
lexeme.is_oov = False
|
||||
lex_added += 1
|
||||
if vectors_data is not None:
|
||||
nlp.vocab.vectors = Vectors(data=vectors_data, keys=vector_keys)
|
||||
nlp.vocab.vectors.name = '%s_model.vectors' % nlp.meta['lang']
|
||||
nlp.meta['vectors']['name'] = nlp.vocab.vectors.name
|
||||
if prune_vectors >= 1:
|
||||
nlp.vocab.prune_vectors(prune_vectors)
|
||||
|
||||
def read_vectors(vectors_loc):
|
||||
print("Reading vectors from %s" % vectors_loc)
|
||||
|
|
Loading…
Reference in New Issue