mirror of https://github.com/explosion/spaCy.git
Support exclusive_classes setting for textcat models
This commit is contained in:
parent
ce1e4eace2
commit
e9dd5943b9
18
spacy/_ml.py
18
spacy/_ml.py
|
@ -564,18 +564,26 @@ def build_text_classifier(nr_class, width=64, **cfg):
|
||||||
)
|
)
|
||||||
|
|
||||||
linear_model = _preprocess_doc >> LinearModel(nr_class)
|
linear_model = _preprocess_doc >> LinearModel(nr_class)
|
||||||
model = (
|
if cfg.get('exclusive_classes'):
|
||||||
(linear_model | cnn_model)
|
output_layer = Softmax(nr_class, nr_class * 2)
|
||||||
>> zero_init(Affine(nr_class, nr_class * 2, drop_factor=0.0))
|
else:
|
||||||
|
output_layer = (
|
||||||
|
zero_init(Affine(nr_class, nr_class * 2, drop_factor=0.0))
|
||||||
>> logistic
|
>> logistic
|
||||||
)
|
)
|
||||||
model.tok2vec = tok2vec
|
|
||||||
|
|
||||||
|
model = (
|
||||||
|
(linear_model | cnn_model)
|
||||||
|
>> output_layer
|
||||||
|
)
|
||||||
|
model.tok2vec = chain(tok2vec, flatten)
|
||||||
model.nO = nr_class
|
model.nO = nr_class
|
||||||
model.lsuv = False
|
model.lsuv = False
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
||||||
def build_simple_cnn_text_classifier(tok2vec, nr_class, exclusive_classes=True, **cfg):
|
def build_simple_cnn_text_classifier(tok2vec, nr_class, exclusive_classes=False, **cfg):
|
||||||
"""
|
"""
|
||||||
Build a simple CNN text classifier, given a token-to-vector model as inputs.
|
Build a simple CNN text classifier, given a token-to-vector model as inputs.
|
||||||
If exclusive_classes=True, a softmax non-linearity is applied, so that the
|
If exclusive_classes=True, a softmax non-linearity is applied, so that the
|
||||||
|
|
Loading…
Reference in New Issue