mirror of https://github.com/explosion/spaCy.git
Fix tagger
This commit is contained in:
parent
da7650e84b
commit
3e3a309764
|
@ -681,7 +681,7 @@ class Tagger(Pipe):
|
|||
idx += 1
|
||||
correct = self.model.ops.xp.array(correct, dtype='i')
|
||||
d_scores = scores - to_categorical(correct, nb_classes=scores.shape[1])
|
||||
d_scores *= self.ops.asarray(known_labels)
|
||||
d_scores *= self.model.ops.asarray(known_labels)
|
||||
loss = (d_scores**2).sum()
|
||||
d_scores = self.model.ops.unflatten(d_scores, [len(d) for d in docs])
|
||||
return float(loss), d_scores
|
||||
|
|
Loading…
Reference in New Issue