From 135de82a2d7073d535d1ffd1e4254e5dca37c046 Mon Sep 17 00:00:00 2001 From: svlandeg Date: Tue, 22 Sep 2020 10:22:06 +0200 Subject: [PATCH] add textcat to quickstart --- spacy/cli/templates/quickstart_training.jinja | 48 ++++++++++++++++++- 1 file changed, 47 insertions(+), 1 deletion(-) diff --git a/spacy/cli/templates/quickstart_training.jinja b/spacy/cli/templates/quickstart_training.jinja index 0db4c8a59..2c7ce024b 100644 --- a/spacy/cli/templates/quickstart_training.jinja +++ b/spacy/cli/templates/quickstart_training.jinja @@ -93,6 +93,29 @@ grad_factor = 1.0 @layers = "reduce_mean.v1" {% endif -%} +{% if "textcat" in components %} +[components.textcat] +factory = "textcat" + +{% if optimize == "accuracy" %} +[components.textcat.model] +@architectures = "spacy.TextCatEnsemble.v1" +exclusive_classes = false +width = 64 +conv_depth = 2 +embed_size = 2000 +window_size = 1 +ngram_size = 1 +nO = null + +{% else -%} +[components.textcat.model] +@architectures = "spacy.TextCatBOW.v1" +exclusive_classes = false +ngram_size = 1 +{%- endif %} +{%- endif %} + {# NON-TRANSFORMER PIPELINE #} {% else -%} @@ -167,10 +190,33 @@ nO = null @architectures = "spacy.Tok2VecListener.v1" width = ${components.tok2vec.model.encode.width} {% endif %} + +{% if "textcat" in components %} +[components.textcat] +factory = "textcat" + +{% if optimize == "accuracy" %} +[components.textcat.model] +@architectures = "spacy.TextCatEnsemble.v1" +exclusive_classes = false +width = 64 +conv_depth = 2 +embed_size = 2000 +window_size = 1 +ngram_size = 1 +nO = null + +{% else -%} +[components.textcat.model] +@architectures = "spacy.TextCatBOW.v1" +exclusive_classes = false +ngram_size = 1 +{%- endif %} +{%- endif %} {% endif %} {% for pipe in components %} -{% if pipe not in ["tagger", "parser", "ner"] %} +{% if pipe not in ["tagger", "parser", "ner", "textcat"] %} {# Other components defined by the user: we just assume they're factories #} [components.{{ pipe }}] factory = "{{ pipe }}"