Add -t2v argument to train_textcat script

This commit is contained in:
Matthew Honnibal 2019-03-20 23:05:11 +01:00
parent 764359c952
commit 4e3ed2ea88
1 changed files with 5 additions and 1 deletions

View File

@ -24,8 +24,9 @@ from spacy.util import minibatch, compounding
output_dir=("Optional output directory", "option", "o", Path),
n_texts=("Number of texts to train from", "option", "t", int),
n_iter=("Number of training iterations", "option", "n", int),
init_tok2vec=("Pretrained tok2vec weights", "option", "t2v", Path)
)
def main(model=None, output_dir=None, n_iter=20, n_texts=2000):
def main(model=None, output_dir=None, n_iter=20, n_texts=2000, init_tok2vec=None):
if output_dir is not None:
output_dir = Path(output_dir)
if not output_dir.exists():
@ -67,6 +68,9 @@ def main(model=None, output_dir=None, n_iter=20, n_texts=2000):
other_pipes = [pipe for pipe in nlp.pipe_names if pipe != "textcat"]
with nlp.disable_pipes(*other_pipes): # only train textcat
optimizer = nlp.begin_training()
if init_tok2vec is not None:
with init_tok2vec.open("rb") as file_:
textcat.model.tok2vec.from_bytes(file_.read())
print("Training the model...")
print("{:^5}\t{:^5}\t{:^5}\t{:^5}".format("LOSS", "P", "R", "F"))
for i in range(n_iter):