mirror of https://github.com/explosion/spaCy.git
* Allow gold parse to cut non-projective arcs
This commit is contained in:
parent
d512d20d81
commit
87d6551d19
|
@ -163,7 +163,7 @@ def _consume_ent(tags):
|
||||||
|
|
||||||
|
|
||||||
cdef class GoldParse:
|
cdef class GoldParse:
|
||||||
def __init__(self, tokens, annot_tuples, brackets=tuple()):
|
def __init__(self, tokens, annot_tuples, brackets=tuple(), make_projective=False):
|
||||||
self.mem = Pool()
|
self.mem = Pool()
|
||||||
self.loss = 0
|
self.loss = 0
|
||||||
self.length = len(tokens)
|
self.length = len(tokens)
|
||||||
|
@ -196,6 +196,24 @@ cdef class GoldParse:
|
||||||
self.heads[i] = self.gold_to_cand[annot_tuples[3][gold_i]]
|
self.heads[i] = self.gold_to_cand[annot_tuples[3][gold_i]]
|
||||||
self.labels[i] = annot_tuples[4][gold_i]
|
self.labels[i] = annot_tuples[4][gold_i]
|
||||||
self.ner[i] = annot_tuples[5][gold_i]
|
self.ner[i] = annot_tuples[5][gold_i]
|
||||||
|
|
||||||
|
# If we have any non-projective arcs, i.e. crossing brackets, consider
|
||||||
|
# the heads for those words missing in the gold-standard.
|
||||||
|
# This way, we can train from these sentences
|
||||||
|
cdef int w1, w2, h1, h2
|
||||||
|
if make_projective:
|
||||||
|
heads = list(self.heads)
|
||||||
|
for w1 in range(self.length):
|
||||||
|
if heads[w1] is not None:
|
||||||
|
h1 = heads[w1]
|
||||||
|
for w2 in range(w1+1, self.length):
|
||||||
|
if heads[w2] is not None:
|
||||||
|
h2 = heads[w2]
|
||||||
|
if _arcs_cross(w1, h1, w2, h2):
|
||||||
|
self.heads[w1] = None
|
||||||
|
self.labels[w1] = ''
|
||||||
|
self.heads[w2] = None
|
||||||
|
self.labels[w2] = ''
|
||||||
|
|
||||||
self.brackets = {}
|
self.brackets = {}
|
||||||
for (gold_start, gold_end, label_str) in brackets:
|
for (gold_start, gold_end, label_str) in brackets:
|
||||||
|
@ -210,16 +228,24 @@ cdef class GoldParse:
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def is_projective(self):
|
def is_projective(self):
|
||||||
heads = [head for (id_, word, tag, head, dep, ner) in self.orig_annot]
|
heads = list(self.heads)
|
||||||
deps = sorted([sorted(arc) for arc in enumerate(heads)])
|
for w1 in range(self.length):
|
||||||
for w1, h1 in deps:
|
if heads[w1] is not None:
|
||||||
for w2, h2 in deps:
|
h1 = heads[w1]
|
||||||
if w1 < w2 < h1 < h2:
|
for w2 in range(self.length):
|
||||||
return False
|
if heads[w2] is not None and _arcs_cross(w1, h1, w2, heads[w2]):
|
||||||
elif w1 < w2 == h2 < h1:
|
return False
|
||||||
return False
|
return True
|
||||||
else:
|
|
||||||
return True
|
|
||||||
|
cdef int _arcs_cross(int w1, int h1, int w2, int h2) except -1:
|
||||||
|
if w1 > h1:
|
||||||
|
w1, h1 = h1, w1
|
||||||
|
if w2 > h2:
|
||||||
|
w2, h2 = h2, w2
|
||||||
|
if w1 > w2:
|
||||||
|
w1, h1, w2, h2 = w2, h2, w1, h1
|
||||||
|
return w1 < w2 < h1 < h2 or w1 < w2 == h2 < h1
|
||||||
|
|
||||||
|
|
||||||
def is_punct_label(label):
|
def is_punct_label(label):
|
||||||
|
|
Loading…
Reference in New Issue