mirror of https://github.com/explosion/spaCy.git
Fix displacy span stacking (#13068)
* Fix displacy span stacking. * Format. Remove counter. * Remove test files. * Add unit test. Refactor to allow for unit test. * Fix off-by-one error in tests.
This commit is contained in:
parent
a804b83a4b
commit
c4e2daf6ef
|
@ -142,7 +142,25 @@ class SpanRenderer:
|
||||||
spans (list): Individual entity spans and their start, end, label, kb_id and kb_url.
|
spans (list): Individual entity spans and their start, end, label, kb_id and kb_url.
|
||||||
title (str / None): Document title set in Doc.user_data['title'].
|
title (str / None): Document title set in Doc.user_data['title'].
|
||||||
"""
|
"""
|
||||||
per_token_info = []
|
per_token_info = self._assemble_per_token_info(tokens, spans)
|
||||||
|
markup = self._render_markup(per_token_info)
|
||||||
|
markup = TPL_SPANS.format(content=markup, dir=self.direction)
|
||||||
|
if title:
|
||||||
|
markup = TPL_TITLE.format(title=title) + markup
|
||||||
|
return markup
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _assemble_per_token_info(
|
||||||
|
tokens: List[str], spans: List[Dict[str, Any]]
|
||||||
|
) -> List[Dict[str, List[Dict[str, Any]]]]:
|
||||||
|
"""Assembles token info used to generate markup in render_spans().
|
||||||
|
tokens (List[str]): Tokens in text.
|
||||||
|
spans (List[Dict[str, Any]]): Spans in text.
|
||||||
|
RETURNS (List[Dict[str, List[Dict, str, Any]]]): Per token info needed to render HTML markup for given tokens
|
||||||
|
and spans.
|
||||||
|
"""
|
||||||
|
per_token_info: List[Dict[str, List[Dict[str, Any]]]] = []
|
||||||
|
|
||||||
# we must sort so that we can correctly describe when spans need to "stack"
|
# we must sort so that we can correctly describe when spans need to "stack"
|
||||||
# which is determined by their start token, then span length (longer spans on top),
|
# which is determined by their start token, then span length (longer spans on top),
|
||||||
# then break any remaining ties with the span label
|
# then break any remaining ties with the span label
|
||||||
|
@ -154,21 +172,22 @@ class SpanRenderer:
|
||||||
s["label"],
|
s["label"],
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
for s in spans:
|
for s in spans:
|
||||||
# this is the vertical 'slot' that the span will be rendered in
|
# this is the vertical 'slot' that the span will be rendered in
|
||||||
# vertical_position = span_label_offset + (offset_step * (slot - 1))
|
# vertical_position = span_label_offset + (offset_step * (slot - 1))
|
||||||
s["render_slot"] = 0
|
s["render_slot"] = 0
|
||||||
|
|
||||||
for idx, token in enumerate(tokens):
|
for idx, token in enumerate(tokens):
|
||||||
# Identify if a token belongs to a Span (and which) and if it's a
|
# Identify if a token belongs to a Span (and which) and if it's a
|
||||||
# start token of said Span. We'll use this for the final HTML render
|
# start token of said Span. We'll use this for the final HTML render
|
||||||
token_markup: Dict[str, Any] = {}
|
token_markup: Dict[str, Any] = {}
|
||||||
token_markup["text"] = token
|
token_markup["text"] = token
|
||||||
concurrent_spans = 0
|
intersecting_spans: List[Dict[str, Any]] = []
|
||||||
entities = []
|
entities = []
|
||||||
for span in spans:
|
for span in spans:
|
||||||
ent = {}
|
ent = {}
|
||||||
if span["start_token"] <= idx < span["end_token"]:
|
if span["start_token"] <= idx < span["end_token"]:
|
||||||
concurrent_spans += 1
|
|
||||||
span_start = idx == span["start_token"]
|
span_start = idx == span["start_token"]
|
||||||
ent["label"] = span["label"]
|
ent["label"] = span["label"]
|
||||||
ent["is_start"] = span_start
|
ent["is_start"] = span_start
|
||||||
|
@ -176,7 +195,12 @@ class SpanRenderer:
|
||||||
# When the span starts, we need to know how many other
|
# When the span starts, we need to know how many other
|
||||||
# spans are on the 'span stack' and will be rendered.
|
# spans are on the 'span stack' and will be rendered.
|
||||||
# This value becomes the vertical render slot for this entire span
|
# This value becomes the vertical render slot for this entire span
|
||||||
span["render_slot"] = concurrent_spans
|
span["render_slot"] = (
|
||||||
|
intersecting_spans[-1]["render_slot"]
|
||||||
|
if len(intersecting_spans)
|
||||||
|
else 0
|
||||||
|
) + 1
|
||||||
|
intersecting_spans.append(span)
|
||||||
ent["render_slot"] = span["render_slot"]
|
ent["render_slot"] = span["render_slot"]
|
||||||
kb_id = span.get("kb_id", "")
|
kb_id = span.get("kb_id", "")
|
||||||
kb_url = span.get("kb_url", "#")
|
kb_url = span.get("kb_url", "#")
|
||||||
|
@ -193,11 +217,8 @@ class SpanRenderer:
|
||||||
span["render_slot"] = 0
|
span["render_slot"] = 0
|
||||||
token_markup["entities"] = entities
|
token_markup["entities"] = entities
|
||||||
per_token_info.append(token_markup)
|
per_token_info.append(token_markup)
|
||||||
markup = self._render_markup(per_token_info)
|
|
||||||
markup = TPL_SPANS.format(content=markup, dir=self.direction)
|
return per_token_info
|
||||||
if title:
|
|
||||||
markup = TPL_TITLE.format(title=title) + markup
|
|
||||||
return markup
|
|
||||||
|
|
||||||
def _render_markup(self, per_token_info: List[Dict[str, Any]]) -> str:
|
def _render_markup(self, per_token_info: List[Dict[str, Any]]) -> str:
|
||||||
"""Render the markup from per-token information"""
|
"""Render the markup from per-token information"""
|
||||||
|
|
|
@ -2,7 +2,7 @@ import numpy
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from spacy import displacy
|
from spacy import displacy
|
||||||
from spacy.displacy.render import DependencyRenderer, EntityRenderer
|
from spacy.displacy.render import DependencyRenderer, EntityRenderer, SpanRenderer
|
||||||
from spacy.lang.en import English
|
from spacy.lang.en import English
|
||||||
from spacy.lang.fa import Persian
|
from spacy.lang.fa import Persian
|
||||||
from spacy.tokens import Doc, Span
|
from spacy.tokens import Doc, Span
|
||||||
|
@ -468,3 +468,23 @@ def test_issue12816(en_vocab) -> None:
|
||||||
# Verify that the HTML tag is still escaped
|
# Verify that the HTML tag is still escaped
|
||||||
html = displacy.render(doc, style="span")
|
html = displacy.render(doc, style="span")
|
||||||
assert "<TEST>" in html
|
assert "<TEST>" in html
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.issue(13056)
|
||||||
|
def test_displacy_span_stacking():
|
||||||
|
"""Test whether span stacking works properly for multiple overlapping spans."""
|
||||||
|
spans = [
|
||||||
|
{"start_token": 2, "end_token": 5, "label": "SkillNC"},
|
||||||
|
{"start_token": 0, "end_token": 2, "label": "Skill"},
|
||||||
|
{"start_token": 1, "end_token": 3, "label": "Skill"},
|
||||||
|
]
|
||||||
|
tokens = ["Welcome", "to", "the", "Bank", "of", "China", "."]
|
||||||
|
per_token_info = SpanRenderer._assemble_per_token_info(spans=spans, tokens=tokens)
|
||||||
|
|
||||||
|
assert len(per_token_info) == len(tokens)
|
||||||
|
assert all([len(per_token_info[i]["entities"]) == 1 for i in (0, 3, 4)])
|
||||||
|
assert all([len(per_token_info[i]["entities"]) == 2 for i in (1, 2)])
|
||||||
|
assert per_token_info[1]["entities"][0]["render_slot"] == 1
|
||||||
|
assert per_token_info[1]["entities"][1]["render_slot"] == 2
|
||||||
|
assert per_token_info[2]["entities"][0]["render_slot"] == 2
|
||||||
|
assert per_token_info[2]["entities"][1]["render_slot"] == 3
|
||||||
|
|
Loading…
Reference in New Issue