From 67d6e53a6959ec0e3c7c30fff95462c8bba54102 Mon Sep 17 00:00:00 2001 From: Matthew Honnibal Date: Fri, 30 Jan 2015 14:08:56 +1100 Subject: [PATCH] * Ensure parser and tagger function correctly when training from missing values, indicated by -1 --- spacy/en/pos.pyx | 10 +++++++--- spacy/syntax/parser.pyx | 9 +++++++-- 2 files changed, 14 insertions(+), 5 deletions(-) diff --git a/spacy/en/pos.pyx b/spacy/en/pos.pyx index 1e19b9b82..d8d1685b2 100644 --- a/spacy/en/pos.pyx +++ b/spacy/en/pos.pyx @@ -255,19 +255,23 @@ cdef class EnPosTagger: tokens._tag_strings = self.tag_names tokens.is_tagged = True - def train(self, Tokens tokens, object golds): + def train(self, Tokens tokens, object gold_tag_strs): cdef int i + cdef int loss cdef atom_t[N_CONTEXT_FIELDS] context cdef const weight_t* scores + golds = [self.tag_names.index(g) if g is not None else -1 + for g in gold_tag_strs] correct = 0 for i in range(tokens.length): fill_context(context, i, tokens.data) scores = self.model.score(context) guess = arg_max(scores, self.model.n_classes) - self.model.update(context, guess, golds[i], guess != golds[i]) + loss = guess != golds[i] if golds[i] != -1 else 0 + self.model.update(context, guess, golds[i], loss) tokens.data[i].tag = guess self.set_morph(i, tokens.data) - correct += guess == golds[i] + correct += loss == 0 return correct cdef int set_morph(self, const int i, TokenC* tokens) except -1: diff --git a/spacy/syntax/parser.pyx b/spacy/syntax/parser.pyx index 61324f69c..4144e93cd 100644 --- a/spacy/syntax/parser.pyx +++ b/spacy/syntax/parser.pyx @@ -102,8 +102,12 @@ cdef class GreedyParser: cdef int* labels_array = mem.alloc(tokens.length, sizeof(int)) cdef int i for i in range(tokens.length): - heads_array[i] = gold_heads[i] - labels_array[i] = self.moves.label_ids[gold_labels[i]] + if gold_heads[i] is None: + heads_array[i] = -1 + labels_array[i] = -1 + else: + heads_array[i] = gold_heads[i] + labels_array[i] = self.moves.label_ids[gold_labels[i]] py_words = [t.orth_ for t in tokens] py_moves = ['S', 'D', 'L', 'R', 'BS', 'BR'] @@ -123,6 +127,7 @@ cdef class GreedyParser: self.moves.transition(state, &guess) cdef int n_corr = 0 for i in range(tokens.length): + if gold_heads[i] != -1: n_corr += (i + state.sent[i].head) == gold_heads[i] if force_gold and n_corr != tokens.length: print py_words