Fix Tok2Vec

This commit is contained in:
Matthew Honnibal 2017-09-23 02:58:54 +02:00
parent 386c1a5bd8
commit 4bd6a12b1f
1 changed files with 5 additions and 3 deletions

View File

@ -475,14 +475,16 @@ def getitem(i):
return layerize(getitem_fwd) return layerize(getitem_fwd)
def build_tagger_model(nr_class, token_vector_width, pretrained_dims=0, **cfg): def build_tagger_model(nr_class, pretrained_dims=0, **cfg):
embed_size = util.env_opt('embed_size', 4000) embed_size = util.env_opt('embed_size', 4000)
if 'token_vector_width' not in cfg:
token_vector_width = util.env_opt('token_vector_width', 128)
with Model.define_operators({'>>': chain, '+': add}): with Model.define_operators({'>>': chain, '+': add}):
tok2vec = Tok2Vec(token_vector_width, embed_size, tok2vec = Tok2Vec(token_vector_width, embed_size,
pretrained_dims=pretrained_dims) pretrained_dims=pretrained_dims)
model = with_flatten( model = (
tok2vec tok2vec
>> Softmax(nr_class, token_vector_width) >> with_flatten(Softmax(nr_class, token_vector_width))
) )
model.nI = None model.nI = None
model.tok2vec = tok2vec model.tok2vec = tok2vec