From c4112a1da3e88f54f997ac90036e4c184d22098f Mon Sep 17 00:00:00 2001 From: Adriane Boyd Date: Thu, 1 Jun 2023 19:19:17 +0200 Subject: [PATCH] Require that all SpanGroup spans are from the current doc (#12569) * Require that all SpanGroup spans are from the current doc The restriction on only adding spans from the current doc were already implemented for all operations except for `SpanGroup.__init__`. Initialize copied spans for `SpanGroup.copy` with `Doc.char_span` in order to validate the character offsets and to make it possible to copy spans between documents with differing tokenization. Currently there is no validation that the document texts are identical, but the span char offsets must be valid spans in the target doc, which prevents you from ending up with completely invalid spans. * Undo change in test_beam_overfitting_IO --- spacy/errors.py | 3 +++ spacy/tests/doc/test_span_group.py | 24 ++++++++++++++++++++++++ spacy/tests/parser/test_ner.py | 10 +++++----- spacy/tests/test_scorer.py | 4 ++-- spacy/tokens/doc.pyx | 4 +++- spacy/tokens/span_group.pyx | 15 ++++++++++++++- 6 files changed, 51 insertions(+), 9 deletions(-) diff --git a/spacy/errors.py b/spacy/errors.py index 40cfa8d92..157110925 100644 --- a/spacy/errors.py +++ b/spacy/errors.py @@ -970,6 +970,9 @@ class Errors(metaclass=ErrorsWithCodes): E1050 = ("Port {port} is already in use. Please specify an available port with `displacy.serve(doc, port=port)` " "or use `auto_select_port=True` to pick an available port automatically.") E1051 = ("'allow_overlap' can only be False when max_positive is 1, but found 'max_positive': {max_positive}.") + E1052 = ("Unable to copy spans: the character offsets for the span at " + "index {i} in the span group do not align with the tokenization " + "in the target doc.") # Deprecated model shortcuts, only used in errors and warnings diff --git a/spacy/tests/doc/test_span_group.py b/spacy/tests/doc/test_span_group.py index 818569c64..cea2c42ee 100644 --- a/spacy/tests/doc/test_span_group.py +++ b/spacy/tests/doc/test_span_group.py @@ -93,6 +93,21 @@ def test_span_group_copy(doc): assert span_group.attrs["key"] == "value" assert list(span_group) != list(clone) + # can't copy if the character offsets don't align to tokens + doc2 = Doc(doc.vocab, words=[t.text + "x" for t in doc]) + with pytest.raises(ValueError): + span_group.copy(doc=doc2) + + # can copy with valid character offsets despite different tokenization + doc3 = doc.copy() + with doc3.retokenize() as retokenizer: + retokenizer.merge(doc3[0:2]) + retokenizer.merge(doc3[3:6]) + span_group = SpanGroup(doc, spans=[doc[0:6], doc[3:6]]) + for span1, span2 in zip(span_group, span_group.copy(doc=doc3)): + assert span1.start_char == span2.start_char + assert span1.end_char == span2.end_char + def test_span_group_set_item(doc, other_doc): span_group = doc.spans["SPANS"] @@ -253,3 +268,12 @@ def test_span_group_typing(doc: Doc): for i, span in enumerate(span_group): assert span == span_group[i] == spans[i] filter_spans(span_group) + + +def test_span_group_init_doc(en_tokenizer): + """Test that all spans must come from the specified doc.""" + doc1 = en_tokenizer("a b c") + doc2 = en_tokenizer("a b c") + span_group = SpanGroup(doc1, spans=[doc1[0:1], doc1[1:2]]) + with pytest.raises(ValueError): + span_group = SpanGroup(doc1, spans=[doc1[0:1], doc2[1:2]]) diff --git a/spacy/tests/parser/test_ner.py b/spacy/tests/parser/test_ner.py index 030182a63..7198859b3 100644 --- a/spacy/tests/parser/test_ner.py +++ b/spacy/tests/parser/test_ner.py @@ -728,9 +728,9 @@ def test_neg_annotation(neg_key): ner.add_label("ORG") example = Example.from_dict(neg_doc, {"entities": [(7, 17, "PERSON")]}) example.reference.spans[neg_key] = [ - Span(neg_doc, 2, 4, "ORG"), - Span(neg_doc, 2, 3, "PERSON"), - Span(neg_doc, 1, 4, "PERSON"), + Span(example.reference, 2, 4, "ORG"), + Span(example.reference, 2, 3, "PERSON"), + Span(example.reference, 1, 4, "PERSON"), ] optimizer = nlp.initialize() @@ -755,7 +755,7 @@ def test_neg_annotation_conflict(neg_key): ner.add_label("PERSON") ner.add_label("LOC") example = Example.from_dict(neg_doc, {"entities": [(7, 17, "PERSON")]}) - example.reference.spans[neg_key] = [Span(neg_doc, 2, 4, "PERSON")] + example.reference.spans[neg_key] = [Span(example.reference, 2, 4, "PERSON")] assert len(example.reference.ents) == 1 assert example.reference.ents[0].text == "Shaka Khan" assert example.reference.ents[0].label_ == "PERSON" @@ -788,7 +788,7 @@ def test_beam_valid_parse(neg_key): doc = Doc(nlp.vocab, words=tokens) example = Example.from_dict(doc, {"ner": iob}) - neg_span = Span(doc, 50, 53, "ORG") + neg_span = Span(example.reference, 50, 53, "ORG") example.reference.spans[neg_key] = [neg_span] optimizer = nlp.initialize() diff --git a/spacy/tests/test_scorer.py b/spacy/tests/test_scorer.py index 4b2d22986..f95c44149 100644 --- a/spacy/tests/test_scorer.py +++ b/spacy/tests/test_scorer.py @@ -438,14 +438,14 @@ def test_score_spans(): return doc.spans[span_key] # Predict exactly the same, but overlapping spans will be discarded - pred.spans[key] = spans + pred.spans[key] = gold.spans[key].copy(doc=pred) eg = Example(pred, gold) scores = Scorer.score_spans([eg], attr=key, getter=span_getter) assert scores[f"{key}_p"] == 1.0 assert scores[f"{key}_r"] < 1.0 # Allow overlapping, now both precision and recall should be 100% - pred.spans[key] = spans + pred.spans[key] = gold.spans[key].copy(doc=pred) eg = Example(pred, gold) scores = Scorer.score_spans([eg], attr=key, getter=span_getter, allow_overlap=True) assert scores[f"{key}_p"] == 1.0 diff --git a/spacy/tokens/doc.pyx b/spacy/tokens/doc.pyx index a54b4ad3c..6c196ad78 100644 --- a/spacy/tokens/doc.pyx +++ b/spacy/tokens/doc.pyx @@ -1264,12 +1264,14 @@ cdef class Doc: other.user_span_hooks = dict(self.user_span_hooks) other.length = self.length other.max_length = self.max_length - other.spans = self.spans.copy(doc=other) buff_size = other.max_length + (PADDING*2) assert buff_size > 0 tokens = other.mem.alloc(buff_size, sizeof(TokenC)) memcpy(tokens, self.c - PADDING, buff_size * sizeof(TokenC)) other.c = &tokens[PADDING] + # copy spans after setting tokens so that SpanGroup.copy can verify + # that the start/end offsets are valid + other.spans = self.spans.copy(doc=other) return other def to_disk(self, path, *, exclude=tuple()): diff --git a/spacy/tokens/span_group.pyx b/spacy/tokens/span_group.pyx index 608dda283..c748fa256 100644 --- a/spacy/tokens/span_group.pyx +++ b/spacy/tokens/span_group.pyx @@ -52,6 +52,8 @@ cdef class SpanGroup: if len(spans) : self.c.reserve(len(spans)) for span in spans: + if doc is not span.doc: + raise ValueError(Errors.E855.format(obj="span")) self.push_back(span.c) def __repr__(self): @@ -261,11 +263,22 @@ cdef class SpanGroup: """ if doc is None: doc = self.doc + if doc is self.doc: + spans = list(self) + else: + spans = [doc.char_span(span.start_char, span.end_char, label=span.label_, kb_id=span.kb_id, span_id=span.id) for span in self] + for i, span in enumerate(spans): + if span is None: + raise ValueError(Errors.E1052.format(i=i)) + if span.kb_id in self.doc.vocab.strings: + doc.vocab.strings.add(span.kb_id_) + if span.id in span.doc.vocab.strings: + doc.vocab.strings.add(span.id_) return SpanGroup( doc, name=self.name, attrs=deepcopy(self.attrs), - spans=list(self), + spans=spans, ) def _concat(