mirror of https://github.com/explosion/spaCy.git
* Ensure parser and tagger function correctly when training from missing values, indicated by -1
This commit is contained in:
parent
4ff180db74
commit
67d6e53a69
|
@ -255,19 +255,23 @@ cdef class EnPosTagger:
|
||||||
tokens._tag_strings = self.tag_names
|
tokens._tag_strings = self.tag_names
|
||||||
tokens.is_tagged = True
|
tokens.is_tagged = True
|
||||||
|
|
||||||
def train(self, Tokens tokens, object golds):
|
def train(self, Tokens tokens, object gold_tag_strs):
|
||||||
cdef int i
|
cdef int i
|
||||||
|
cdef int loss
|
||||||
cdef atom_t[N_CONTEXT_FIELDS] context
|
cdef atom_t[N_CONTEXT_FIELDS] context
|
||||||
cdef const weight_t* scores
|
cdef const weight_t* scores
|
||||||
|
golds = [self.tag_names.index(g) if g is not None else -1
|
||||||
|
for g in gold_tag_strs]
|
||||||
correct = 0
|
correct = 0
|
||||||
for i in range(tokens.length):
|
for i in range(tokens.length):
|
||||||
fill_context(context, i, tokens.data)
|
fill_context(context, i, tokens.data)
|
||||||
scores = self.model.score(context)
|
scores = self.model.score(context)
|
||||||
guess = arg_max(scores, self.model.n_classes)
|
guess = arg_max(scores, self.model.n_classes)
|
||||||
self.model.update(context, guess, golds[i], guess != golds[i])
|
loss = guess != golds[i] if golds[i] != -1 else 0
|
||||||
|
self.model.update(context, guess, golds[i], loss)
|
||||||
tokens.data[i].tag = guess
|
tokens.data[i].tag = guess
|
||||||
self.set_morph(i, tokens.data)
|
self.set_morph(i, tokens.data)
|
||||||
correct += guess == golds[i]
|
correct += loss == 0
|
||||||
return correct
|
return correct
|
||||||
|
|
||||||
cdef int set_morph(self, const int i, TokenC* tokens) except -1:
|
cdef int set_morph(self, const int i, TokenC* tokens) except -1:
|
||||||
|
|
|
@ -102,6 +102,10 @@ cdef class GreedyParser:
|
||||||
cdef int* labels_array = <int*>mem.alloc(tokens.length, sizeof(int))
|
cdef int* labels_array = <int*>mem.alloc(tokens.length, sizeof(int))
|
||||||
cdef int i
|
cdef int i
|
||||||
for i in range(tokens.length):
|
for i in range(tokens.length):
|
||||||
|
if gold_heads[i] is None:
|
||||||
|
heads_array[i] = -1
|
||||||
|
labels_array[i] = -1
|
||||||
|
else:
|
||||||
heads_array[i] = gold_heads[i]
|
heads_array[i] = gold_heads[i]
|
||||||
labels_array[i] = self.moves.label_ids[gold_labels[i]]
|
labels_array[i] = self.moves.label_ids[gold_labels[i]]
|
||||||
|
|
||||||
|
@ -123,6 +127,7 @@ cdef class GreedyParser:
|
||||||
self.moves.transition(state, &guess)
|
self.moves.transition(state, &guess)
|
||||||
cdef int n_corr = 0
|
cdef int n_corr = 0
|
||||||
for i in range(tokens.length):
|
for i in range(tokens.length):
|
||||||
|
if gold_heads[i] != -1:
|
||||||
n_corr += (i + state.sent[i].head) == gold_heads[i]
|
n_corr += (i + state.sent[i].head) == gold_heads[i]
|
||||||
if force_gold and n_corr != tokens.length:
|
if force_gold and n_corr != tokens.length:
|
||||||
print py_words
|
print py_words
|
||||||
|
|
Loading…
Reference in New Issue