mirror of https://github.com/explosion/spaCy.git
rewrite Maxout layer as separate layers to avoid shape inference trouble (#6760)
This commit is contained in:
parent
26c34ab8b0
commit
c8761b0e6e
|
@ -4,7 +4,7 @@ from thinc.types import Floats2d
|
|||
from thinc.api import Model, reduce_mean, Linear, list2ragged, Logistic
|
||||
from thinc.api import chain, concatenate, clone, Dropout, ParametricAttention
|
||||
from thinc.api import SparseLinear, Softmax, softmax_activation, Maxout, reduce_sum
|
||||
from thinc.api import with_cpu, Relu, residual
|
||||
from thinc.api import with_cpu, Relu, residual, LayerNorm
|
||||
from thinc.layers.chain import init as init_chain
|
||||
|
||||
from ...attrs import ORTH
|
||||
|
@ -72,13 +72,14 @@ def build_text_classifier_v2(
|
|||
attention_layer = ParametricAttention(
|
||||
width
|
||||
) # TODO: benchmark performance difference of this layer
|
||||
maxout_layer = Maxout(nO=width, nI=width, dropout=0.0, normalize=True)
|
||||
maxout_layer = Maxout(nO=width, nI=width)
|
||||
norm_layer = LayerNorm(nI=width)
|
||||
cnn_model = (
|
||||
tok2vec
|
||||
>> list2ragged()
|
||||
>> attention_layer
|
||||
>> reduce_sum()
|
||||
>> residual(maxout_layer)
|
||||
>> residual(maxout_layer >> norm_layer >> Dropout(0.0))
|
||||
)
|
||||
|
||||
nO_double = nO * 2 if nO else None
|
||||
|
@ -93,6 +94,7 @@ def build_text_classifier_v2(
|
|||
model.set_ref("output_layer", linear_model.get_ref("output_layer"))
|
||||
model.set_ref("attention_layer", attention_layer)
|
||||
model.set_ref("maxout_layer", maxout_layer)
|
||||
model.set_ref("norm_layer", norm_layer)
|
||||
model.attrs["multi_label"] = not exclusive_classes
|
||||
|
||||
model.init = init_ensemble_textcat
|
||||
|
@ -104,6 +106,7 @@ def init_ensemble_textcat(model, X, Y) -> Model:
|
|||
model.get_ref("attention_layer").set_dim("nO", tok2vec_width)
|
||||
model.get_ref("maxout_layer").set_dim("nO", tok2vec_width)
|
||||
model.get_ref("maxout_layer").set_dim("nI", tok2vec_width)
|
||||
model.get_ref("norm_layer").set_dim("nI", tok2vec_width)
|
||||
init_chain(model, X, Y)
|
||||
return model
|
||||
|
||||
|
|
Loading…
Reference in New Issue