diff --git a/spacy/pipeline/pipes.pyx b/spacy/pipeline/pipes.pyx index 100b5abfd..4e052ef16 100644 --- a/spacy/pipeline/pipes.pyx +++ b/spacy/pipeline/pipes.pyx @@ -24,7 +24,8 @@ from ..vocab cimport Vocab from ..syntax import nonproj from ..attrs import POS, ID from ..parts_of_speech import X -from .._ml import Tok2Vec, build_tagger_model, build_simple_cnn_text_classifier +from .._ml import Tok2Vec, build_tagger_model +from .._ml import build_text_classifier, build_simple_cnn_text_classifier from .._ml import link_vectors_to_models, zero_init, flatten from .._ml import masked_language_model, create_default_optimizer from ..errors import Errors, TempErrors @@ -862,8 +863,11 @@ class TextCategorizer(Pipe): token_vector_width = cfg["token_vector_width"] else: token_vector_width = util.env_opt("token_vector_width", 96) - tok2vec = Tok2Vec(token_vector_width, embed_size, **cfg) - return build_simple_cnn_text_classifier(tok2vec, nr_class, **cfg) + if cfg.get('architecture') == 'simple_cnn': + tok2vec = Tok2Vec(token_vector_width, embed_size, **cfg) + return build_simple_cnn_text_classifier(tok2vec, nr_class, **cfg) + else: + return build_text_classifier(nr_class, **cfg) @property def tok2vec(self):