spaCy/spacy/tests/doc/test_span_group.py

243 lines
7.1 KiB
Python
Raw Normal View History

Support more internal methods for SpanGroup (#10476) * Added new convenience cython functions to SpanGroup to avoid unnecessary allocation/deallocation of objects * Replaced sorting in has_overlap with C++ for efficiency. Also, added a test for has_overlap * Added a method to efficiently merge SpanGroups * Added __delitem__, __add__ and __iadd__. Also, allowed to pass span lists to merge function. Replaced extend() body with call to merge * Renamed merge to concat and added missing things to documentation * Added operator+ and operator += in the documentation * Added a test for Doc deallocation * Update spacy/tokens/span_group.pyx * Updated SpanGroup tests to use new span list comparison function rather than assert_span_list_equal, eliminating the need to have a separate assert_not_equal fnction * Fixed typos in SpanGroup documentation Co-authored-by: Sofie Van Landeghem <svlandeg@users.noreply.github.com> * Minor changes requested by Sofie: rearranged import statements. Added new=3.2.1 tag to SpanGroup.__setitem__ documentation * SpanGroup: moved repetitive list index check/adjustment in a separate function * Turn off formatting that hurts readability spacy/tests/doc/test_span_group.py Co-authored-by: Sofie Van Landeghem <svlandeg@users.noreply.github.com> * Remove formatting that hurts readability spacy/tests/doc/test_span_group.py Co-authored-by: Sofie Van Landeghem <svlandeg@users.noreply.github.com> * Turn off formatting that hurts readability in spacy/tests/doc/test_span_group.py Co-authored-by: Sofie Van Landeghem <svlandeg@users.noreply.github.com> * Support more internal methods for SpanGroup Add support for: * `__setitem__` * `__delitem__` * `__iadd__`: for `SpanGroup` or `Iterable[Span]` * `__add__`: for `SpanGroup` only Adapted from #9698 with the scope limited to the magic methods. * Use v3.3 as new version in docs * Add new tag to SpanGroup.copy in API docs * Remove duplicate import * Apply suggestions from code review Co-authored-by: Sofie Van Landeghem <svlandeg@users.noreply.github.com> * Remaining suggestions and formatting Co-authored-by: nrodnova <nrodnova@hotmail.com> Co-authored-by: Sofie Van Landeghem <svlandeg@users.noreply.github.com> Co-authored-by: Natalia Rodnova <4512370+nrodnova@users.noreply.github.com>
2022-04-01 07:56:26 +00:00
import pytest
from random import Random
from spacy.matcher import Matcher
from spacy.tokens import Span, SpanGroup
@pytest.fixture
def doc(en_tokenizer):
doc = en_tokenizer("0 1 2 3 4 5 6")
matcher = Matcher(en_tokenizer.vocab, validate=True)
# fmt: off
matcher.add("4", [[{}, {}, {}, {}]])
matcher.add("2", [[{}, {}, ]])
matcher.add("1", [[{}, ]])
# fmt: on
matches = matcher(doc)
spans = []
for match in matches:
spans.append(
Span(doc, match[1], match[2], en_tokenizer.vocab.strings[match[0]])
)
Random(42).shuffle(spans)
doc.spans["SPANS"] = SpanGroup(
doc, name="SPANS", attrs={"key": "value"}, spans=spans
)
return doc
@pytest.fixture
def other_doc(en_tokenizer):
doc = en_tokenizer("0 1 2 3 4 5 6")
matcher = Matcher(en_tokenizer.vocab, validate=True)
# fmt: off
matcher.add("4", [[{}, {}, {}, {}]])
matcher.add("2", [[{}, {}, ]])
matcher.add("1", [[{}, ]])
# fmt: on
matches = matcher(doc)
spans = []
for match in matches:
spans.append(
Span(doc, match[1], match[2], en_tokenizer.vocab.strings[match[0]])
)
Random(42).shuffle(spans)
doc.spans["SPANS"] = SpanGroup(
doc, name="SPANS", attrs={"key": "value"}, spans=spans
)
return doc
@pytest.fixture
def span_group(en_tokenizer):
doc = en_tokenizer("0 1 2 3 4 5 6")
matcher = Matcher(en_tokenizer.vocab, validate=True)
# fmt: off
matcher.add("4", [[{}, {}, {}, {}]])
matcher.add("2", [[{}, {}, ]])
matcher.add("1", [[{}, ]])
# fmt: on
matches = matcher(doc)
spans = []
for match in matches:
spans.append(
Span(doc, match[1], match[2], en_tokenizer.vocab.strings[match[0]])
)
Random(42).shuffle(spans)
doc.spans["SPANS"] = SpanGroup(
doc, name="SPANS", attrs={"key": "value"}, spans=spans
)
def test_span_group_copy(doc):
span_group = doc.spans["SPANS"]
clone = span_group.copy()
assert clone != span_group
assert clone.name == span_group.name
assert clone.attrs == span_group.attrs
assert len(clone) == len(span_group)
assert list(span_group) == list(clone)
clone.name = "new_name"
clone.attrs["key"] = "new_value"
clone.append(Span(doc, 0, 6, "LABEL"))
assert clone.name != span_group.name
assert clone.attrs != span_group.attrs
assert span_group.attrs["key"] == "value"
assert list(span_group) != list(clone)
def test_span_group_set_item(doc, other_doc):
span_group = doc.spans["SPANS"]
index = 5
span = span_group[index]
span.label_ = "NEW LABEL"
span.kb_id = doc.vocab.strings["KB_ID"]
assert span_group[index].label != span.label
assert span_group[index].kb_id != span.kb_id
span_group[index] = span
assert span_group[index].start == span.start
assert span_group[index].end == span.end
assert span_group[index].label == span.label
assert span_group[index].kb_id == span.kb_id
assert span_group[index] == span
with pytest.raises(IndexError):
span_group[-100] = span
with pytest.raises(IndexError):
span_group[100] = span
span = Span(other_doc, 0, 2)
with pytest.raises(ValueError):
span_group[index] = span
def test_span_group_has_overlap(doc):
span_group = doc.spans["SPANS"]
assert span_group.has_overlap
def test_span_group_concat(doc, other_doc):
span_group_1 = doc.spans["SPANS"]
spans = [doc[0:5], doc[0:6]]
span_group_2 = SpanGroup(
doc,
name="MORE_SPANS",
attrs={"key": "new_value", "new_key": "new_value"},
spans=spans,
)
span_group_3 = span_group_1._concat(span_group_2)
assert span_group_3.name == span_group_1.name
assert span_group_3.attrs == {"key": "value", "new_key": "new_value"}
span_list_expected = list(span_group_1) + list(span_group_2)
assert list(span_group_3) == list(span_list_expected)
# Inplace
span_list_expected = list(span_group_1) + list(span_group_2)
span_group_3 = span_group_1._concat(span_group_2, inplace=True)
assert span_group_3 == span_group_1
assert span_group_3.name == span_group_1.name
assert span_group_3.attrs == {"key": "value", "new_key": "new_value"}
assert list(span_group_3) == list(span_list_expected)
span_group_2 = other_doc.spans["SPANS"]
with pytest.raises(ValueError):
span_group_1._concat(span_group_2)
def test_span_doc_delitem(doc):
span_group = doc.spans["SPANS"]
length = len(span_group)
index = 5
span = span_group[index]
next_span = span_group[index + 1]
del span_group[index]
assert len(span_group) == length - 1
assert span_group[index] != span
assert span_group[index] == next_span
with pytest.raises(IndexError):
del span_group[-100]
with pytest.raises(IndexError):
del span_group[100]
def test_span_group_add(doc):
span_group_1 = doc.spans["SPANS"]
spans = [doc[0:5], doc[0:6]]
span_group_2 = SpanGroup(
doc,
name="MORE_SPANS",
attrs={"key": "new_value", "new_key": "new_value"},
spans=spans,
)
span_group_3_expected = span_group_1._concat(span_group_2)
span_group_3 = span_group_1 + span_group_2
assert len(span_group_3) == len(span_group_3_expected)
assert span_group_3.attrs == {"key": "value", "new_key": "new_value"}
assert list(span_group_3) == list(span_group_3_expected)
def test_span_group_iadd(doc):
span_group_1 = doc.spans["SPANS"].copy()
spans = [doc[0:5], doc[0:6]]
span_group_2 = SpanGroup(
doc,
name="MORE_SPANS",
attrs={"key": "new_value", "new_key": "new_value"},
spans=spans,
)
span_group_1_expected = span_group_1._concat(span_group_2)
span_group_1 += span_group_2
assert len(span_group_1) == len(span_group_1_expected)
assert span_group_1.attrs == {"key": "value", "new_key": "new_value"}
assert list(span_group_1) == list(span_group_1_expected)
span_group_1 = doc.spans["SPANS"].copy()
span_group_1 += spans
assert len(span_group_1) == len(span_group_1_expected)
assert span_group_1.attrs == {
"key": "value",
}
assert list(span_group_1) == list(span_group_1_expected)
def test_span_group_extend(doc):
span_group_1 = doc.spans["SPANS"].copy()
spans = [doc[0:5], doc[0:6]]
span_group_2 = SpanGroup(
doc,
name="MORE_SPANS",
attrs={"key": "new_value", "new_key": "new_value"},
spans=spans,
)
span_group_1_expected = span_group_1._concat(span_group_2)
span_group_1.extend(span_group_2)
assert len(span_group_1) == len(span_group_1_expected)
assert span_group_1.attrs == {"key": "value", "new_key": "new_value"}
assert list(span_group_1) == list(span_group_1_expected)
span_group_1 = doc.spans["SPANS"]
span_group_1.extend(spans)
assert len(span_group_1) == len(span_group_1_expected)
assert span_group_1.attrs == {"key": "value"}
assert list(span_group_1) == list(span_group_1_expected)
def test_span_group_dealloc(span_group):
with pytest.raises(AttributeError):
print(span_group.doc)