diff --git a/spacy/syntax/conll.pxd b/spacy/syntax/conll.pxd index 508c575c0..6fc27b151 100644 --- a/spacy/syntax/conll.pxd +++ b/spacy/syntax/conll.pxd @@ -18,10 +18,12 @@ cdef class GoldParse: cdef readonly list ents cdef readonly dict brackets + cdef readonly list cand_to_gold + cdef readonly list gold_to_cand + cdef readonly list orig_annot + cdef int* c_tags cdef int* c_heads cdef int* c_labels cdef int** c_brackets cdef Transition* c_ner - - cdef int heads_correct(self, TokenC* tokens, bint score_punct=?) except -1 diff --git a/spacy/syntax/conll.pyx b/spacy/syntax/conll.pyx index a84a73d5e..974f8c65a 100644 --- a/spacy/syntax/conll.pyx +++ b/spacy/syntax/conll.pyx @@ -162,18 +162,20 @@ cdef class GoldParse: self.labels = [''] * len(tokens) self.ner = ['-'] * len(tokens) - cand_to_gold = align([t.orth_ for t in tokens], annot_tuples[1]) - gold_to_cand = align(annot_tuples[1], [t.orth_ for t in tokens]) + self.cand_to_gold = align([t.orth_ for t in tokens], annot_tuples[1]) + self.gold_to_cand = align(annot_tuples[1], [t.orth_ for t in tokens]) + + self.orig_annot = zip(*annot_tuples) self.ents = [] - for i, gold_i in enumerate(cand_to_gold): + for i, gold_i in enumerate(self.cand_to_gold): if gold_i is None: # TODO: What do we do for missing values again? pass else: self.tags[i] = annot_tuples[2][gold_i] - self.heads[i] = gold_to_cand[annot_tuples[3][gold_i]] + self.heads[i] = self.gold_to_cand[annot_tuples[3][gold_i]] self.labels[i] = annot_tuples[4][gold_i] # TODO: Declare NER information MISSING if tokenization incorrect for start, end, label in self.ents: @@ -187,8 +189,8 @@ cdef class GoldParse: self.brackets = {} for (gold_start, gold_end, label_str) in brackets: - start = gold_to_cand[gold_start] - end = gold_to_cand[gold_end] + start = self.gold_to_cand[gold_start] + end = self.gold_to_cand[gold_end] if start is not None and end is not None: self.brackets.setdefault(start, {}).setdefault(end, set()) self.brackets[end][start].add(label) @@ -196,33 +198,6 @@ cdef class GoldParse: def __len__(self): return self.length - @property - def n_non_punct(self): - return len([l for l in self.labels if l not in ('P', 'punct')]) - - cdef int heads_correct(self, TokenC* tokens, bint score_punct=False) except -1: - n = 0 - for i in range(self.length): - if not score_punct and self.labels_[i] not in ('P', 'punct'): - continue - if self.heads[i] == -1: - continue - n += (i + tokens[i].head) == self.heads[i] - return n - - def is_correct(self, i, head): - return head == self.c_heads[i] - def is_punct_label(label): return label == 'P' or label.lower() == 'punct' - - -def _map_indices_to_tokens(ids, heads): - mapped = [] - for head in heads: - if head not in ids: - mapped.append(None) - else: - mapped.append(ids.index(head)) - return mapped