one more losses fix

This commit is contained in:
svlandeg 2020-10-14 15:11:34 +02:00
parent 478a14a619
commit 44e14ccae8
1 changed files with 5 additions and 4 deletions

View File

@ -227,10 +227,13 @@ class Tagger(TrainablePipe):
DOCS: https://nightly.spacy.io/api/tagger#rehearse
"""
if losses is None:
losses = {}
losses.setdefault(self.name, 0.0)
validate_examples(examples, "Tagger.rehearse")
docs = [eg.predicted for eg in examples]
if self._rehearsal_model is None:
return
return losses
if not any(len(doc) for doc in docs):
# Handle cases where there are no tokens in any docs.
return losses
@ -240,9 +243,7 @@ class Tagger(TrainablePipe):
gradient = guesses - target
backprop(gradient)
self.finish_update(sgd)
if losses is not None:
losses.setdefault(self.name, 0.0)
losses[self.name] += (gradient**2).sum()
losses[self.name] += (gradient**2).sum()
return losses
def get_loss(self, examples, scores):