💫 Fix displaCy support for RTL languages (#3393)

Closes #2091.

## Description

With the new `vocab.writing_system` property introduced in #3390 (exposed via the language defaults), I was able to finally fix this (I think!). Based on the `Doc`, dispaCy now detects whether it's a RTL or LTR language and adjusts the visualization accordingly. Wherever possible, I've also added `direction` and `lang` attributes.

Entity visualization now looks like this:

<img width="318" alt="Screenshot 2019-03-11 at 16 06 51" src="https://user-images.githubusercontent.com/13643239/54136866-d97afd80-441c-11e9-8c27-3d46994cc833.png">

And dependencies like this (ignore the most likely incorrect tags and dependencies):

<img width="621" alt="Screenshot 2019-03-11 at 16 51 59" src="https://user-images.githubusercontent.com/13643239/54137771-8b66f980-441e-11e9-8460-0682b95eef2a.png">

### Types of change
enhancement, bug fix

## Checklist
<!--- Before you submit the PR, go over this checklist and make sure you can
tick off all the boxes. [] -> [x] -->
- [x] I have submitted the spaCy Contributor Agreement.
- [x] I ran the tests, and all new and existing tests passed.
- [x] My changes don't require a change to the documentation, or if they do, I've added all required information.
This commit is contained in:
Ines Montani 2019-03-11 18:52:50 +01:00 committed by GitHub
parent b1f5f39a19
commit 4bd2688eac
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 86 additions and 21 deletions

View File

@ -161,7 +161,7 @@ def parse_deps(orig_doc, options={}):
"dir": "right",
}
)
return {"words": words, "arcs": arcs}
return {"words": words, "arcs": arcs, "settings": get_doc_settings(orig_doc)}
def parse_ents(doc, options={}):
@ -177,7 +177,8 @@ def parse_ents(doc, options={}):
if not ents:
user_warning(Warnings.W006)
title = doc.user_data.get("title", None) if hasattr(doc, "user_data") else None
return {"text": doc.text, "ents": ents, "title": title}
settings = get_doc_settings(doc)
return {"text": doc.text, "ents": ents, "title": title, "settings": settings}
def set_render_wrapper(func):
@ -195,3 +196,10 @@ def set_render_wrapper(func):
if not hasattr(func, "__call__"):
raise ValueError(Errors.E110.format(obj=type(func)))
RENDER_WRAPPER = func
def get_doc_settings(doc):
return {
"lang": doc.lang_,
"direction": doc.vocab.writing_system.get("direction", "ltr"),
}

View File

@ -3,10 +3,13 @@ from __future__ import unicode_literals
import uuid
from .templates import TPL_DEP_SVG, TPL_DEP_WORDS, TPL_DEP_ARCS
from .templates import TPL_ENT, TPL_ENTS, TPL_FIGURE, TPL_TITLE, TPL_PAGE
from .templates import TPL_DEP_SVG, TPL_DEP_WORDS, TPL_DEP_ARCS, TPL_ENTS
from .templates import TPL_ENT, TPL_ENT_RTL, TPL_FIGURE, TPL_TITLE, TPL_PAGE
from ..util import minify_html, escape_html
DEFAULT_LANG = "en"
DEFAULT_DIR = "ltr"
class DependencyRenderer(object):
"""Render dependency parses as SVGs."""
@ -30,6 +33,8 @@ class DependencyRenderer(object):
self.color = options.get("color", "#000000")
self.bg = options.get("bg", "#ffffff")
self.font = options.get("font", "Arial")
self.direction = DEFAULT_DIR
self.lang = DEFAULT_LANG
def render(self, parsed, page=False, minify=False):
"""Render complete markup.
@ -42,13 +47,19 @@ class DependencyRenderer(object):
# Create a random ID prefix to make sure parses don't receive the
# same ID, even if they're identical
id_prefix = uuid.uuid4().hex
rendered = [
self.render_svg("{}-{}".format(id_prefix, i), p["words"], p["arcs"])
for i, p in enumerate(parsed)
]
rendered = []
for i, p in enumerate(parsed):
if i == 0:
self.direction = p["settings"].get("direction", DEFAULT_DIR)
self.lang = p["settings"].get("lang", DEFAULT_LANG)
render_id = "{}-{}".format(id_prefix, i)
svg = self.render_svg(render_id, p["words"], p["arcs"])
rendered.append(svg)
if page:
content = "".join([TPL_FIGURE.format(content=svg) for svg in rendered])
markup = TPL_PAGE.format(content=content)
markup = TPL_PAGE.format(
content=content, lang=self.lang, dir=self.direction
)
else:
markup = "".join(rendered)
if minify:
@ -83,6 +94,8 @@ class DependencyRenderer(object):
bg=self.bg,
font=self.font,
content=content,
dir=self.direction,
lang=self.lang,
)
def render_word(self, text, tag, i):
@ -95,11 +108,13 @@ class DependencyRenderer(object):
"""
y = self.offset_y + self.word_spacing
x = self.offset_x + i * self.distance
if self.direction == "rtl":
x = self.width - x
html_text = escape_html(text)
return TPL_DEP_WORDS.format(text=html_text, tag=tag, x=x, y=y)
def render_arrow(self, label, start, end, direction, i):
"""Render indivicual arrow.
"""Render individual arrow.
label (unicode): Dependency label.
start (int): Index of start word.
@ -110,6 +125,8 @@ class DependencyRenderer(object):
"""
level = self.levels.index(end - start) + 1
x_start = self.offset_x + start * self.distance + self.arrow_spacing
if self.direction == "rtl":
x_start = self.width - x_start
y = self.offset_y
x_end = (
self.offset_x
@ -117,6 +134,8 @@ class DependencyRenderer(object):
+ start * self.distance
- self.arrow_spacing * (self.highest_level - level) / 4
)
if self.direction == "rtl":
x_end = self.width - x_end
y_curve = self.offset_y - level * self.distance / 2
if self.compact:
y_curve = self.offset_y - level * self.distance / 6
@ -124,12 +143,14 @@ class DependencyRenderer(object):
y_curve = -self.distance
arrowhead = self.get_arrowhead(direction, x_start, y, x_end)
arc = self.get_arc(x_start, y, y_curve, x_end)
label_side = "right" if self.direction == "rtl" else "left"
return TPL_DEP_ARCS.format(
id=self.id,
i=i,
stroke=self.arrow_stroke,
head=arrowhead,
label=label,
label_side=label_side,
arc=arc,
)
@ -219,6 +240,8 @@ class EntityRenderer(object):
self.default_color = "#ddd"
self.colors = colors
self.ents = options.get("ents", None)
self.direction = DEFAULT_DIR
self.lang = DEFAULT_LANG
def render(self, parsed, page=False, minify=False):
"""Render complete markup.
@ -228,12 +251,15 @@ class EntityRenderer(object):
minify (bool): Minify HTML markup.
RETURNS (unicode): Rendered HTML markup.
"""
rendered = [
self.render_ents(p["text"], p["ents"], p.get("title", None)) for p in parsed
]
rendered = []
for i, p in enumerate(parsed):
if i == 0:
self.direction = p["settings"].get("direction", DEFAULT_DIR)
self.lang = p["settings"].get("lang", DEFAULT_LANG)
rendered.append(self.render_ents(p["text"], p["ents"], p["title"]))
if page:
docs = "".join([TPL_FIGURE.format(content=doc) for doc in rendered])
markup = TPL_PAGE.format(content=docs)
markup = TPL_PAGE.format(content=docs, lang=self.lang, dir=self.direction)
else:
markup = "".join(rendered)
if minify:
@ -261,12 +287,16 @@ class EntityRenderer(object):
markup += "</br>"
if self.ents is None or label.upper() in self.ents:
color = self.colors.get(label.upper(), self.default_color)
markup += TPL_ENT.format(label=label, text=entity, bg=color)
ent_settings = {"label": label, "text": entity, "bg": color}
if self.direction == "rtl":
markup += TPL_ENT_RTL.format(**ent_settings)
else:
markup += TPL_ENT.format(**ent_settings)
else:
markup += entity
offset = end
markup += escape_html(text[offset:])
markup = TPL_ENTS.format(content=markup, colors=self.colors)
markup = TPL_ENTS.format(content=markup, dir=self.direction)
if title:
markup = TPL_TITLE.format(title=title) + markup
return markup

View File

@ -6,7 +6,7 @@ from __future__ import unicode_literals
# Jupyter to render it properly in a cell
TPL_DEP_SVG = """
<svg xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" id="{id}" class="displacy" width="{width}" height="{height}" style="max-width: none; height: {height}px; color: {color}; background: {bg}; font-family: {font}">{content}</svg>
<svg xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" xml:lang="{lang}" id="{id}" class="displacy" width="{width}" height="{height}" direction="{dir}" style="max-width: none; height: {height}px; color: {color}; background: {bg}; font-family: {font}; direction: {dir}">{content}</svg>
"""
@ -22,7 +22,7 @@ TPL_DEP_ARCS = """
<g class="displacy-arrow">
<path class="displacy-arc" id="arrow-{id}-{i}" stroke-width="{stroke}px" d="{arc}" fill="none" stroke="currentColor"/>
<text dy="1.25em" style="font-size: 0.8em; letter-spacing: 1px">
<textPath xlink:href="#arrow-{id}-{i}" class="displacy-label" startOffset="50%" fill="currentColor" text-anchor="middle">{label}</textPath>
<textPath xlink:href="#arrow-{id}-{i}" class="displacy-label" startOffset="50%" side="{label_side}" fill="currentColor" text-anchor="middle">{label}</textPath>
</text>
<path class="displacy-arrowhead" d="{head}" fill="currentColor"/>
</g>
@ -39,7 +39,7 @@ TPL_TITLE = """
TPL_ENTS = """
<div class="entities" style="line-height: 2.5">{content}</div>
<div class="entities" style="line-height: 2.5; direction: {dir}">{content}</div>
"""
@ -50,14 +50,21 @@ TPL_ENT = """
</mark>
"""
TPL_ENT_RTL = """
<mark class="entity" style="background: {bg}; padding: 0.45em 0.6em; margin: 0 0.25em; line-height: 1; border-radius: 0.35em;">
{text}
<span style="font-size: 0.8em; font-weight: bold; line-height: 1; border-radius: 0.35em; text-transform: uppercase; vertical-align: middle; margin-right: 0.5rem">{label}</span>
</mark>
"""
TPL_PAGE = """
<!DOCTYPE html>
<html>
<html lang="{lang}">
<head>
<title>displaCy</title>
</head>
<body style="font-size: 16px; font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', Helvetica, Arial, sans-serif, 'Apple Color Emoji', 'Segoe UI Emoji', 'Segoe UI Symbol'; padding: 4rem 2rem;">{content}</body>
<body style="font-size: 16px; font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', Helvetica, Arial, sans-serif, 'Apple Color Emoji', 'Segoe UI Emoji', 'Segoe UI Symbol'; padding: 4rem 2rem; direction: {dir}">{content}</body>
</html>
"""

View File

@ -4,6 +4,7 @@ from __future__ import unicode_literals
import pytest
from spacy import displacy
from spacy.tokens import Span
from spacy.lang.fa import Persian
from .util import get_doc
@ -66,3 +67,22 @@ def test_displacy_render_wrapper(en_vocab):
def test_displacy_raises_for_wrong_type(en_vocab):
with pytest.raises(ValueError):
displacy.render("hello world")
def test_displacy_rtl():
# Source: http://www.sobhe.ir/hazm/ is this correct?
words = ["ما", "بسیار", "کتاب", "می\u200cخوانیم"]
# These are (likely) wrong, but it's just for testing
pos = ["PRO", "ADV", "N_PL", "V_SUB"] # needs to match lang.fa.tag_map
deps = ["foo", "bar", "foo", "baz"]
heads = [1, 0, 1, -2]
nlp = Persian()
doc = get_doc(nlp.vocab, words=words, pos=pos, tags=pos, heads=heads, deps=deps)
doc.ents = [Span(doc, 1, 3, label="TEST")]
html = displacy.render(doc, page=True, style="dep")
assert "direction: rtl" in html
assert 'direction="rtl"' in html
assert 'lang="{}"'.format(nlp.lang) in html
html = displacy.render(doc, page=True, style="ent")
assert "direction: rtl" in html
assert 'lang="{}"'.format(nlp.lang) in html