mirror of https://github.com/explosion/spaCy.git
Support doc.spans in Example.from_dict (#7197)
* add support for spans in Example.from_dict * add unit tests * update error to E879
This commit is contained in:
parent
fb98862337
commit
212f0e779e
|
@ -321,7 +321,8 @@ class Errors:
|
|||
"https://spacy.io/api/top-level#util.filter_spans")
|
||||
E103 = ("Trying to set conflicting doc.ents: '{span1}' and '{span2}'. A "
|
||||
"token can only be part of one entity, so make sure the entities "
|
||||
"you're setting don't overlap.")
|
||||
"you're setting don't overlap. To work with overlapping entities, "
|
||||
"consider using doc.spans instead.")
|
||||
E106 = ("Can't find `doc._.{attr}` attribute specified in the underscore "
|
||||
"settings: {opts}")
|
||||
E107 = ("Value of `doc._.{attr}` is not JSON-serializable: {value}")
|
||||
|
@ -487,6 +488,9 @@ class Errors:
|
|||
|
||||
# New errors added in v3.x
|
||||
|
||||
E879 = ("Unexpected type for 'spans' data. Provide a dictionary mapping keys to "
|
||||
"a list of spans, with each span represented by a tuple (start_char, end_char). "
|
||||
"The tuple can be optionally extended with a label and a KB ID.")
|
||||
E880 = ("The 'wandb' library could not be found - did you install it? "
|
||||
"Alternatively, specify the 'ConsoleLogger' in the 'training.logger' "
|
||||
"config section, instead of the 'WandbLogger'.")
|
||||
|
|
|
@ -196,6 +196,104 @@ def test_Example_from_dict_with_entities_invalid(annots):
|
|||
assert len(list(example.reference.ents)) == 0
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"annots",
|
||||
[
|
||||
{
|
||||
"words": ["I", "like", "New", "York", "and", "Berlin", "."],
|
||||
"entities": [
|
||||
(7, 15, "LOC"),
|
||||
(11, 15, "LOC"),
|
||||
(20, 26, "LOC"),
|
||||
], # overlapping
|
||||
}
|
||||
],
|
||||
)
|
||||
def test_Example_from_dict_with_entities_overlapping(annots):
|
||||
vocab = Vocab()
|
||||
predicted = Doc(vocab, words=annots["words"])
|
||||
with pytest.raises(ValueError):
|
||||
Example.from_dict(predicted, annots)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"annots",
|
||||
[
|
||||
{
|
||||
"words": ["I", "like", "New", "York", "and", "Berlin", "."],
|
||||
"spans": {
|
||||
"cities": [(7, 15, "LOC"), (20, 26, "LOC")],
|
||||
"people": [(0, 1, "PERSON")],
|
||||
},
|
||||
}
|
||||
],
|
||||
)
|
||||
def test_Example_from_dict_with_spans(annots):
|
||||
vocab = Vocab()
|
||||
predicted = Doc(vocab, words=annots["words"])
|
||||
example = Example.from_dict(predicted, annots)
|
||||
assert len(list(example.reference.ents)) == 0
|
||||
assert len(list(example.reference.spans["cities"])) == 2
|
||||
assert len(list(example.reference.spans["people"])) == 1
|
||||
for span in example.reference.spans["cities"]:
|
||||
assert span.label_ == "LOC"
|
||||
for span in example.reference.spans["people"]:
|
||||
assert span.label_ == "PERSON"
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"annots",
|
||||
[
|
||||
{
|
||||
"words": ["I", "like", "New", "York", "and", "Berlin", "."],
|
||||
"spans": {
|
||||
"cities": [(7, 15, "LOC"), (11, 15, "LOC"), (20, 26, "LOC")],
|
||||
"people": [(0, 1, "PERSON")],
|
||||
},
|
||||
}
|
||||
],
|
||||
)
|
||||
def test_Example_from_dict_with_spans_overlapping(annots):
|
||||
vocab = Vocab()
|
||||
predicted = Doc(vocab, words=annots["words"])
|
||||
example = Example.from_dict(predicted, annots)
|
||||
assert len(list(example.reference.ents)) == 0
|
||||
assert len(list(example.reference.spans["cities"])) == 3
|
||||
assert len(list(example.reference.spans["people"])) == 1
|
||||
for span in example.reference.spans["cities"]:
|
||||
assert span.label_ == "LOC"
|
||||
for span in example.reference.spans["people"]:
|
||||
assert span.label_ == "PERSON"
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"annots",
|
||||
[
|
||||
{
|
||||
"words": ["I", "like", "New", "York", "and", "Berlin", "."],
|
||||
"spans": [(0, 1, "PERSON")],
|
||||
},
|
||||
{
|
||||
"words": ["I", "like", "New", "York", "and", "Berlin", "."],
|
||||
"spans": {"cities": (7, 15, "LOC")},
|
||||
},
|
||||
{
|
||||
"words": ["I", "like", "New", "York", "and", "Berlin", "."],
|
||||
"spans": {"cities": [7, 11]},
|
||||
},
|
||||
{
|
||||
"words": ["I", "like", "New", "York", "and", "Berlin", "."],
|
||||
"spans": {"cities": [[7]]},
|
||||
},
|
||||
],
|
||||
)
|
||||
def test_Example_from_dict_with_spans_invalid(annots):
|
||||
vocab = Vocab()
|
||||
predicted = Doc(vocab, words=annots["words"])
|
||||
with pytest.raises(ValueError):
|
||||
Example.from_dict(predicted, annots)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"annots",
|
||||
[
|
||||
|
|
|
@ -22,6 +22,8 @@ cpdef Doc annotations_to_doc(vocab, tok_annot, doc_annot):
|
|||
output = Doc(vocab, words=tok_annot["ORTH"], spaces=tok_annot["SPACY"])
|
||||
if "entities" in doc_annot:
|
||||
_add_entities_to_doc(output, doc_annot["entities"])
|
||||
if "spans" in doc_annot:
|
||||
_add_spans_to_doc(output, doc_annot["spans"])
|
||||
if array.size:
|
||||
output = output.from_array(attrs, array)
|
||||
# links are currently added with ENT_KB_ID on the token level
|
||||
|
@ -314,13 +316,11 @@ def _annot2array(vocab, tok_annot, doc_annot):
|
|||
|
||||
for key, value in doc_annot.items():
|
||||
if value:
|
||||
if key == "entities":
|
||||
if key in ["entities", "cats", "spans"]:
|
||||
pass
|
||||
elif key == "links":
|
||||
ent_kb_ids = _parse_links(vocab, tok_annot["ORTH"], tok_annot["SPACY"], value)
|
||||
tok_annot["ENT_KB_ID"] = ent_kb_ids
|
||||
elif key == "cats":
|
||||
pass
|
||||
else:
|
||||
raise ValueError(Errors.E974.format(obj="doc", key=key))
|
||||
|
||||
|
@ -351,6 +351,29 @@ def _annot2array(vocab, tok_annot, doc_annot):
|
|||
return attrs, array.T
|
||||
|
||||
|
||||
def _add_spans_to_doc(doc, spans_data):
|
||||
if not isinstance(spans_data, dict):
|
||||
raise ValueError(Errors.E879)
|
||||
for key, span_list in spans_data.items():
|
||||
spans = []
|
||||
if not isinstance(span_list, list):
|
||||
raise ValueError(Errors.E879)
|
||||
for span_tuple in span_list:
|
||||
if not isinstance(span_tuple, (list, tuple)) or len(span_tuple) < 2:
|
||||
raise ValueError(Errors.E879)
|
||||
start_char = span_tuple[0]
|
||||
end_char = span_tuple[1]
|
||||
label = 0
|
||||
kb_id = 0
|
||||
if len(span_tuple) > 2:
|
||||
label = span_tuple[2]
|
||||
if len(span_tuple) > 3:
|
||||
kb_id = span_tuple[3]
|
||||
span = doc.char_span(start_char, end_char, label=label, kb_id=kb_id)
|
||||
spans.append(span)
|
||||
doc.spans[key] = spans
|
||||
|
||||
|
||||
def _add_entities_to_doc(doc, ner_data):
|
||||
if ner_data is None:
|
||||
return
|
||||
|
@ -397,7 +420,7 @@ def _fix_legacy_dict_data(example_dict):
|
|||
pass
|
||||
elif key == "ids":
|
||||
pass
|
||||
elif key in ("cats", "links"):
|
||||
elif key in ("cats", "links", "spans"):
|
||||
doc_dict[key] = value
|
||||
elif key in ("ner", "entities"):
|
||||
doc_dict["entities"] = value
|
||||
|
|
Loading…
Reference in New Issue