mirror of https://github.com/explosion/spaCy.git
bugfixing prune_vectors and vectors_loc
This commit is contained in:
parent
94a0cf46fd
commit
a30bc77415
|
@ -37,7 +37,7 @@ def init_model_cli(
|
||||||
clusters_loc: Optional[Path] = Opt(None, "--clusters-loc", "-c", help="Optional location of brown clusters data", exists=True),
|
clusters_loc: Optional[Path] = Opt(None, "--clusters-loc", "-c", help="Optional location of brown clusters data", exists=True),
|
||||||
jsonl_loc: Optional[Path] = Opt(None, "--jsonl-loc", "-j", help="Location of JSONL-formatted attributes file", exists=True),
|
jsonl_loc: Optional[Path] = Opt(None, "--jsonl-loc", "-j", help="Location of JSONL-formatted attributes file", exists=True),
|
||||||
vectors_loc: Optional[Path] = Opt(None, "--vectors-loc", "-v", help="Optional vectors file in Word2Vec format", exists=True),
|
vectors_loc: Optional[Path] = Opt(None, "--vectors-loc", "-v", help="Optional vectors file in Word2Vec format", exists=True),
|
||||||
prune_vectors: int = Opt(-1 , "--prune-vectors", "-V", help="Optional number of vectors to prune to"),
|
prune_vectors: int = Opt(-1, "--prune-vectors", "-V", help="Optional number of vectors to prune to"),
|
||||||
truncate_vectors: int = Opt(0, "--truncate-vectors", "-t", help="Optional number of vectors to truncate to when reading in vectors file"),
|
truncate_vectors: int = Opt(0, "--truncate-vectors", "-t", help="Optional number of vectors to truncate to when reading in vectors file"),
|
||||||
vectors_name: Optional[str] = Opt(None, "--vectors-name", "-vn", help="Optional name for the word vectors, e.g. en_core_web_lg.vectors"),
|
vectors_name: Optional[str] = Opt(None, "--vectors-name", "-vn", help="Optional name for the word vectors, e.g. en_core_web_lg.vectors"),
|
||||||
model_name: Optional[str] = Opt(None, "--model-name", "-mn", help="Optional name for the model meta"),
|
model_name: Optional[str] = Opt(None, "--model-name", "-mn", help="Optional name for the model meta"),
|
||||||
|
@ -56,6 +56,7 @@ def init_model_cli(
|
||||||
freqs_loc=freqs_loc,
|
freqs_loc=freqs_loc,
|
||||||
clusters_loc=clusters_loc,
|
clusters_loc=clusters_loc,
|
||||||
jsonl_loc=jsonl_loc,
|
jsonl_loc=jsonl_loc,
|
||||||
|
vectors_loc=vectors_loc,
|
||||||
prune_vectors=prune_vectors,
|
prune_vectors=prune_vectors,
|
||||||
truncate_vectors=truncate_vectors,
|
truncate_vectors=truncate_vectors,
|
||||||
vectors_name=vectors_name,
|
vectors_name=vectors_name,
|
||||||
|
@ -228,7 +229,7 @@ def add_vectors(
|
||||||
else:
|
else:
|
||||||
if vectors_loc:
|
if vectors_loc:
|
||||||
with msg.loading(f"Reading vectors from {vectors_loc}"):
|
with msg.loading(f"Reading vectors from {vectors_loc}"):
|
||||||
vectors_data, vector_keys = read_vectors(msg, vectors_loc)
|
vectors_data, vector_keys = read_vectors(msg, vectors_loc, truncate_vectors)
|
||||||
msg.good(f"Loaded vectors from {vectors_loc}")
|
msg.good(f"Loaded vectors from {vectors_loc}")
|
||||||
else:
|
else:
|
||||||
vectors_data, vector_keys = (None, None)
|
vectors_data, vector_keys = (None, None)
|
||||||
|
@ -247,7 +248,7 @@ def add_vectors(
|
||||||
nlp.vocab.prune_vectors(prune_vectors)
|
nlp.vocab.prune_vectors(prune_vectors)
|
||||||
|
|
||||||
|
|
||||||
def read_vectors(msg: Printer, vectors_loc: Path, truncate_vectors: int = 0):
|
def read_vectors(msg: Printer, vectors_loc: Path, truncate_vectors: int):
|
||||||
f = open_file(vectors_loc)
|
f = open_file(vectors_loc)
|
||||||
shape = tuple(int(size) for size in next(f).split())
|
shape = tuple(int(size) for size in next(f).split())
|
||||||
if truncate_vectors >= 1:
|
if truncate_vectors >= 1:
|
||||||
|
|
Loading…
Reference in New Issue