mirror of https://github.com/explosion/spaCy.git
Adjust text classification model
This commit is contained in:
parent
ac040b99bb
commit
a824cf8f9a
23
spacy/_ml.py
23
spacy/_ml.py
|
@ -24,7 +24,7 @@ from thinc.linear.linear import LinearModel
|
|||
from thinc.api import uniqued, wrap, flatten_add_lengths
|
||||
|
||||
|
||||
from .attrs import ID, ORTH, LOWER, NORM, PREFIX, SUFFIX, SHAPE, TAG, DEP
|
||||
from .attrs import ID, ORTH, LOWER, NORM, PREFIX, SUFFIX, SHAPE, TAG, DEP, CLUSTER
|
||||
from .tokens.doc import Doc
|
||||
from . import util
|
||||
|
||||
|
@ -521,25 +521,30 @@ def build_text_classifier(nr_class, width=64, **cfg):
|
|||
trained_vectors = (
|
||||
FeatureExtracter([ORTH, LOWER, PREFIX, SUFFIX, SHAPE, ID])
|
||||
>> with_flatten(
|
||||
uniqued(
|
||||
(lower | prefix | suffix | shape)
|
||||
>> LN(Maxout(width, 64+32+32+32)),
|
||||
column=0
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
convolution = (
|
||||
ExtractWindow(nW=1)
|
||||
>> LN(Maxout(width, width*3))
|
||||
static_vectors = (
|
||||
SpacyVectors
|
||||
>> with_flatten(Affine(width, 300))
|
||||
)
|
||||
|
||||
cnn_model = (
|
||||
# TODO Make concatenate support lists
|
||||
concatenate_lists(trained_vectors, SpacyVectors)
|
||||
>> with_flatten(
|
||||
LN(Maxout(width, 64+32+32+32+300))
|
||||
>> convolution ** 4, pad=4)
|
||||
concatenate_lists(trained_vectors, static_vectors)
|
||||
>> flatten_add_lengths
|
||||
>> with_getitem(0,
|
||||
SELU(width, width*2)
|
||||
>> (ExtractWindow(nW=1) >> SELU(width, width*3)) ** 2
|
||||
)
|
||||
>> ParametricAttention(width)
|
||||
>> Pooling(sum_pool)
|
||||
>> ReLu(width, width)
|
||||
>> SELU(width, width) ** 2
|
||||
>> zero_init(Affine(nr_class, width, drop_factor=0.0))
|
||||
)
|
||||
|
||||
|
|
Loading…
Reference in New Issue