Recalculate alignment if tokenization differs (#5868)

* Recalculate alignment if tokenization differs

* Refactor cached alignment data
This commit is contained in:
Adriane Boyd 2020-08-04 14:31:32 +02:00 committed by GitHub
parent 934447a611
commit b7e3018d97
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 34 additions and 13 deletions

View File

@ -4,4 +4,6 @@ from ..tokens.doc cimport Doc
cdef class Example:
cdef readonly Doc x
cdef readonly Doc y
cdef readonly object _alignment
cdef readonly object _cached_alignment
cdef readonly object _cached_words_x
cdef readonly object _cached_words_y

View File

@ -32,9 +32,9 @@ cdef class Example:
raise TypeError(Errors.E972.format(arg="predicted"))
if reference is None:
raise TypeError(Errors.E972.format(arg="reference"))
self.x = predicted
self.y = reference
self._alignment = alignment
self.predicted = predicted
self.reference = reference
self._cached_alignment = alignment
def __len__(self):
return len(self.predicted)
@ -45,7 +45,8 @@ cdef class Example:
def __set__(self, doc):
self.x = doc
self._alignment = None
self._cached_alignment = None
self._cached_words_x = [t.text for t in doc]
property reference:
def __get__(self):
@ -53,7 +54,8 @@ cdef class Example:
def __set__(self, doc):
self.y = doc
self._alignment = None
self._cached_alignment = None
self._cached_words_y = [t.text for t in doc]
def copy(self):
return Example(
@ -79,13 +81,15 @@ cdef class Example:
@property
def alignment(self):
if self._alignment is None:
spacy_words = [token.orth_ for token in self.predicted]
gold_words = [token.orth_ for token in self.reference]
if gold_words == []:
gold_words = spacy_words
self._alignment = Alignment.from_strings(spacy_words, gold_words)
return self._alignment
words_x = [token.text for token in self.x]
words_y = [token.text for token in self.y]
if self._cached_alignment is None or \
words_x != self._cached_words_x or \
words_y != self._cached_words_y:
self._cached_alignment = Alignment.from_strings(words_x, words_y)
self._cached_words_x = words_x
self._cached_words_y = words_y
return self._cached_alignment
def get_aligned(self, field, as_string=False):
"""Return an aligned array for a token attribute."""

View File

@ -655,3 +655,18 @@ def test_split_sents(merged_dict):
assert token_annotation_2["words"] == ["It", "is", "just", "me"]
assert token_annotation_2["tags"] == ["PRON", "AUX", "ADV", "PRON"]
assert token_annotation_2["sent_starts"] == [1, 0, 0, 0]
def test_retokenized_docs(doc):
a = doc.to_array(["TAG"])
doc1 = Doc(doc.vocab, words=[t.text for t in doc]).from_array(["TAG"], a)
doc2 = Doc(doc.vocab, words=[t.text for t in doc]).from_array(["TAG"], a)
example = Example(doc1, doc2)
assert example.get_aligned("ORTH", as_string=True) == ['Sarah', "'s", 'sister', 'flew', 'to', 'Silicon', 'Valley', 'via', 'London', '.']
with doc1.retokenize() as retokenizer:
retokenizer.merge(doc1[0:2])
retokenizer.merge(doc1[5:7])
assert example.get_aligned("ORTH", as_string=True) == [None, 'sister', 'flew', 'to', None, 'via', 'London', '.']