Add option for base model in init-model CLI (#5467)

Intended for languages like Chinese with a custom tokenizer.
This commit is contained in:
adrianeboyd 2020-05-20 18:49:11 +02:00 committed by GitHub
parent 9393253b66
commit 49ef06d793
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 14 additions and 5 deletions

View File

@ -17,7 +17,7 @@ from wasabi import msg
from ..vectors import Vectors from ..vectors import Vectors
from ..errors import Errors, Warnings from ..errors import Errors, Warnings
from ..util import ensure_path, get_lang_class, OOV_RANK from ..util import ensure_path, get_lang_class, load_model, OOV_RANK
try: try:
import ftfy import ftfy
@ -49,6 +49,7 @@ DEFAULT_OOV_PROB = -20
str, str,
), ),
model_name=("Optional name for the model meta", "option", "mn", str), model_name=("Optional name for the model meta", "option", "mn", str),
base_model=("Base model (for languages with custom tokenizers)", "option", "b", str),
) )
def init_model( def init_model(
lang, lang,
@ -61,6 +62,7 @@ def init_model(
prune_vectors=-1, prune_vectors=-1,
vectors_name=None, vectors_name=None,
model_name=None, model_name=None,
base_model=None,
): ):
""" """
Create a new model from raw data, like word frequencies, Brown clusters Create a new model from raw data, like word frequencies, Brown clusters
@ -92,7 +94,7 @@ def init_model(
lex_attrs = read_attrs_from_deprecated(freqs_loc, clusters_loc) lex_attrs = read_attrs_from_deprecated(freqs_loc, clusters_loc)
with msg.loading("Creating model..."): with msg.loading("Creating model..."):
nlp = create_model(lang, lex_attrs, name=model_name) nlp = create_model(lang, lex_attrs, name=model_name, base_model=base_model)
msg.good("Successfully created model") msg.good("Successfully created model")
if vectors_loc is not None: if vectors_loc is not None:
add_vectors(nlp, vectors_loc, truncate_vectors, prune_vectors, vectors_name) add_vectors(nlp, vectors_loc, truncate_vectors, prune_vectors, vectors_name)
@ -152,9 +154,16 @@ def read_attrs_from_deprecated(freqs_loc, clusters_loc):
return lex_attrs return lex_attrs
def create_model(lang, lex_attrs, name=None): def create_model(lang, lex_attrs, name=None, base_model=None):
lang_class = get_lang_class(lang) if base_model:
nlp = lang_class() nlp = load_model(base_model)
# keep the tokenizer but remove any existing pipeline components due to
# potentially conflicting vectors
for pipe in nlp.pipe_names:
nlp.remove_pipe(pipe)
else:
lang_class = get_lang_class(lang)
nlp = lang_class()
for lexeme in nlp.vocab: for lexeme in nlp.vocab:
lexeme.rank = OOV_RANK lexeme.rank = OOV_RANK
for attrs in lex_attrs: for attrs in lex_attrs: