diff --git a/spacy/_ml.py b/spacy/_ml.py index 511babf3c..ba9c3b634 100644 --- a/spacy/_ml.py +++ b/spacy/_ml.py @@ -564,18 +564,26 @@ def build_text_classifier(nr_class, width=64, **cfg): ) linear_model = _preprocess_doc >> LinearModel(nr_class) + if cfg.get('exclusive_classes'): + output_layer = Softmax(nr_class, nr_class * 2) + else: + output_layer = ( + zero_init(Affine(nr_class, nr_class * 2, drop_factor=0.0)) + >> logistic + ) + + model = ( (linear_model | cnn_model) - >> zero_init(Affine(nr_class, nr_class * 2, drop_factor=0.0)) - >> logistic + >> output_layer ) - model.tok2vec = tok2vec + model.tok2vec = chain(tok2vec, flatten) model.nO = nr_class model.lsuv = False return model -def build_simple_cnn_text_classifier(tok2vec, nr_class, exclusive_classes=True, **cfg): +def build_simple_cnn_text_classifier(tok2vec, nr_class, exclusive_classes=False, **cfg): """ Build a simple CNN text classifier, given a token-to-vector model as inputs. If exclusive_classes=True, a softmax non-linearity is applied, so that the