mirror of https://github.com/explosion/spaCy.git
Require that all SpanGroup spans are from the current doc (#12569)
* Require that all SpanGroup spans are from the current doc The restriction on only adding spans from the current doc were already implemented for all operations except for `SpanGroup.__init__`. Initialize copied spans for `SpanGroup.copy` with `Doc.char_span` in order to validate the character offsets and to make it possible to copy spans between documents with differing tokenization. Currently there is no validation that the document texts are identical, but the span char offsets must be valid spans in the target doc, which prevents you from ending up with completely invalid spans. * Undo change in test_beam_overfitting_IO
This commit is contained in:
parent
05df59fd4a
commit
c4112a1da3
|
@ -970,6 +970,9 @@ class Errors(metaclass=ErrorsWithCodes):
|
|||
E1050 = ("Port {port} is already in use. Please specify an available port with `displacy.serve(doc, port=port)` "
|
||||
"or use `auto_select_port=True` to pick an available port automatically.")
|
||||
E1051 = ("'allow_overlap' can only be False when max_positive is 1, but found 'max_positive': {max_positive}.")
|
||||
E1052 = ("Unable to copy spans: the character offsets for the span at "
|
||||
"index {i} in the span group do not align with the tokenization "
|
||||
"in the target doc.")
|
||||
|
||||
|
||||
# Deprecated model shortcuts, only used in errors and warnings
|
||||
|
|
|
@ -93,6 +93,21 @@ def test_span_group_copy(doc):
|
|||
assert span_group.attrs["key"] == "value"
|
||||
assert list(span_group) != list(clone)
|
||||
|
||||
# can't copy if the character offsets don't align to tokens
|
||||
doc2 = Doc(doc.vocab, words=[t.text + "x" for t in doc])
|
||||
with pytest.raises(ValueError):
|
||||
span_group.copy(doc=doc2)
|
||||
|
||||
# can copy with valid character offsets despite different tokenization
|
||||
doc3 = doc.copy()
|
||||
with doc3.retokenize() as retokenizer:
|
||||
retokenizer.merge(doc3[0:2])
|
||||
retokenizer.merge(doc3[3:6])
|
||||
span_group = SpanGroup(doc, spans=[doc[0:6], doc[3:6]])
|
||||
for span1, span2 in zip(span_group, span_group.copy(doc=doc3)):
|
||||
assert span1.start_char == span2.start_char
|
||||
assert span1.end_char == span2.end_char
|
||||
|
||||
|
||||
def test_span_group_set_item(doc, other_doc):
|
||||
span_group = doc.spans["SPANS"]
|
||||
|
@ -253,3 +268,12 @@ def test_span_group_typing(doc: Doc):
|
|||
for i, span in enumerate(span_group):
|
||||
assert span == span_group[i] == spans[i]
|
||||
filter_spans(span_group)
|
||||
|
||||
|
||||
def test_span_group_init_doc(en_tokenizer):
|
||||
"""Test that all spans must come from the specified doc."""
|
||||
doc1 = en_tokenizer("a b c")
|
||||
doc2 = en_tokenizer("a b c")
|
||||
span_group = SpanGroup(doc1, spans=[doc1[0:1], doc1[1:2]])
|
||||
with pytest.raises(ValueError):
|
||||
span_group = SpanGroup(doc1, spans=[doc1[0:1], doc2[1:2]])
|
||||
|
|
|
@ -728,9 +728,9 @@ def test_neg_annotation(neg_key):
|
|||
ner.add_label("ORG")
|
||||
example = Example.from_dict(neg_doc, {"entities": [(7, 17, "PERSON")]})
|
||||
example.reference.spans[neg_key] = [
|
||||
Span(neg_doc, 2, 4, "ORG"),
|
||||
Span(neg_doc, 2, 3, "PERSON"),
|
||||
Span(neg_doc, 1, 4, "PERSON"),
|
||||
Span(example.reference, 2, 4, "ORG"),
|
||||
Span(example.reference, 2, 3, "PERSON"),
|
||||
Span(example.reference, 1, 4, "PERSON"),
|
||||
]
|
||||
|
||||
optimizer = nlp.initialize()
|
||||
|
@ -755,7 +755,7 @@ def test_neg_annotation_conflict(neg_key):
|
|||
ner.add_label("PERSON")
|
||||
ner.add_label("LOC")
|
||||
example = Example.from_dict(neg_doc, {"entities": [(7, 17, "PERSON")]})
|
||||
example.reference.spans[neg_key] = [Span(neg_doc, 2, 4, "PERSON")]
|
||||
example.reference.spans[neg_key] = [Span(example.reference, 2, 4, "PERSON")]
|
||||
assert len(example.reference.ents) == 1
|
||||
assert example.reference.ents[0].text == "Shaka Khan"
|
||||
assert example.reference.ents[0].label_ == "PERSON"
|
||||
|
@ -788,7 +788,7 @@ def test_beam_valid_parse(neg_key):
|
|||
|
||||
doc = Doc(nlp.vocab, words=tokens)
|
||||
example = Example.from_dict(doc, {"ner": iob})
|
||||
neg_span = Span(doc, 50, 53, "ORG")
|
||||
neg_span = Span(example.reference, 50, 53, "ORG")
|
||||
example.reference.spans[neg_key] = [neg_span]
|
||||
|
||||
optimizer = nlp.initialize()
|
||||
|
|
|
@ -438,14 +438,14 @@ def test_score_spans():
|
|||
return doc.spans[span_key]
|
||||
|
||||
# Predict exactly the same, but overlapping spans will be discarded
|
||||
pred.spans[key] = spans
|
||||
pred.spans[key] = gold.spans[key].copy(doc=pred)
|
||||
eg = Example(pred, gold)
|
||||
scores = Scorer.score_spans([eg], attr=key, getter=span_getter)
|
||||
assert scores[f"{key}_p"] == 1.0
|
||||
assert scores[f"{key}_r"] < 1.0
|
||||
|
||||
# Allow overlapping, now both precision and recall should be 100%
|
||||
pred.spans[key] = spans
|
||||
pred.spans[key] = gold.spans[key].copy(doc=pred)
|
||||
eg = Example(pred, gold)
|
||||
scores = Scorer.score_spans([eg], attr=key, getter=span_getter, allow_overlap=True)
|
||||
assert scores[f"{key}_p"] == 1.0
|
||||
|
|
|
@ -1264,12 +1264,14 @@ cdef class Doc:
|
|||
other.user_span_hooks = dict(self.user_span_hooks)
|
||||
other.length = self.length
|
||||
other.max_length = self.max_length
|
||||
other.spans = self.spans.copy(doc=other)
|
||||
buff_size = other.max_length + (PADDING*2)
|
||||
assert buff_size > 0
|
||||
tokens = <TokenC*>other.mem.alloc(buff_size, sizeof(TokenC))
|
||||
memcpy(tokens, self.c - PADDING, buff_size * sizeof(TokenC))
|
||||
other.c = &tokens[PADDING]
|
||||
# copy spans after setting tokens so that SpanGroup.copy can verify
|
||||
# that the start/end offsets are valid
|
||||
other.spans = self.spans.copy(doc=other)
|
||||
return other
|
||||
|
||||
def to_disk(self, path, *, exclude=tuple()):
|
||||
|
|
|
@ -52,6 +52,8 @@ cdef class SpanGroup:
|
|||
if len(spans) :
|
||||
self.c.reserve(len(spans))
|
||||
for span in spans:
|
||||
if doc is not span.doc:
|
||||
raise ValueError(Errors.E855.format(obj="span"))
|
||||
self.push_back(span.c)
|
||||
|
||||
def __repr__(self):
|
||||
|
@ -261,11 +263,22 @@ cdef class SpanGroup:
|
|||
"""
|
||||
if doc is None:
|
||||
doc = self.doc
|
||||
if doc is self.doc:
|
||||
spans = list(self)
|
||||
else:
|
||||
spans = [doc.char_span(span.start_char, span.end_char, label=span.label_, kb_id=span.kb_id, span_id=span.id) for span in self]
|
||||
for i, span in enumerate(spans):
|
||||
if span is None:
|
||||
raise ValueError(Errors.E1052.format(i=i))
|
||||
if span.kb_id in self.doc.vocab.strings:
|
||||
doc.vocab.strings.add(span.kb_id_)
|
||||
if span.id in span.doc.vocab.strings:
|
||||
doc.vocab.strings.add(span.id_)
|
||||
return SpanGroup(
|
||||
doc,
|
||||
name=self.name,
|
||||
attrs=deepcopy(self.attrs),
|
||||
spans=list(self),
|
||||
spans=spans,
|
||||
)
|
||||
|
||||
def _concat(
|
||||
|
|
Loading…
Reference in New Issue