From 61a051f2c039b5a44ea8397b51917350012a28ed Mon Sep 17 00:00:00 2001 From: Matthew Honnibal Date: Sun, 21 Jan 2018 19:21:34 +0100 Subject: [PATCH] Fix MultitaskObjective --- spacy/pipeline.pyx | 20 +++++++++++++++----- 1 file changed, 15 insertions(+), 5 deletions(-) diff --git a/spacy/pipeline.pyx b/spacy/pipeline.pyx index 3890e2422..11e7df3d1 100644 --- a/spacy/pipeline.pyx +++ b/spacy/pipeline.pyx @@ -652,10 +652,7 @@ class MultitaskObjective(Tagger): self.labels[label] = len(self.labels) if self.model is True: token_vector_width = util.env_opt('token_vector_width') - self.model = chain( - tok2vec, - Softmax(len(self.labels), token_vector_width) - ) + self.model = self.Model(len(self.labels), tok2vec=tok2vec) link_vectors_to_models(self.vocab) if sgd is None: sgd = self.create_optimizer() @@ -663,7 +660,20 @@ class MultitaskObjective(Tagger): @classmethod def Model(cls, n_tags, tok2vec=None, **cfg): - return build_tagger_model(n_tags, tok2vec=tok2vec, **cfg) + token_vector_width = util.env_opt('token_vector_width', 128) + softmax = Softmax(n_tags, token_vector_width) + model = chain( + tok2vec, + softmax + ) + model.tok2vec = tok2vec + model.softmax = softmax + return model + + def predict(self, docs): + tokvecs = self.model.tok2vec(docs) + scores = self.model.softmax(tokvecs) + return tokvecs, scores def get_loss(self, docs, golds, scores): cdef int idx = 0