From 36cb2029a9accea40285f62c4af365cb974b2ccd Mon Sep 17 00:00:00 2001 From: Peter Baumgartner <5107405+pmbaumgartner@users.noreply.github.com> Date: Fri, 8 Jul 2022 13:20:13 -0400 Subject: [PATCH] displaCy Spans Vertical Alignment Fix 2 (#11092) * add in span render slot fix * fix spacing off by one * rm demo * adjust comments * fix whitespace and overlap issue --- spacy/displacy/render.py | 61 ++++++++++++++++++++++++++++++++++------ 1 file changed, 53 insertions(+), 8 deletions(-) diff --git a/spacy/displacy/render.py b/spacy/displacy/render.py index a730ce522..50dc3466c 100644 --- a/spacy/displacy/render.py +++ b/spacy/displacy/render.py @@ -130,26 +130,56 @@ class SpanRenderer: title (str / None): Document title set in Doc.user_data['title']. """ per_token_info = [] + # we must sort so that we can correctly describe when spans need to "stack" + # which is determined by their start token, then span length (longer spans on top), + # then break any remaining ties with the span label + spans = sorted( + spans, + key=lambda s: ( + s["start_token"], + -(s["end_token"] - s["start_token"]), + s["label"], + ), + ) + for s in spans: + # this is the vertical 'slot' that the span will be rendered in + # vertical_position = span_label_offset + (offset_step * (slot - 1)) + s["render_slot"] = 0 for idx, token in enumerate(tokens): # Identify if a token belongs to a Span (and which) and if it's a # start token of said Span. We'll use this for the final HTML render token_markup: Dict[str, Any] = {} token_markup["text"] = token + concurrent_spans = 0 entities = [] for span in spans: ent = {} if span["start_token"] <= idx < span["end_token"]: + concurrent_spans += 1 + span_start = idx == span["start_token"] ent["label"] = span["label"] - ent["is_start"] = True if idx == span["start_token"] else False + ent["is_start"] = span_start + if span_start: + # When the span starts, we need to know how many other + # spans are on the 'span stack' and will be rendered. + # This value becomes the vertical render slot for this entire span + span["render_slot"] = concurrent_spans + ent["render_slot"] = span["render_slot"] kb_id = span.get("kb_id", "") kb_url = span.get("kb_url", "#") ent["kb_link"] = ( TPL_KB_LINK.format(kb_id=kb_id, kb_url=kb_url) if kb_id else "" ) entities.append(ent) + else: + # We don't specifically need to do this since we loop + # over tokens and spans sorted by their start_token, + # so we'll never use a span again after the last token it appears in, + # but if we were to use these spans again we'd want to make sure + # this value was reset correctly. + span["render_slot"] = 0 token_markup["entities"] = entities per_token_info.append(token_markup) - markup = self._render_markup(per_token_info) markup = TPL_SPANS.format(content=markup, dir=self.direction) if title: @@ -160,8 +190,12 @@ class SpanRenderer: """Render the markup from per-token information""" markup = "" for token in per_token_info: - entities = sorted(token["entities"], key=lambda d: d["label"]) - if entities: + entities = sorted(token["entities"], key=lambda d: d["render_slot"]) + # Whitespace tokens disrupt the vertical space (no line height) so that the + # span indicators get misaligned. We don't render them as individual + # tokens anyway, so we'll just not display a span indicator either. + is_whitespace = token["text"].strip() == "" + if entities and not is_whitespace: slices = self._get_span_slices(token["entities"]) starts = self._get_span_starts(token["entities"]) total_height = ( @@ -182,10 +216,18 @@ class SpanRenderer: def _get_span_slices(self, entities: List[Dict]) -> str: """Get the rendered markup of all Span slices""" span_slices = [] - for entity, step in zip(entities, itertools.count(step=self.offset_step)): + for entity in entities: + # rather than iterate over multiples of offset_step, we use entity['render_slot'] + # to determine the vertical position, since that tells where + # the span starts vertically so we can extend it horizontally, + # past other spans that might have already ended color = self.colors.get(entity["label"].upper(), self.default_color) + top_offset = self.top_offset + ( + self.offset_step * (entity["render_slot"] - 1) + ) span_slice = self.span_slice_template.format( - bg=color, top_offset=self.top_offset + step + bg=color, + top_offset=top_offset, ) span_slices.append(span_slice) return "".join(span_slices) @@ -193,12 +235,15 @@ class SpanRenderer: def _get_span_starts(self, entities: List[Dict]) -> str: """Get the rendered markup of all Span start tokens""" span_starts = [] - for entity, step in zip(entities, itertools.count(step=self.offset_step)): + for entity in entities: color = self.colors.get(entity["label"].upper(), self.default_color) + top_offset = self.top_offset + ( + self.offset_step * (entity["render_slot"] - 1) + ) span_start = ( self.span_start_template.format( bg=color, - top_offset=self.top_offset + step, + top_offset=top_offset, label=entity["label"], kb_link=entity["kb_link"], )