Fix defaults and args to build_tagger_model

This commit is contained in:
Matthew Honnibal 2017-09-17 05:46:36 -05:00
parent c003c561c3
commit 8f913a74ca
1 changed files with 2 additions and 3 deletions

View File

@ -467,9 +467,8 @@ def getitem(i):
return X[i], None return X[i], None
return layerize(getitem_fwd) return layerize(getitem_fwd)
def build_tagger_model(nr_class, token_vector_width, **cfg): def build_tagger_model(nr_class, token_vector_width, pretrained_dims=0, **cfg):
embed_size = util.env_opt('embed_size', 7500) embed_size = util.env_opt('embed_size', 4000)
pretrained_dims = cfg.get('pretrained_dims', 0)
with Model.define_operators({'>>': chain, '+': add}): with Model.define_operators({'>>': chain, '+': add}):
# Input: (doc, tensor) tuples # Input: (doc, tensor) tuples
private_tok2vec = Tok2Vec(token_vector_width, embed_size, private_tok2vec = Tok2Vec(token_vector_width, embed_size,