mirror of https://github.com/explosion/spaCy.git
243 lines
7.1 KiB
Python
243 lines
7.1 KiB
Python
|
import pytest
|
||
|
from random import Random
|
||
|
from spacy.matcher import Matcher
|
||
|
from spacy.tokens import Span, SpanGroup
|
||
|
|
||
|
|
||
|
@pytest.fixture
|
||
|
def doc(en_tokenizer):
|
||
|
doc = en_tokenizer("0 1 2 3 4 5 6")
|
||
|
matcher = Matcher(en_tokenizer.vocab, validate=True)
|
||
|
|
||
|
# fmt: off
|
||
|
matcher.add("4", [[{}, {}, {}, {}]])
|
||
|
matcher.add("2", [[{}, {}, ]])
|
||
|
matcher.add("1", [[{}, ]])
|
||
|
# fmt: on
|
||
|
matches = matcher(doc)
|
||
|
spans = []
|
||
|
for match in matches:
|
||
|
spans.append(
|
||
|
Span(doc, match[1], match[2], en_tokenizer.vocab.strings[match[0]])
|
||
|
)
|
||
|
Random(42).shuffle(spans)
|
||
|
doc.spans["SPANS"] = SpanGroup(
|
||
|
doc, name="SPANS", attrs={"key": "value"}, spans=spans
|
||
|
)
|
||
|
return doc
|
||
|
|
||
|
|
||
|
@pytest.fixture
|
||
|
def other_doc(en_tokenizer):
|
||
|
doc = en_tokenizer("0 1 2 3 4 5 6")
|
||
|
matcher = Matcher(en_tokenizer.vocab, validate=True)
|
||
|
|
||
|
# fmt: off
|
||
|
matcher.add("4", [[{}, {}, {}, {}]])
|
||
|
matcher.add("2", [[{}, {}, ]])
|
||
|
matcher.add("1", [[{}, ]])
|
||
|
# fmt: on
|
||
|
|
||
|
matches = matcher(doc)
|
||
|
spans = []
|
||
|
for match in matches:
|
||
|
spans.append(
|
||
|
Span(doc, match[1], match[2], en_tokenizer.vocab.strings[match[0]])
|
||
|
)
|
||
|
Random(42).shuffle(spans)
|
||
|
doc.spans["SPANS"] = SpanGroup(
|
||
|
doc, name="SPANS", attrs={"key": "value"}, spans=spans
|
||
|
)
|
||
|
return doc
|
||
|
|
||
|
|
||
|
@pytest.fixture
|
||
|
def span_group(en_tokenizer):
|
||
|
doc = en_tokenizer("0 1 2 3 4 5 6")
|
||
|
matcher = Matcher(en_tokenizer.vocab, validate=True)
|
||
|
|
||
|
# fmt: off
|
||
|
matcher.add("4", [[{}, {}, {}, {}]])
|
||
|
matcher.add("2", [[{}, {}, ]])
|
||
|
matcher.add("1", [[{}, ]])
|
||
|
# fmt: on
|
||
|
|
||
|
matches = matcher(doc)
|
||
|
spans = []
|
||
|
for match in matches:
|
||
|
spans.append(
|
||
|
Span(doc, match[1], match[2], en_tokenizer.vocab.strings[match[0]])
|
||
|
)
|
||
|
Random(42).shuffle(spans)
|
||
|
doc.spans["SPANS"] = SpanGroup(
|
||
|
doc, name="SPANS", attrs={"key": "value"}, spans=spans
|
||
|
)
|
||
|
|
||
|
|
||
|
def test_span_group_copy(doc):
|
||
|
span_group = doc.spans["SPANS"]
|
||
|
clone = span_group.copy()
|
||
|
assert clone != span_group
|
||
|
assert clone.name == span_group.name
|
||
|
assert clone.attrs == span_group.attrs
|
||
|
assert len(clone) == len(span_group)
|
||
|
assert list(span_group) == list(clone)
|
||
|
clone.name = "new_name"
|
||
|
clone.attrs["key"] = "new_value"
|
||
|
clone.append(Span(doc, 0, 6, "LABEL"))
|
||
|
assert clone.name != span_group.name
|
||
|
assert clone.attrs != span_group.attrs
|
||
|
assert span_group.attrs["key"] == "value"
|
||
|
assert list(span_group) != list(clone)
|
||
|
|
||
|
|
||
|
def test_span_group_set_item(doc, other_doc):
|
||
|
span_group = doc.spans["SPANS"]
|
||
|
|
||
|
index = 5
|
||
|
span = span_group[index]
|
||
|
span.label_ = "NEW LABEL"
|
||
|
span.kb_id = doc.vocab.strings["KB_ID"]
|
||
|
|
||
|
assert span_group[index].label != span.label
|
||
|
assert span_group[index].kb_id != span.kb_id
|
||
|
|
||
|
span_group[index] = span
|
||
|
assert span_group[index].start == span.start
|
||
|
assert span_group[index].end == span.end
|
||
|
assert span_group[index].label == span.label
|
||
|
assert span_group[index].kb_id == span.kb_id
|
||
|
assert span_group[index] == span
|
||
|
|
||
|
with pytest.raises(IndexError):
|
||
|
span_group[-100] = span
|
||
|
with pytest.raises(IndexError):
|
||
|
span_group[100] = span
|
||
|
|
||
|
span = Span(other_doc, 0, 2)
|
||
|
with pytest.raises(ValueError):
|
||
|
span_group[index] = span
|
||
|
|
||
|
|
||
|
def test_span_group_has_overlap(doc):
|
||
|
span_group = doc.spans["SPANS"]
|
||
|
assert span_group.has_overlap
|
||
|
|
||
|
|
||
|
def test_span_group_concat(doc, other_doc):
|
||
|
span_group_1 = doc.spans["SPANS"]
|
||
|
spans = [doc[0:5], doc[0:6]]
|
||
|
span_group_2 = SpanGroup(
|
||
|
doc,
|
||
|
name="MORE_SPANS",
|
||
|
attrs={"key": "new_value", "new_key": "new_value"},
|
||
|
spans=spans,
|
||
|
)
|
||
|
span_group_3 = span_group_1._concat(span_group_2)
|
||
|
assert span_group_3.name == span_group_1.name
|
||
|
assert span_group_3.attrs == {"key": "value", "new_key": "new_value"}
|
||
|
span_list_expected = list(span_group_1) + list(span_group_2)
|
||
|
assert list(span_group_3) == list(span_list_expected)
|
||
|
|
||
|
# Inplace
|
||
|
span_list_expected = list(span_group_1) + list(span_group_2)
|
||
|
span_group_3 = span_group_1._concat(span_group_2, inplace=True)
|
||
|
assert span_group_3 == span_group_1
|
||
|
assert span_group_3.name == span_group_1.name
|
||
|
assert span_group_3.attrs == {"key": "value", "new_key": "new_value"}
|
||
|
assert list(span_group_3) == list(span_list_expected)
|
||
|
|
||
|
span_group_2 = other_doc.spans["SPANS"]
|
||
|
with pytest.raises(ValueError):
|
||
|
span_group_1._concat(span_group_2)
|
||
|
|
||
|
|
||
|
def test_span_doc_delitem(doc):
|
||
|
span_group = doc.spans["SPANS"]
|
||
|
length = len(span_group)
|
||
|
index = 5
|
||
|
span = span_group[index]
|
||
|
next_span = span_group[index + 1]
|
||
|
del span_group[index]
|
||
|
assert len(span_group) == length - 1
|
||
|
assert span_group[index] != span
|
||
|
assert span_group[index] == next_span
|
||
|
|
||
|
with pytest.raises(IndexError):
|
||
|
del span_group[-100]
|
||
|
with pytest.raises(IndexError):
|
||
|
del span_group[100]
|
||
|
|
||
|
|
||
|
def test_span_group_add(doc):
|
||
|
span_group_1 = doc.spans["SPANS"]
|
||
|
spans = [doc[0:5], doc[0:6]]
|
||
|
span_group_2 = SpanGroup(
|
||
|
doc,
|
||
|
name="MORE_SPANS",
|
||
|
attrs={"key": "new_value", "new_key": "new_value"},
|
||
|
spans=spans,
|
||
|
)
|
||
|
|
||
|
span_group_3_expected = span_group_1._concat(span_group_2)
|
||
|
|
||
|
span_group_3 = span_group_1 + span_group_2
|
||
|
assert len(span_group_3) == len(span_group_3_expected)
|
||
|
assert span_group_3.attrs == {"key": "value", "new_key": "new_value"}
|
||
|
assert list(span_group_3) == list(span_group_3_expected)
|
||
|
|
||
|
|
||
|
def test_span_group_iadd(doc):
|
||
|
span_group_1 = doc.spans["SPANS"].copy()
|
||
|
spans = [doc[0:5], doc[0:6]]
|
||
|
span_group_2 = SpanGroup(
|
||
|
doc,
|
||
|
name="MORE_SPANS",
|
||
|
attrs={"key": "new_value", "new_key": "new_value"},
|
||
|
spans=spans,
|
||
|
)
|
||
|
|
||
|
span_group_1_expected = span_group_1._concat(span_group_2)
|
||
|
|
||
|
span_group_1 += span_group_2
|
||
|
assert len(span_group_1) == len(span_group_1_expected)
|
||
|
assert span_group_1.attrs == {"key": "value", "new_key": "new_value"}
|
||
|
assert list(span_group_1) == list(span_group_1_expected)
|
||
|
|
||
|
span_group_1 = doc.spans["SPANS"].copy()
|
||
|
span_group_1 += spans
|
||
|
assert len(span_group_1) == len(span_group_1_expected)
|
||
|
assert span_group_1.attrs == {
|
||
|
"key": "value",
|
||
|
}
|
||
|
assert list(span_group_1) == list(span_group_1_expected)
|
||
|
|
||
|
|
||
|
def test_span_group_extend(doc):
|
||
|
span_group_1 = doc.spans["SPANS"].copy()
|
||
|
spans = [doc[0:5], doc[0:6]]
|
||
|
span_group_2 = SpanGroup(
|
||
|
doc,
|
||
|
name="MORE_SPANS",
|
||
|
attrs={"key": "new_value", "new_key": "new_value"},
|
||
|
spans=spans,
|
||
|
)
|
||
|
|
||
|
span_group_1_expected = span_group_1._concat(span_group_2)
|
||
|
|
||
|
span_group_1.extend(span_group_2)
|
||
|
assert len(span_group_1) == len(span_group_1_expected)
|
||
|
assert span_group_1.attrs == {"key": "value", "new_key": "new_value"}
|
||
|
assert list(span_group_1) == list(span_group_1_expected)
|
||
|
|
||
|
span_group_1 = doc.spans["SPANS"]
|
||
|
span_group_1.extend(spans)
|
||
|
assert len(span_group_1) == len(span_group_1_expected)
|
||
|
assert span_group_1.attrs == {"key": "value"}
|
||
|
assert list(span_group_1) == list(span_group_1_expected)
|
||
|
|
||
|
|
||
|
def test_span_group_dealloc(span_group):
|
||
|
with pytest.raises(AttributeError):
|
||
|
print(span_group.doc)
|