Support exclusive_classes setting for textcat models

This commit is contained in:
Matthew Honnibal 2019-02-23 11:57:16 +01:00
parent ce1e4eace2
commit e9dd5943b9
1 changed files with 12 additions and 4 deletions

View File

@ -564,18 +564,26 @@ def build_text_classifier(nr_class, width=64, **cfg):
) )
linear_model = _preprocess_doc >> LinearModel(nr_class) linear_model = _preprocess_doc >> LinearModel(nr_class)
model = ( if cfg.get('exclusive_classes'):
(linear_model | cnn_model) output_layer = Softmax(nr_class, nr_class * 2)
>> zero_init(Affine(nr_class, nr_class * 2, drop_factor=0.0)) else:
output_layer = (
zero_init(Affine(nr_class, nr_class * 2, drop_factor=0.0))
>> logistic >> logistic
) )
model.tok2vec = tok2vec
model = (
(linear_model | cnn_model)
>> output_layer
)
model.tok2vec = chain(tok2vec, flatten)
model.nO = nr_class model.nO = nr_class
model.lsuv = False model.lsuv = False
return model return model
def build_simple_cnn_text_classifier(tok2vec, nr_class, exclusive_classes=True, **cfg): def build_simple_cnn_text_classifier(tok2vec, nr_class, exclusive_classes=False, **cfg):
""" """
Build a simple CNN text classifier, given a token-to-vector model as inputs. Build a simple CNN text classifier, given a token-to-vector model as inputs.
If exclusive_classes=True, a softmax non-linearity is applied, so that the If exclusive_classes=True, a softmax non-linearity is applied, so that the