mirror of https://github.com/explosion/spaCy.git
Add -t2v argument to train_textcat script
This commit is contained in:
parent
764359c952
commit
4e3ed2ea88
|
@ -24,8 +24,9 @@ from spacy.util import minibatch, compounding
|
|||
output_dir=("Optional output directory", "option", "o", Path),
|
||||
n_texts=("Number of texts to train from", "option", "t", int),
|
||||
n_iter=("Number of training iterations", "option", "n", int),
|
||||
init_tok2vec=("Pretrained tok2vec weights", "option", "t2v", Path)
|
||||
)
|
||||
def main(model=None, output_dir=None, n_iter=20, n_texts=2000):
|
||||
def main(model=None, output_dir=None, n_iter=20, n_texts=2000, init_tok2vec=None):
|
||||
if output_dir is not None:
|
||||
output_dir = Path(output_dir)
|
||||
if not output_dir.exists():
|
||||
|
@ -67,6 +68,9 @@ def main(model=None, output_dir=None, n_iter=20, n_texts=2000):
|
|||
other_pipes = [pipe for pipe in nlp.pipe_names if pipe != "textcat"]
|
||||
with nlp.disable_pipes(*other_pipes): # only train textcat
|
||||
optimizer = nlp.begin_training()
|
||||
if init_tok2vec is not None:
|
||||
with init_tok2vec.open("rb") as file_:
|
||||
textcat.model.tok2vec.from_bytes(file_.read())
|
||||
print("Training the model...")
|
||||
print("{:^5}\t{:^5}\t{:^5}\t{:^5}".format("LOSS", "P", "R", "F"))
|
||||
for i in range(n_iter):
|
||||
|
|
Loading…
Reference in New Issue