From fef896ce49093357247d223e4f4d65d8811ac380 Mon Sep 17 00:00:00 2001 From: Adriane Boyd Date: Thu, 3 Feb 2022 17:01:53 +0100 Subject: [PATCH] Allow Example to align whitespace annotation (#10189) Remove exception for whitespace tokens in `Example.get_aligned` so that annotation on whitespace tokens is aligned in the same way as for non-whitespace tokens. --- spacy/tests/training/test_new_example.py | 10 ++++++++++ spacy/training/example.pyx | 21 +++++++++------------ 2 files changed, 19 insertions(+), 12 deletions(-) diff --git a/spacy/tests/training/test_new_example.py b/spacy/tests/training/test_new_example.py index 4dd90f416..a39d40ded 100644 --- a/spacy/tests/training/test_new_example.py +++ b/spacy/tests/training/test_new_example.py @@ -421,3 +421,13 @@ def test_Example_missing_heads(): # Ensure that the missing head doesn't create an artificial new sentence start expected = [True, False, False, False, False, False] assert example.get_aligned_sent_starts() == expected + + +def test_Example_aligned_whitespace(en_vocab): + words = ["a", " ", "b"] + tags = ["A", "SPACE", "B"] + predicted = Doc(en_vocab, words=words) + reference = Doc(en_vocab, words=words, tags=tags) + + example = Example(predicted, reference) + assert example.get_aligned("TAG", as_string=True) == tags diff --git a/spacy/training/example.pyx b/spacy/training/example.pyx index 732203e7b..d792c9bbf 100644 --- a/spacy/training/example.pyx +++ b/spacy/training/example.pyx @@ -159,20 +159,17 @@ cdef class Example: gold_values = self.reference.to_array([field]) output = [None] * len(self.predicted) for token in self.predicted: - if token.is_space: + values = gold_values[align[token.i].dataXd] + values = values.ravel() + if len(values) == 0: output[token.i] = None + elif len(values) == 1: + output[token.i] = values[0] + elif len(set(list(values))) == 1: + # If all aligned tokens have the same value, use it. + output[token.i] = values[0] else: - values = gold_values[align[token.i].dataXd] - values = values.ravel() - if len(values) == 0: - output[token.i] = None - elif len(values) == 1: - output[token.i] = values[0] - elif len(set(list(values))) == 1: - # If all aligned tokens have the same value, use it. - output[token.i] = values[0] - else: - output[token.i] = None + output[token.i] = None if as_string and field not in ["ENT_IOB", "SENT_START"]: output = [vocab.strings[o] if o is not None else o for o in output] return output