Adjust text classification model

This commit is contained in:
Matthew Honnibal 2017-09-02 11:41:00 +02:00
parent ac040b99bb
commit a824cf8f9a
1 changed files with 15 additions and 10 deletions

View File

@ -24,7 +24,7 @@ from thinc.linear.linear import LinearModel
from thinc.api import uniqued, wrap, flatten_add_lengths
from .attrs import ID, ORTH, LOWER, NORM, PREFIX, SUFFIX, SHAPE, TAG, DEP
from .attrs import ID, ORTH, LOWER, NORM, PREFIX, SUFFIX, SHAPE, TAG, DEP, CLUSTER
from .tokens.doc import Doc
from . import util
@ -521,25 +521,30 @@ def build_text_classifier(nr_class, width=64, **cfg):
trained_vectors = (
FeatureExtracter([ORTH, LOWER, PREFIX, SUFFIX, SHAPE, ID])
>> with_flatten(
(lower | prefix | suffix | shape)
uniqued(
(lower | prefix | suffix | shape)
>> LN(Maxout(width, 64+32+32+32)),
column=0
)
)
)
convolution = (
ExtractWindow(nW=1)
>> LN(Maxout(width, width*3))
static_vectors = (
SpacyVectors
>> with_flatten(Affine(width, 300))
)
cnn_model = (
# TODO Make concatenate support lists
concatenate_lists(trained_vectors, SpacyVectors)
>> with_flatten(
LN(Maxout(width, 64+32+32+32+300))
>> convolution ** 4, pad=4)
concatenate_lists(trained_vectors, static_vectors)
>> flatten_add_lengths
>> with_getitem(0,
SELU(width, width*2)
>> (ExtractWindow(nW=1) >> SELU(width, width*3)) ** 2
)
>> ParametricAttention(width)
>> Pooling(sum_pool)
>> ReLu(width, width)
>> SELU(width, width) ** 2
>> zero_init(Affine(nr_class, width, drop_factor=0.0))
)