diff --git a/spacy/syntax/conll.pyx b/spacy/syntax/conll.pyx index 860adb1ad..0e0899c71 100644 --- a/spacy/syntax/conll.pyx +++ b/spacy/syntax/conll.pyx @@ -12,12 +12,25 @@ cdef class GoldParse: self.c_heads = self.mem.alloc(self.length, sizeof(int)) self.c_labels = self.mem.alloc(self.length, sizeof(int)) + @property + def n_non_punct(self): + return len([l for l in self.labels if l != 'P']) + + @property + def py_heads(self): + return [self.c_heads[i] for i in range(self.length)] + cdef int heads_correct(self, TokenC* tokens, bint score_punct=False) except -1: n = 0 for i in range(self.length): + if not score_punct and self.labels[i] == 'P': + continue n += (i + tokens[i].head) == self.c_heads[i] return n + def is_correct(self, i, head): + return head == self.c_heads[i] + @classmethod def from_conll(cls, unicode sent_str): ids = [] @@ -96,6 +109,10 @@ cdef class GoldParse: self.c_heads = self.mem.alloc(self.length, sizeof(int)) self.c_labels = self.mem.alloc(self.length, sizeof(int)) self.ids = [token.idx for token in tokens] + self.map_heads(label_ids) + return self.loss + + def map_heads(self, label_ids): mapped_heads = _map_indices_to_tokens(self.ids, self.heads) for i in range(self.length): if mapped_heads[i] is None: @@ -121,7 +138,6 @@ def _map_indices_to_tokens(ids, heads): return mapped - def _parse_line(line): pieces = line.split() if len(pieces) == 4: