fix sent_starts

This commit is contained in:
svlandeg 2021-01-13 13:47:25 +01:00
parent 232e953b14
commit 86a4e316b8
5 changed files with 37 additions and 2 deletions

View File

@ -261,14 +261,18 @@ def test_missing_head_dep(en_vocab):
doc = Doc(en_vocab, words=words, heads=heads, deps=deps)
pred_has_heads = [t.has_head() for t in doc]
pred_deps = [t.dep_ for t in doc]
pred_sent_starts = [t.is_sent_start for t in doc]
assert pred_has_heads == [True, True, True, True, True, False]
assert pred_deps == ["nsubj", "ROOT", "dobj", "cc", "conj", MISSING_DEP_]
assert pred_sent_starts == [True, False, False, False, False, False]
example = Example.from_dict(doc, {"heads": heads, "deps": deps})
ref_heads = [t.head.i for t in example.reference]
ref_deps = [t.dep_ for t in example.reference]
ref_has_heads = [t.has_head() for t in example.reference]
ref_sent_starts = [t.is_sent_start for t in example.reference]
assert ref_deps == ["nsubj", "ROOT", "dobj", "cc", "conj", MISSING_DEP_]
assert ref_has_heads == [True, True, True, True, True, False]
assert ref_sent_starts == [True, False, False, False, False, False]
aligned_heads, aligned_deps = example.get_aligned_parse(projectivize=True)
assert aligned_heads[5] == ref_heads[5]
assert aligned_deps[5] == MISSING_DEP_

View File

@ -282,3 +282,24 @@ def test_Example_missing_deps():
# when providing deps, the head information is actually used
example_2 = Example.from_dict(predicted, annots_head_dep)
assert [t.head.i for t in example_2.reference] == heads
def test_Example_missing_heads():
vocab = Vocab()
words = ["I", "like", "London", "and", "Berlin", "."]
deps = ["nsubj", "ROOT", "dobj", None, "conj", "punct"]
heads = [1, 1, 1, None, 2, 1]
annots = {"words": words, "heads": heads, "deps": deps}
predicted = Doc(vocab, words=words)
example = Example.from_dict(predicted, annots)
parsed_heads = [t.head.i for t in example.reference]
assert parsed_heads[0] == heads[0]
assert parsed_heads[1] == heads[1]
assert parsed_heads[2] == heads[2]
assert parsed_heads[4] == heads[4]
assert parsed_heads[5] == heads[5]
assert [t.has_head() for t in example.reference] == [True, True, True, False, True, True]
# Ensure that the missing head doesn't create an artificial new sentence start
assert example.get_aligned_sent_starts() == [True, False, False, False, False, False]

View File

@ -1540,7 +1540,7 @@ cdef int set_children_from_heads(TokenC* tokens, int start, int end) except -1:
for i in range(start, end):
tokens[i].sent_start = -1
for i in range(start, end):
if tokens[i].head == 0:
if tokens[i].head == 0 and not Token.missing_head(&tokens[i]):
tokens[tokens[i].l_edge].sent_start = 1

View File

@ -94,3 +94,10 @@ cdef class Token:
token.ent_kb_id = value
elif feat_name == SENT_START:
token.sent_start = value
@staticmethod
cdef inline int missing_head(const TokenC* token) nogil:
if token.dep == 0:
return 1
else:
return 0

View File

@ -184,7 +184,10 @@ cdef class Example:
heads = [token.head.i for token in self.y]
deps = [token.dep_ for token in self.y]
if projectivize:
heads, deps = nonproj.projectivize(heads, deps)
proj_heads, proj_deps = nonproj.projectivize(heads, deps)
# ensure that data that was previously missing, remains missing
heads = [h if has_heads[i] else heads[i] for i, h in enumerate(proj_heads)]
deps = [d if deps[i] != MISSING_DEP_ else MISSING_DEP_ for i, d in enumerate(proj_deps)]
for cand_i in range(self.x.length):
if cand_to_gold.lengths[cand_i] == 1:
gold_i = cand_to_gold[cand_i].dataXd[0, 0]