diff --git a/spacy/cli/templates/quickstart_training.jinja b/spacy/cli/templates/quickstart_training.jinja index 1194438de..37983cb1a 100644 --- a/spacy/cli/templates/quickstart_training.jinja +++ b/spacy/cli/templates/quickstart_training.jinja @@ -143,6 +143,9 @@ nO = null @architectures = "spacy-transformers.TransformerListener.v1" grad_factor = 1.0 +[components.textcat.model.tok2vec.pooling] +@layers = "reduce_mean.v1" + [components.textcat.model.linear_model] @architectures = "spacy.TextCatBOW.v1" exclusive_classes = false diff --git a/spacy/ml/models/textcat.py b/spacy/ml/models/textcat.py index d4aed2839..2ec036810 100644 --- a/spacy/ml/models/textcat.py +++ b/spacy/ml/models/textcat.py @@ -61,14 +61,14 @@ def build_bow_text_classifier( @registry.architectures.register("spacy.TextCatEnsemble.v2") -def build_text_classifier( +def build_text_classifier_v2( tok2vec: Model[List[Doc], List[Floats2d]], linear_model: Model[List[Doc], Floats2d], nO: Optional[int] = None, ) -> Model[List[Doc], Floats2d]: exclusive_classes = not linear_model.attrs["multi_label"] with Model.define_operators({">>": chain, "|": concatenate}): - width = tok2vec.get_dim("nO") + width = tok2vec.maybe_get_dim("nO") cnn_model = ( tok2vec >> list2ragged() @@ -94,7 +94,7 @@ def build_text_classifier( # TODO: move to legacy @registry.architectures.register("spacy.TextCatEnsemble.v1") -def build_text_classifier( +def build_text_classifier_v1( width: int, embed_size: int, pretrained_vectors: Optional[bool],