Normalize over all actions in parser, not just valid ones

This commit is contained in:
Matthew Honnibal 2019-03-15 15:22:16 +01:00
parent b94b2b1168
commit 693c8934e8
1 changed files with 5 additions and 8 deletions

View File

@ -156,7 +156,7 @@ cdef void cpu_log_loss(float* d_scores,
"""Do multi-label log loss""" """Do multi-label log loss"""
cdef double max_, gmax, Z, gZ cdef double max_, gmax, Z, gZ
best = arg_max_if_gold(scores, costs, is_valid, O) best = arg_max_if_gold(scores, costs, is_valid, O)
guess = arg_max_if_valid(scores, is_valid, O) guess = Vec.arg_max(scores, O)
if best == -1 or guess == -1: if best == -1 or guess == -1:
# These shouldn't happen, but if they do, we want to make sure we don't # These shouldn't happen, but if they do, we want to make sure we don't
# cause an OOB access. # cause an OOB access.
@ -166,14 +166,11 @@ cdef void cpu_log_loss(float* d_scores,
max_ = scores[guess] max_ = scores[guess]
gmax = scores[best] gmax = scores[best]
for i in range(O): for i in range(O):
if is_valid[i]:
Z += exp(scores[i] - max_) Z += exp(scores[i] - max_)
if costs[i] <= costs[best]: if costs[i] <= costs[best]:
gZ += exp(scores[i] - gmax) gZ += exp(scores[i] - gmax)
for i in range(O): for i in range(O):
if not is_valid[i]: if costs[i] <= costs[best]:
d_scores[i] = 0.
elif costs[i] <= costs[best]:
d_scores[i] = (exp(scores[i]-max_) / Z) - (exp(scores[i]-gmax)/gZ) d_scores[i] = (exp(scores[i]-max_) / Z) - (exp(scores[i]-gmax)/gZ)
else: else:
d_scores[i] = exp(scores[i]-max_) / Z d_scores[i] = exp(scores[i]-max_) / Z