From 212f0e779eeb1e1f66619bcb50739e4dbf90f4d5 Mon Sep 17 00:00:00 2001 From: Sofie Van Landeghem Date: Tue, 2 Mar 2021 15:12:54 +0100 Subject: [PATCH] Support doc.spans in Example.from_dict (#7197) * add support for spans in Example.from_dict * add unit tests * update error to E879 --- spacy/errors.py | 6 +- spacy/tests/training/test_new_example.py | 98 ++++++++++++++++++++++++ spacy/training/example.pyx | 31 +++++++- 3 files changed, 130 insertions(+), 5 deletions(-) diff --git a/spacy/errors.py b/spacy/errors.py index fc98fdaa6..2ebc49e8c 100644 --- a/spacy/errors.py +++ b/spacy/errors.py @@ -321,7 +321,8 @@ class Errors: "https://spacy.io/api/top-level#util.filter_spans") E103 = ("Trying to set conflicting doc.ents: '{span1}' and '{span2}'. A " "token can only be part of one entity, so make sure the entities " - "you're setting don't overlap.") + "you're setting don't overlap. To work with overlapping entities, " + "consider using doc.spans instead.") E106 = ("Can't find `doc._.{attr}` attribute specified in the underscore " "settings: {opts}") E107 = ("Value of `doc._.{attr}` is not JSON-serializable: {value}") @@ -487,6 +488,9 @@ class Errors: # New errors added in v3.x + E879 = ("Unexpected type for 'spans' data. Provide a dictionary mapping keys to " + "a list of spans, with each span represented by a tuple (start_char, end_char). " + "The tuple can be optionally extended with a label and a KB ID.") E880 = ("The 'wandb' library could not be found - did you install it? " "Alternatively, specify the 'ConsoleLogger' in the 'training.logger' " "config section, instead of the 'WandbLogger'.") diff --git a/spacy/tests/training/test_new_example.py b/spacy/tests/training/test_new_example.py index be3419b82..b8fbaf606 100644 --- a/spacy/tests/training/test_new_example.py +++ b/spacy/tests/training/test_new_example.py @@ -196,6 +196,104 @@ def test_Example_from_dict_with_entities_invalid(annots): assert len(list(example.reference.ents)) == 0 +@pytest.mark.parametrize( + "annots", + [ + { + "words": ["I", "like", "New", "York", "and", "Berlin", "."], + "entities": [ + (7, 15, "LOC"), + (11, 15, "LOC"), + (20, 26, "LOC"), + ], # overlapping + } + ], +) +def test_Example_from_dict_with_entities_overlapping(annots): + vocab = Vocab() + predicted = Doc(vocab, words=annots["words"]) + with pytest.raises(ValueError): + Example.from_dict(predicted, annots) + + +@pytest.mark.parametrize( + "annots", + [ + { + "words": ["I", "like", "New", "York", "and", "Berlin", "."], + "spans": { + "cities": [(7, 15, "LOC"), (20, 26, "LOC")], + "people": [(0, 1, "PERSON")], + }, + } + ], +) +def test_Example_from_dict_with_spans(annots): + vocab = Vocab() + predicted = Doc(vocab, words=annots["words"]) + example = Example.from_dict(predicted, annots) + assert len(list(example.reference.ents)) == 0 + assert len(list(example.reference.spans["cities"])) == 2 + assert len(list(example.reference.spans["people"])) == 1 + for span in example.reference.spans["cities"]: + assert span.label_ == "LOC" + for span in example.reference.spans["people"]: + assert span.label_ == "PERSON" + + +@pytest.mark.parametrize( + "annots", + [ + { + "words": ["I", "like", "New", "York", "and", "Berlin", "."], + "spans": { + "cities": [(7, 15, "LOC"), (11, 15, "LOC"), (20, 26, "LOC")], + "people": [(0, 1, "PERSON")], + }, + } + ], +) +def test_Example_from_dict_with_spans_overlapping(annots): + vocab = Vocab() + predicted = Doc(vocab, words=annots["words"]) + example = Example.from_dict(predicted, annots) + assert len(list(example.reference.ents)) == 0 + assert len(list(example.reference.spans["cities"])) == 3 + assert len(list(example.reference.spans["people"])) == 1 + for span in example.reference.spans["cities"]: + assert span.label_ == "LOC" + for span in example.reference.spans["people"]: + assert span.label_ == "PERSON" + + +@pytest.mark.parametrize( + "annots", + [ + { + "words": ["I", "like", "New", "York", "and", "Berlin", "."], + "spans": [(0, 1, "PERSON")], + }, + { + "words": ["I", "like", "New", "York", "and", "Berlin", "."], + "spans": {"cities": (7, 15, "LOC")}, + }, + { + "words": ["I", "like", "New", "York", "and", "Berlin", "."], + "spans": {"cities": [7, 11]}, + }, + { + "words": ["I", "like", "New", "York", "and", "Berlin", "."], + "spans": {"cities": [[7]]}, + }, + ], +) +def test_Example_from_dict_with_spans_invalid(annots): + vocab = Vocab() + predicted = Doc(vocab, words=annots["words"]) + with pytest.raises(ValueError): + Example.from_dict(predicted, annots) + + @pytest.mark.parametrize( "annots", [ diff --git a/spacy/training/example.pyx b/spacy/training/example.pyx index dc1c74e8a..9cf825bf9 100644 --- a/spacy/training/example.pyx +++ b/spacy/training/example.pyx @@ -22,6 +22,8 @@ cpdef Doc annotations_to_doc(vocab, tok_annot, doc_annot): output = Doc(vocab, words=tok_annot["ORTH"], spaces=tok_annot["SPACY"]) if "entities" in doc_annot: _add_entities_to_doc(output, doc_annot["entities"]) + if "spans" in doc_annot: + _add_spans_to_doc(output, doc_annot["spans"]) if array.size: output = output.from_array(attrs, array) # links are currently added with ENT_KB_ID on the token level @@ -314,13 +316,11 @@ def _annot2array(vocab, tok_annot, doc_annot): for key, value in doc_annot.items(): if value: - if key == "entities": + if key in ["entities", "cats", "spans"]: pass elif key == "links": ent_kb_ids = _parse_links(vocab, tok_annot["ORTH"], tok_annot["SPACY"], value) tok_annot["ENT_KB_ID"] = ent_kb_ids - elif key == "cats": - pass else: raise ValueError(Errors.E974.format(obj="doc", key=key)) @@ -351,6 +351,29 @@ def _annot2array(vocab, tok_annot, doc_annot): return attrs, array.T +def _add_spans_to_doc(doc, spans_data): + if not isinstance(spans_data, dict): + raise ValueError(Errors.E879) + for key, span_list in spans_data.items(): + spans = [] + if not isinstance(span_list, list): + raise ValueError(Errors.E879) + for span_tuple in span_list: + if not isinstance(span_tuple, (list, tuple)) or len(span_tuple) < 2: + raise ValueError(Errors.E879) + start_char = span_tuple[0] + end_char = span_tuple[1] + label = 0 + kb_id = 0 + if len(span_tuple) > 2: + label = span_tuple[2] + if len(span_tuple) > 3: + kb_id = span_tuple[3] + span = doc.char_span(start_char, end_char, label=label, kb_id=kb_id) + spans.append(span) + doc.spans[key] = spans + + def _add_entities_to_doc(doc, ner_data): if ner_data is None: return @@ -397,7 +420,7 @@ def _fix_legacy_dict_data(example_dict): pass elif key == "ids": pass - elif key in ("cats", "links"): + elif key in ("cats", "links", "spans"): doc_dict[key] = value elif key in ("ner", "entities"): doc_dict["entities"] = value