Update textcat example

This commit is contained in:
ines 2017-11-01 17:09:22 +01:00
parent 1ae40b50b4
commit 8f1d3fc3ee
1 changed files with 4 additions and 5 deletions

View File

@ -26,7 +26,7 @@ from spacy.pipeline import TextCategorizer
@plac.annotations( @plac.annotations(
model=("Model name. Defaults to blank 'en' model.", "option", "m", str), model=("Model name. Defaults to blank 'en' model.", "option", "m", str),
output_dir=("Optional output directory", "option", "o", Path), output_dir=("Optional output directory", "option", "o", Path),
n_examples=("Number of texts to train from", "option", "N", int), n_texts=("Number of texts to train from", "option", "t", int),
n_iter=("Number of training iterations", "option", "n", int)) n_iter=("Number of training iterations", "option", "n", int))
def main(model=None, output_dir=None, n_iter=20, n_texts=2000): def main(model=None, output_dir=None, n_iter=20, n_texts=2000):
if model is not None: if model is not None:
@ -39,20 +39,19 @@ def main(model=None, output_dir=None, n_iter=20, n_texts=2000):
# add the text classifier to the pipeline if it doesn't exist # add the text classifier to the pipeline if it doesn't exist
# nlp.create_pipe works for built-ins that are registered with spaCy # nlp.create_pipe works for built-ins that are registered with spaCy
if 'textcat' not in nlp.pipe_names: if 'textcat' not in nlp.pipe_names:
# textcat = nlp.create_pipe('textcat') textcat = nlp.create_pipe('textcat')
textcat = TextCategorizer(nlp.vocab, labels=['POSITIVE'])
nlp.add_pipe(textcat, last=True) nlp.add_pipe(textcat, last=True)
# otherwise, get it, so we can add labels to it # otherwise, get it, so we can add labels to it
else: else:
textcat = nlp.get_pipe('textcat') textcat = nlp.get_pipe('textcat')
# add label to text classifier # add label to text classifier
# textcat.add_label('POSITIVE') textcat.add_label('POSITIVE')
# load the IMBD dataset # load the IMBD dataset
print("Loading IMDB data...") print("Loading IMDB data...")
print("Using %d training examples" % n_texts)
(train_texts, train_cats), (dev_texts, dev_cats) = load_data(limit=n_texts) (train_texts, train_cats), (dev_texts, dev_cats) = load_data(limit=n_texts)
print("Using %d training examples" % n_texts)
train_docs = [nlp.tokenizer(text) for text in train_texts] train_docs = [nlp.tokenizer(text) for text in train_texts]
train_gold = [GoldParse(doc, cats=cats) for doc, cats in train_gold = [GoldParse(doc, cats=cats) for doc, cats in
zip(train_docs, train_cats)] zip(train_docs, train_cats)]