mirror of https://github.com/explosion/spaCy.git
Add option for base model in init-model CLI (#5467)
Intended for languages like Chinese with a custom tokenizer.
This commit is contained in:
parent
9393253b66
commit
49ef06d793
|
@ -17,7 +17,7 @@ from wasabi import msg
|
||||||
|
|
||||||
from ..vectors import Vectors
|
from ..vectors import Vectors
|
||||||
from ..errors import Errors, Warnings
|
from ..errors import Errors, Warnings
|
||||||
from ..util import ensure_path, get_lang_class, OOV_RANK
|
from ..util import ensure_path, get_lang_class, load_model, OOV_RANK
|
||||||
|
|
||||||
try:
|
try:
|
||||||
import ftfy
|
import ftfy
|
||||||
|
@ -49,6 +49,7 @@ DEFAULT_OOV_PROB = -20
|
||||||
str,
|
str,
|
||||||
),
|
),
|
||||||
model_name=("Optional name for the model meta", "option", "mn", str),
|
model_name=("Optional name for the model meta", "option", "mn", str),
|
||||||
|
base_model=("Base model (for languages with custom tokenizers)", "option", "b", str),
|
||||||
)
|
)
|
||||||
def init_model(
|
def init_model(
|
||||||
lang,
|
lang,
|
||||||
|
@ -61,6 +62,7 @@ def init_model(
|
||||||
prune_vectors=-1,
|
prune_vectors=-1,
|
||||||
vectors_name=None,
|
vectors_name=None,
|
||||||
model_name=None,
|
model_name=None,
|
||||||
|
base_model=None,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Create a new model from raw data, like word frequencies, Brown clusters
|
Create a new model from raw data, like word frequencies, Brown clusters
|
||||||
|
@ -92,7 +94,7 @@ def init_model(
|
||||||
lex_attrs = read_attrs_from_deprecated(freqs_loc, clusters_loc)
|
lex_attrs = read_attrs_from_deprecated(freqs_loc, clusters_loc)
|
||||||
|
|
||||||
with msg.loading("Creating model..."):
|
with msg.loading("Creating model..."):
|
||||||
nlp = create_model(lang, lex_attrs, name=model_name)
|
nlp = create_model(lang, lex_attrs, name=model_name, base_model=base_model)
|
||||||
msg.good("Successfully created model")
|
msg.good("Successfully created model")
|
||||||
if vectors_loc is not None:
|
if vectors_loc is not None:
|
||||||
add_vectors(nlp, vectors_loc, truncate_vectors, prune_vectors, vectors_name)
|
add_vectors(nlp, vectors_loc, truncate_vectors, prune_vectors, vectors_name)
|
||||||
|
@ -152,9 +154,16 @@ def read_attrs_from_deprecated(freqs_loc, clusters_loc):
|
||||||
return lex_attrs
|
return lex_attrs
|
||||||
|
|
||||||
|
|
||||||
def create_model(lang, lex_attrs, name=None):
|
def create_model(lang, lex_attrs, name=None, base_model=None):
|
||||||
lang_class = get_lang_class(lang)
|
if base_model:
|
||||||
nlp = lang_class()
|
nlp = load_model(base_model)
|
||||||
|
# keep the tokenizer but remove any existing pipeline components due to
|
||||||
|
# potentially conflicting vectors
|
||||||
|
for pipe in nlp.pipe_names:
|
||||||
|
nlp.remove_pipe(pipe)
|
||||||
|
else:
|
||||||
|
lang_class = get_lang_class(lang)
|
||||||
|
nlp = lang_class()
|
||||||
for lexeme in nlp.vocab:
|
for lexeme in nlp.vocab:
|
||||||
lexeme.rank = OOV_RANK
|
lexeme.rank = OOV_RANK
|
||||||
for attrs in lex_attrs:
|
for attrs in lex_attrs:
|
||||||
|
|
Loading…
Reference in New Issue