mirror of https://github.com/explosion/spaCy.git
Fix Tok2Vec
This commit is contained in:
parent
386c1a5bd8
commit
4bd6a12b1f
|
@ -475,14 +475,16 @@ def getitem(i):
|
||||||
return layerize(getitem_fwd)
|
return layerize(getitem_fwd)
|
||||||
|
|
||||||
|
|
||||||
def build_tagger_model(nr_class, token_vector_width, pretrained_dims=0, **cfg):
|
def build_tagger_model(nr_class, pretrained_dims=0, **cfg):
|
||||||
embed_size = util.env_opt('embed_size', 4000)
|
embed_size = util.env_opt('embed_size', 4000)
|
||||||
|
if 'token_vector_width' not in cfg:
|
||||||
|
token_vector_width = util.env_opt('token_vector_width', 128)
|
||||||
with Model.define_operators({'>>': chain, '+': add}):
|
with Model.define_operators({'>>': chain, '+': add}):
|
||||||
tok2vec = Tok2Vec(token_vector_width, embed_size,
|
tok2vec = Tok2Vec(token_vector_width, embed_size,
|
||||||
pretrained_dims=pretrained_dims)
|
pretrained_dims=pretrained_dims)
|
||||||
model = with_flatten(
|
model = (
|
||||||
tok2vec
|
tok2vec
|
||||||
>> Softmax(nr_class, token_vector_width)
|
>> with_flatten(Softmax(nr_class, token_vector_width))
|
||||||
)
|
)
|
||||||
model.nI = None
|
model.nI = None
|
||||||
model.tok2vec = tok2vec
|
model.tok2vec = tok2vec
|
||||||
|
|
Loading…
Reference in New Issue