diff --git a/spacy/syntax/nn_parser.pyx b/spacy/syntax/nn_parser.pyx index 6f23a08b5..645e5d9e6 100644 --- a/spacy/syntax/nn_parser.pyx +++ b/spacy/syntax/nn_parser.pyx @@ -416,7 +416,9 @@ cdef class Parser: free(scores) free(token_ids) - def update(self, docs_tokvecs, golds, drop=0., sgd=None): + def update(self, docs_tokvecs, golds, drop=0., sgd=None, losses=None): + if losses is not None and self.name not in losses: + losses[self.name] = 0. docs, tokvec_lists = docs_tokvecs tokvecs = self.model[0].ops.flatten(tokvec_lists) if isinstance(docs, Doc) and isinstance(golds, GoldParse): @@ -436,18 +438,20 @@ cdef class Parser: backprops = [] d_tokvecs = state2vec.ops.allocate(tokvecs.shape) cdef float loss = 0. - while len(todo) >= 3: + while len(todo) >= 2: states, golds = zip(*todo) token_ids = self.get_token_ids(states) vector, bp_vector = state2vec.begin_update(token_ids, drop=0.0) - mask = vec2scores.ops.get_dropout_mask(vector.shape, drop) - vector *= mask + if drop != 0: + mask = vec2scores.ops.get_dropout_mask(vector.shape, drop) + vector *= mask scores, bp_scores = vec2scores.begin_update(vector, drop=drop) d_scores = self.get_batch_loss(states, golds, scores) d_vector = bp_scores(d_scores, sgd=sgd) - d_vector *= mask + if drop != 0: + d_vector *= mask if isinstance(self.model[0].ops, CupyOps) \ and not isinstance(token_ids, state2vec.ops.xp.ndarray): @@ -461,10 +465,12 @@ cdef class Parser: backprops.append((token_ids, d_vector, bp_vector)) self.transition_batch(states, scores) todo = [st for st in todo if not st[0].is_final()] - if len(backprops) >= 50: + if len(backprops) >= 20: self._make_updates(d_tokvecs, backprops, sgd, cuda_stream) backprops = [] + if losses is not None: + losses[self.name] += (d_scores**2).sum() if backprops: self._make_updates(d_tokvecs, backprops, sgd, cuda_stream)