Support exclusive_classes setting for textcat models

This commit is contained in:
Matthew Honnibal 2019-02-23 11:57:16 +01:00
parent ce1e4eace2
commit e9dd5943b9
1 changed files with 12 additions and 4 deletions

View File

@ -564,18 +564,26 @@ def build_text_classifier(nr_class, width=64, **cfg):
)
linear_model = _preprocess_doc >> LinearModel(nr_class)
if cfg.get('exclusive_classes'):
output_layer = Softmax(nr_class, nr_class * 2)
else:
output_layer = (
zero_init(Affine(nr_class, nr_class * 2, drop_factor=0.0))
>> logistic
)
model = (
(linear_model | cnn_model)
>> zero_init(Affine(nr_class, nr_class * 2, drop_factor=0.0))
>> logistic
>> output_layer
)
model.tok2vec = tok2vec
model.tok2vec = chain(tok2vec, flatten)
model.nO = nr_class
model.lsuv = False
return model
def build_simple_cnn_text_classifier(tok2vec, nr_class, exclusive_classes=True, **cfg):
def build_simple_cnn_text_classifier(tok2vec, nr_class, exclusive_classes=False, **cfg):
"""
Build a simple CNN text classifier, given a token-to-vector model as inputs.
If exclusive_classes=True, a softmax non-linearity is applied, so that the