Fix MultitaskObjective

This commit is contained in:
Matthew Honnibal 2018-01-21 19:21:34 +01:00
parent 82135d85b7
commit 61a051f2c0
1 changed files with 15 additions and 5 deletions

View File

@ -652,10 +652,7 @@ class MultitaskObjective(Tagger):
self.labels[label] = len(self.labels) self.labels[label] = len(self.labels)
if self.model is True: if self.model is True:
token_vector_width = util.env_opt('token_vector_width') token_vector_width = util.env_opt('token_vector_width')
self.model = chain( self.model = self.Model(len(self.labels), tok2vec=tok2vec)
tok2vec,
Softmax(len(self.labels), token_vector_width)
)
link_vectors_to_models(self.vocab) link_vectors_to_models(self.vocab)
if sgd is None: if sgd is None:
sgd = self.create_optimizer() sgd = self.create_optimizer()
@ -663,7 +660,20 @@ class MultitaskObjective(Tagger):
@classmethod @classmethod
def Model(cls, n_tags, tok2vec=None, **cfg): def Model(cls, n_tags, tok2vec=None, **cfg):
return build_tagger_model(n_tags, tok2vec=tok2vec, **cfg) token_vector_width = util.env_opt('token_vector_width', 128)
softmax = Softmax(n_tags, token_vector_width)
model = chain(
tok2vec,
softmax
)
model.tok2vec = tok2vec
model.softmax = softmax
return model
def predict(self, docs):
tokvecs = self.model.tok2vec(docs)
scores = self.model.softmax(tokvecs)
return tokvecs, scores
def get_loss(self, docs, golds, scores): def get_loss(self, docs, golds, scores):
cdef int idx = 0 cdef int idx = 0