diff --git a/spacy/syntax/nn_parser.pyx b/spacy/syntax/nn_parser.pyx index 8160f52e8..ca238774a 100644 --- a/spacy/syntax/nn_parser.pyx +++ b/spacy/syntax/nn_parser.pyx @@ -51,6 +51,7 @@ from .._ml import zero_init, PrecomputableAffine, PrecomputableMaxouts from .._ml import Tok2Vec, doc2feats, rebatch from ..compat import json_dumps +from . import _beam_utils from . import _parse_features from ._parse_features cimport CONTEXT_SIZE from ._parse_features cimport fill_context @@ -504,6 +505,9 @@ cdef class Parser: losses[self.name] = 0. docs, tokvec_lists = docs_tokvecs tokvecs = self.model[0].ops.flatten(tokvec_lists) + my_tokvecs, bp_my_tokvecs = self.model[0].begin_update(docs_tokvecs, drop=drop) + tokvecs += self.model[0].ops.flatten(my_tokvecs) + if isinstance(docs, Doc) and isinstance(golds, GoldParse): docs = [docs] golds = [golds] @@ -557,8 +561,7 @@ cdef class Parser: self._make_updates(d_tokvecs, backprops, sgd, cuda_stream) d_tokvecs = self.model[0].ops.unflatten(d_tokvecs, [len(d) for d in docs]) - if USE_FINE_TUNE: - bp_my_tokvecs(d_tokvecs, sgd=sgd) + bp_my_tokvecs(d_tokvecs, sgd=sgd) return d_tokvecs def update_beam(self, docs_tokvecs, golds, width=None, density=None, @@ -573,10 +576,9 @@ cdef class Parser: lengths = [len(d) for d in docs] assert min(lengths) >= 1 tokvecs = self.model[0].ops.flatten(tokvecs) - if USE_FINE_TUNE: - my_tokvecs, bp_my_tokvecs = self.model[0].begin_update(docs_tokvecs, drop=drop) - my_tokvecs = self.model[0].ops.flatten(my_tokvecs) - tokvecs += my_tokvecs + my_tokvecs, bp_my_tokvecs = self.model[0].begin_update(docs_tokvecs, drop=drop) + my_tokvecs = self.model[0].ops.flatten(my_tokvecs) + tokvecs += my_tokvecs states = self.moves.init_batch(docs) for gold in golds: @@ -607,8 +609,7 @@ cdef class Parser: d_tokvecs = self.model[0].ops.allocate(tokvecs.shape) self._make_updates(d_tokvecs, backprop_lower, sgd, cuda_stream) d_tokvecs = self.model[0].ops.unflatten(d_tokvecs, lengths) - if USE_FINE_TUNE: - bp_my_tokvecs(d_tokvecs, sgd=sgd) + bp_my_tokvecs(d_tokvecs, sgd=sgd) return d_tokvecs def _init_gold_batch(self, whole_docs, whole_golds):