mirror of https://github.com/explosion/spaCy.git
Support exclusive_classes setting for textcat models
This commit is contained in:
parent
ce1e4eace2
commit
e9dd5943b9
16
spacy/_ml.py
16
spacy/_ml.py
|
@ -564,18 +564,26 @@ def build_text_classifier(nr_class, width=64, **cfg):
|
|||
)
|
||||
|
||||
linear_model = _preprocess_doc >> LinearModel(nr_class)
|
||||
if cfg.get('exclusive_classes'):
|
||||
output_layer = Softmax(nr_class, nr_class * 2)
|
||||
else:
|
||||
output_layer = (
|
||||
zero_init(Affine(nr_class, nr_class * 2, drop_factor=0.0))
|
||||
>> logistic
|
||||
)
|
||||
|
||||
|
||||
model = (
|
||||
(linear_model | cnn_model)
|
||||
>> zero_init(Affine(nr_class, nr_class * 2, drop_factor=0.0))
|
||||
>> logistic
|
||||
>> output_layer
|
||||
)
|
||||
model.tok2vec = tok2vec
|
||||
model.tok2vec = chain(tok2vec, flatten)
|
||||
model.nO = nr_class
|
||||
model.lsuv = False
|
||||
return model
|
||||
|
||||
|
||||
def build_simple_cnn_text_classifier(tok2vec, nr_class, exclusive_classes=True, **cfg):
|
||||
def build_simple_cnn_text_classifier(tok2vec, nr_class, exclusive_classes=False, **cfg):
|
||||
"""
|
||||
Build a simple CNN text classifier, given a token-to-vector model as inputs.
|
||||
If exclusive_classes=True, a softmax non-linearity is applied, so that the
|
||||
|
|
Loading…
Reference in New Issue