Resize doc.tensor when merging spans. Closes #1963 (#3106)

The doc.retokenize() context manager wasn't resizing doc.tensor, leading to a mismatch between the number of tokens in the doc and the number of rows in the tensor. We fix this by deleting rows from the tensor. Merged spans are represented by the vector of their last token.

* Add test for resizing doc.tensor when merging

* Add test for resizing doc.tensor when merging. Closes #1963

* Update get_lca_matrix test for develop

* Fix retokenize if tensor unset
This commit is contained in:
Matthew Honnibal 2018-12-30 15:17:17 +01:00 committed by GitHub
parent 3d64eb4a74
commit 72e4d3782a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 32 additions and 2 deletions

View File

@ -247,6 +247,16 @@ def test_issue1945():
assert matches[1][1:] == (1, 3) assert matches[1][1:] == (1, 3)
def test_issue1963(en_tokenizer):
"""Test that doc.merge() resizes doc.tensor"""
doc = en_tokenizer('a b c d')
doc.tensor = numpy.ones((len(doc), 128), dtype='f')
with doc.retokenize() as retokenizer:
retokenizer.merge(doc[0:2])
assert len(doc) == 3
assert doc.tensor.shape == (3, 128)
@pytest.mark.parametrize("label", ["U-JOB-NAME"]) @pytest.mark.parametrize("label", ["U-JOB-NAME"])
def test_issue1967(label): def test_issue1967(label):
ner = EntityRecognizer(Vocab()) ner = EntityRecognizer(Vocab())

View File

@ -7,7 +7,9 @@ from __future__ import unicode_literals
from libc.string cimport memcpy, memset from libc.string cimport memcpy, memset
from libc.stdlib cimport malloc, free from libc.stdlib cimport malloc, free
import numpy
from cymem.cymem cimport Pool from cymem.cymem cimport Pool
from thinc.neural.util import get_array_module
from .doc cimport Doc, set_children_from_heads, token_by_start, token_by_end from .doc cimport Doc, set_children_from_heads, token_by_start, token_by_end
from .span cimport Span from .span cimport Span
@ -83,6 +85,11 @@ def _merge(Doc doc, int start, int end, attributes):
cdef Span span = doc[start:end] cdef Span span = doc[start:end]
cdef int start_char = span.start_char cdef int start_char = span.start_char
cdef int end_char = span.end_char cdef int end_char = span.end_char
# Resize the doc.tensor, if it's set. Let the last row for each token stand
# for the merged region. To do this, we create a boolean array indicating
# whether the row is to be deleted, then use numpy.delete
if doc.tensor is not None and doc.tensor.size != 0:
doc.tensor = _resize_tensor(doc.tensor, [(start, end)])
# Get LexemeC for newly merged token # Get LexemeC for newly merged token
new_orth = ''.join([t.text_with_ws for t in span]) new_orth = ''.join([t.text_with_ws for t in span])
if span[-1].whitespace_: if span[-1].whitespace_:
@ -182,7 +189,12 @@ def _bulk_merge(Doc doc, merges):
else: else:
Token.set_struct_attr(token, attr_name, attr_value) Token.set_struct_attr(token, attr_name, attr_value)
# Resize the doc.tensor, if it's set. Let the last row for each token stand
# for the merged region. To do this, we create a boolean array indicating
# whether the row is to be deleted, then use numpy.delete
if doc.tensor is not None and doc.tensor.size != 0:
doc.tensor = _resize_tensor(doc.tensor,
[(m[1][0].start, m[1][0].end) for m in merges])
# Memorize span roots and sets dependencies of the newly merged # Memorize span roots and sets dependencies of the newly merged
# tokens to the dependencies of their roots. # tokens to the dependencies of their roots.
span_roots = [] span_roots = []
@ -276,6 +288,14 @@ def _bulk_merge(Doc doc, merges):
else: else:
# If they're not the same entity type, let them be two entities # If they're not the same entity type, let them be two entities
doc.c[token_after_span_position].ent_iob = 3 doc.c[token_after_span_position].ent_iob = 3
# Return the merged Python object # Return the merged Python object
return doc[spans[0].start] return doc[spans[0].start]
def _resize_tensor(tensor, ranges):
delete = []
for start, end in ranges:
for i in range(start, end-1):
delete.append(i)
xp = get_array_module(tensor)
return xp.delete(tensor, delete, axis=0)