From dad8f09fba1fe00cf24f3ab5d920bbe4168f4709 Mon Sep 17 00:00:00 2001 From: Matthew Honnibal Date: Wed, 1 Nov 2017 16:34:31 +0100 Subject: [PATCH] Fix print statements in text classifier example --- examples/training/train_textcat.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/examples/training/train_textcat.py b/examples/training/train_textcat.py index 852635075..6fa79e75b 100644 --- a/examples/training/train_textcat.py +++ b/examples/training/train_textcat.py @@ -26,8 +26,9 @@ from spacy.pipeline import TextCategorizer @plac.annotations( model=("Model name. Defaults to blank 'en' model.", "option", "m", str), output_dir=("Optional output directory", "option", "o", Path), + n_examples=("Number of texts to train from", "option", "N", int), n_iter=("Number of training iterations", "option", "n", int)) -def main(model=None, output_dir=None, n_iter=20): +def main(model=None, output_dir=None, n_iter=20, n_texts=2000): if model is not None: nlp = spacy.load(model) # load existing spaCy model print("Loaded model '%s'" % model) @@ -50,7 +51,8 @@ def main(model=None, output_dir=None, n_iter=20): # load the IMBD dataset print("Loading IMDB data...") - (train_texts, train_cats), (dev_texts, dev_cats) = load_data(limit=2000) + print("Using %d training examples" % n_texts) + (train_texts, train_cats), (dev_texts, dev_cats) = load_data(limit=n_texts) train_docs = [nlp.tokenizer(text) for text in train_texts] train_gold = [GoldParse(doc, cats=cats) for doc, cats in zip(train_docs, train_cats)] @@ -65,14 +67,14 @@ def main(model=None, output_dir=None, n_iter=20): for i in range(n_iter): losses = {} # batch up the examples using spaCy's minibatch - batches = minibatch(train_data, size=compounding(4., 128., 1.001)) + batches = minibatch(train_data, size=compounding(4., 32., 1.001)) for batch in batches: docs, golds = zip(*batch) nlp.update(docs, golds, sgd=optimizer, drop=0.2, losses=losses) with textcat.model.use_params(optimizer.averages): # evaluate on the dev data split off in load_data() scores = evaluate(nlp.tokenizer, textcat, dev_texts, dev_cats) - print('{0:.3f}\t{0:.3f}\t{0:.3f}\t{0:.3f}' # print a simple table + print('{0:.3f}\t{1:.3f}\t{2:.3f}\t{3:.3f}' # print a simple table .format(losses['textcat'], scores['textcat_p'], scores['textcat_r'], scores['textcat_f']))