From 87d6551d1920a6c50816ec0b981b98ec76839468 Mon Sep 17 00:00:00 2001 From: Matthew Honnibal Date: Sun, 31 May 2015 01:11:56 +0200 Subject: [PATCH] * Allow gold parse to cut non-projective arcs --- spacy/gold.pyx | 48 +++++++++++++++++++++++++++++++++++++----------- 1 file changed, 37 insertions(+), 11 deletions(-) diff --git a/spacy/gold.pyx b/spacy/gold.pyx index 52416c06b..244d7afeb 100644 --- a/spacy/gold.pyx +++ b/spacy/gold.pyx @@ -163,7 +163,7 @@ def _consume_ent(tags): cdef class GoldParse: - def __init__(self, tokens, annot_tuples, brackets=tuple()): + def __init__(self, tokens, annot_tuples, brackets=tuple(), make_projective=False): self.mem = Pool() self.loss = 0 self.length = len(tokens) @@ -196,6 +196,24 @@ cdef class GoldParse: self.heads[i] = self.gold_to_cand[annot_tuples[3][gold_i]] self.labels[i] = annot_tuples[4][gold_i] self.ner[i] = annot_tuples[5][gold_i] + + # If we have any non-projective arcs, i.e. crossing brackets, consider + # the heads for those words missing in the gold-standard. + # This way, we can train from these sentences + cdef int w1, w2, h1, h2 + if make_projective: + heads = list(self.heads) + for w1 in range(self.length): + if heads[w1] is not None: + h1 = heads[w1] + for w2 in range(w1+1, self.length): + if heads[w2] is not None: + h2 = heads[w2] + if _arcs_cross(w1, h1, w2, h2): + self.heads[w1] = None + self.labels[w1] = '' + self.heads[w2] = None + self.labels[w2] = '' self.brackets = {} for (gold_start, gold_end, label_str) in brackets: @@ -210,16 +228,24 @@ cdef class GoldParse: @property def is_projective(self): - heads = [head for (id_, word, tag, head, dep, ner) in self.orig_annot] - deps = sorted([sorted(arc) for arc in enumerate(heads)]) - for w1, h1 in deps: - for w2, h2 in deps: - if w1 < w2 < h1 < h2: - return False - elif w1 < w2 == h2 < h1: - return False - else: - return True + heads = list(self.heads) + for w1 in range(self.length): + if heads[w1] is not None: + h1 = heads[w1] + for w2 in range(self.length): + if heads[w2] is not None and _arcs_cross(w1, h1, w2, heads[w2]): + return False + return True + + +cdef int _arcs_cross(int w1, int h1, int w2, int h2) except -1: + if w1 > h1: + w1, h1 = h1, w1 + if w2 > h2: + w2, h2 = h2, w2 + if w1 > w2: + w1, h1, w2, h2 = w2, h2, w1, h1 + return w1 < w2 < h1 < h2 or w1 < w2 == h2 < h1 def is_punct_label(label):