Fix start/end chars for empty and out-of-bounds spans (#8816)

This commit is contained in:
Adriane Boyd 2021-08-02 19:07:19 +02:00 committed by GitHub
parent 9ad3b8cf8d
commit fbbbda1954
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 37 additions and 2 deletions

View File

@ -357,6 +357,9 @@ def test_span_eq_hash(doc, doc_not_parsed):
assert hash(doc[0:2]) != hash(doc[1:3]) assert hash(doc[0:2]) != hash(doc[1:3])
assert hash(doc[0:2]) != hash(doc_not_parsed[0:2]) assert hash(doc[0:2]) != hash(doc_not_parsed[0:2])
# check that an out-of-bounds is not equivalent to the span of the full doc
assert doc[0 : len(doc)] != doc[len(doc) : len(doc) + 1]
def test_span_boundaries(doc): def test_span_boundaries(doc):
start = 1 start = 1
@ -369,6 +372,33 @@ def test_span_boundaries(doc):
with pytest.raises(IndexError): with pytest.raises(IndexError):
span[5] span[5]
empty_span_0 = doc[0:0]
assert empty_span_0.text == ""
assert empty_span_0.start == 0
assert empty_span_0.end == 0
assert empty_span_0.start_char == 0
assert empty_span_0.end_char == 0
empty_span_1 = doc[1:1]
assert empty_span_1.text == ""
assert empty_span_1.start == 1
assert empty_span_1.end == 1
assert empty_span_1.start_char == empty_span_1.end_char
oob_span_start = doc[-len(doc) - 1 : -len(doc) - 10]
assert oob_span_start.text == ""
assert oob_span_start.start == 0
assert oob_span_start.end == 0
assert oob_span_start.start_char == 0
assert oob_span_start.end_char == 0
oob_span_end = doc[len(doc) + 1 : len(doc) + 10]
assert oob_span_end.text == ""
assert oob_span_end.start == len(doc)
assert oob_span_end.end == len(doc)
assert oob_span_end.start_char == len(doc.text)
assert oob_span_end.end_char == len(doc.text)
def test_span_lemma(doc): def test_span_lemma(doc):
# span lemmas should have the same number of spaces as the span # span lemmas should have the same number of spaces as the span

View File

@ -105,13 +105,18 @@ cdef class Span:
if label not in doc.vocab.strings: if label not in doc.vocab.strings:
raise ValueError(Errors.E084.format(label=label)) raise ValueError(Errors.E084.format(label=label))
start_char = doc[start].idx if start < doc.length else len(doc.text)
if start == end:
end_char = start_char
else:
end_char = doc[end - 1].idx + len(doc[end - 1])
self.c = SpanC( self.c = SpanC(
label=label, label=label,
kb_id=kb_id, kb_id=kb_id,
start=start, start=start,
end=end, end=end,
start_char=doc[start].idx if start < doc.length else 0, start_char=start_char,
end_char=doc[end - 1].idx + len(doc[end - 1]) if end >= 1 else 0, end_char=end_char,
) )
self._vector = vector self._vector = vector
self._vector_norm = vector_norm self._vector_norm = vector_norm