diff --git a/spacy/pipeline/spancat.py b/spacy/pipeline/spancat.py index 983e1fba9..ff68a3703 100644 --- a/spacy/pipeline/spancat.py +++ b/spacy/pipeline/spancat.py @@ -726,6 +726,7 @@ class SpanCategorizer(TrainablePipe): if not allow_overlap: # Get the probabilities sort_idx = (argmax_scores.squeeze() * -1).argsort() + argmax_scores = argmax_scores[sort_idx] predicted = predicted[sort_idx] indices = indices[sort_idx] keeps = keeps[sort_idx] @@ -748,4 +749,5 @@ class SpanCategorizer(TrainablePipe): attrs_scores.append(argmax_scores[i]) spans.append(Span(doc, start, end, label=self.labels[label])) + spans.attrs["scores"] = numpy.array(attrs_scores) return spans diff --git a/spacy/tests/pipeline/test_spancat.py b/spacy/tests/pipeline/test_spancat.py index cf6304042..b06505a6d 100644 --- a/spacy/tests/pipeline/test_spancat.py +++ b/spacy/tests/pipeline/test_spancat.py @@ -190,17 +190,19 @@ def test_make_spangroup_singlelabel(threshold, allow_overlap, nr_results): spangroup = spancat._make_span_group_singlelabel( doc, indices, scores, allow_overlap ) - assert len(spangroup) == nr_results if threshold > 0.4: if allow_overlap: assert spangroup[0].text == "London" assert spangroup[0].label_ == "City" + assert_almost_equal(0.6, spangroup.attrs["scores"][0], 5) assert spangroup[1].text == "Greater London" assert spangroup[1].label_ == "GreatCity" - + assert spangroup.attrs["scores"][1] == 0.9 + assert_almost_equal(0.9, spangroup.attrs["scores"][1], 5) else: assert spangroup[0].text == "Greater London" assert spangroup[0].label_ == "GreatCity" + assert spangroup.attrs["scores"][0] == 0.9 else: if allow_overlap: assert spangroup[0].text == "Greater" @@ -256,22 +258,32 @@ def test_make_spangroup_negative_label(): assert len(spangroup_single) == 2 assert spangroup_single[0].text == "Greater" assert spangroup_single[0].label_ == "City" + assert_almost_equal(0.4, spangroup_single.attrs["scores"][0], 5) assert spangroup_single[1].text == "Greater London" assert spangroup_single[1].label_ == "GreatCity" + assert spangroup_single.attrs["scores"][1] == 0.9 + assert_almost_equal(0.9, spangroup_single.attrs["scores"][1], 5) assert len(spangroup_multi) == 6 assert spangroup_multi[0].text == "Greater" assert spangroup_multi[0].label_ == "City" + assert_almost_equal(0.4, spangroup_multi.attrs["scores"][0], 5) assert spangroup_multi[1].text == "Greater" assert spangroup_multi[1].label_ == "Person" + assert_almost_equal(0.3, spangroup_multi.attrs["scores"][1], 5) assert spangroup_multi[2].text == "London" assert spangroup_multi[2].label_ == "City" + assert_almost_equal(0.6, spangroup_multi.attrs["scores"][2], 5) assert spangroup_multi[3].text == "London" assert spangroup_multi[3].label_ == "GreatCity" + assert_almost_equal(0.4, spangroup_multi.attrs["scores"][3], 5) assert spangroup_multi[4].text == "Greater London" assert spangroup_multi[4].label_ == "Thing" + assert spangroup_multi[4].text == "Greater London" + assert_almost_equal(0.8, spangroup_multi.attrs["scores"][4], 5) assert spangroup_multi[5].text == "Greater London" assert spangroup_multi[5].label_ == "GreatCity" + assert_almost_equal(0.9, spangroup_multi.attrs["scores"][5], 5) def test_ngram_suggester(en_tokenizer):