mirror of https://github.com/explosion/spaCy.git
Add depth argument for text classifier
This commit is contained in:
parent
13067095a1
commit
3cdee79a0c
|
@ -450,6 +450,7 @@ def SpacyVectors(docs, drop=0.):
|
||||||
|
|
||||||
|
|
||||||
def build_text_classifier(nr_class, width=64, **cfg):
|
def build_text_classifier(nr_class, width=64, **cfg):
|
||||||
|
depth = cfg.get('depth', 2)
|
||||||
nr_vector = cfg.get('nr_vector', 5000)
|
nr_vector = cfg.get('nr_vector', 5000)
|
||||||
pretrained_dims = cfg.get('pretrained_dims', 0)
|
pretrained_dims = cfg.get('pretrained_dims', 0)
|
||||||
with Model.define_operators({'>>': chain, '+': add, '|': concatenate,
|
with Model.define_operators({'>>': chain, '+': add, '|': concatenate,
|
||||||
|
@ -501,7 +502,7 @@ def build_text_classifier(nr_class, width=64, **cfg):
|
||||||
LN(Maxout(width, vectors_width))
|
LN(Maxout(width, vectors_width))
|
||||||
>> Residual(
|
>> Residual(
|
||||||
(ExtractWindow(nW=1) >> LN(Maxout(width, width*3)))
|
(ExtractWindow(nW=1) >> LN(Maxout(width, width*3)))
|
||||||
) ** 2, pad=2
|
) ** depth, pad=depth
|
||||||
)
|
)
|
||||||
>> flatten_add_lengths
|
>> flatten_add_lengths
|
||||||
>> ParametricAttention(width)
|
>> ParametricAttention(width)
|
||||||
|
|
Loading…
Reference in New Issue