diff --git a/examples/training/train_textcat.py b/examples/training/train_textcat.py index e76693b8d..a81b7fbe5 100644 --- a/examples/training/train_textcat.py +++ b/examples/training/train_textcat.py @@ -41,7 +41,9 @@ def main(model=None, output_dir=None, n_iter=20, n_texts=2000): # add the text classifier to the pipeline if it doesn't exist # nlp.create_pipe works for built-ins that are registered with spaCy if "textcat" not in nlp.pipe_names: - textcat = nlp.create_pipe("textcat") + textcat = nlp.create_pipe("textcat", config={ + "architecture": "simple_cnn", + "exclusive_classes": True}) nlp.add_pipe(textcat, last=True) # otherwise, get it, so we can add labels to it else: @@ -70,7 +72,7 @@ def main(model=None, output_dir=None, n_iter=20, n_texts=2000): for i in range(n_iter): losses = {} # batch up the examples using spaCy's minibatch - batches = minibatch(train_data, size=compounding(4.0, 16.0, 1.001)) + batches = minibatch(train_data, size=compounding(4.0, 32.0, 1.001)) for batch in batches: texts, annotations = zip(*batch) nlp.update(texts, annotations, sgd=optimizer, drop=0.2, losses=losses)