From a0c899a0fff08e09f7ebabb8e0e50baa4f4b0897 Mon Sep 17 00:00:00 2001 From: Sofie Van Landeghem Date: Tue, 10 Nov 2020 13:14:47 +0100 Subject: [PATCH] Fix textcat + transformer architecture (#6371) * add pooling to textcat TransformerListener * maybe_get_dim in case it's null --- spacy/cli/templates/quickstart_training.jinja | 3 +++ spacy/ml/models/textcat.py | 6 +++--- 2 files changed, 6 insertions(+), 3 deletions(-) diff --git a/spacy/cli/templates/quickstart_training.jinja b/spacy/cli/templates/quickstart_training.jinja index 1194438de..37983cb1a 100644 --- a/spacy/cli/templates/quickstart_training.jinja +++ b/spacy/cli/templates/quickstart_training.jinja @@ -143,6 +143,9 @@ nO = null @architectures = "spacy-transformers.TransformerListener.v1" grad_factor = 1.0 +[components.textcat.model.tok2vec.pooling] +@layers = "reduce_mean.v1" + [components.textcat.model.linear_model] @architectures = "spacy.TextCatBOW.v1" exclusive_classes = false diff --git a/spacy/ml/models/textcat.py b/spacy/ml/models/textcat.py index d4aed2839..2ec036810 100644 --- a/spacy/ml/models/textcat.py +++ b/spacy/ml/models/textcat.py @@ -61,14 +61,14 @@ def build_bow_text_classifier( @registry.architectures.register("spacy.TextCatEnsemble.v2") -def build_text_classifier( +def build_text_classifier_v2( tok2vec: Model[List[Doc], List[Floats2d]], linear_model: Model[List[Doc], Floats2d], nO: Optional[int] = None, ) -> Model[List[Doc], Floats2d]: exclusive_classes = not linear_model.attrs["multi_label"] with Model.define_operators({">>": chain, "|": concatenate}): - width = tok2vec.get_dim("nO") + width = tok2vec.maybe_get_dim("nO") cnn_model = ( tok2vec >> list2ragged() @@ -94,7 +94,7 @@ def build_text_classifier( # TODO: move to legacy @registry.architectures.register("spacy.TextCatEnsemble.v1") -def build_text_classifier( +def build_text_classifier_v1( width: int, embed_size: int, pretrained_vectors: Optional[bool],