diff --git a/spacy/tests/regression/test_issue5082.py b/spacy/tests/regression/test_issue5082.py new file mode 100644 index 000000000..efa5d39f2 --- /dev/null +++ b/spacy/tests/regression/test_issue5082.py @@ -0,0 +1,46 @@ +# coding: utf8 +from __future__ import unicode_literals + +import numpy as np +from spacy.lang.en import English +from spacy.pipeline import EntityRuler + + +def test_issue5082(): + # Ensure the 'merge_entities' pipeline does something sensible for the vectors of the merged tokens + nlp = English() + vocab = nlp.vocab + array1 = np.asarray([0.1, 0.5, 0.8], dtype=np.float32) + array2 = np.asarray([-0.2, -0.6, -0.9], dtype=np.float32) + array3 = np.asarray([0.3, -0.1, 0.7], dtype=np.float32) + array4 = np.asarray([0.5, 0, 0.3], dtype=np.float32) + array34 = np.asarray([0.4, -0.05, 0.5], dtype=np.float32) + + vocab.set_vector("I", array1) + vocab.set_vector("like", array2) + vocab.set_vector("David", array3) + vocab.set_vector("Bowie", array4) + + text = "I like David Bowie" + ruler = EntityRuler(nlp) + patterns = [ + {"label": "PERSON", "pattern": [{"LOWER": "david"}, {"LOWER": "bowie"}]} + ] + ruler.add_patterns(patterns) + nlp.add_pipe(ruler) + + parsed_vectors_1 = [t.vector for t in nlp(text)] + assert len(parsed_vectors_1) == 4 + np.testing.assert_array_equal(parsed_vectors_1[0], array1) + np.testing.assert_array_equal(parsed_vectors_1[1], array2) + np.testing.assert_array_equal(parsed_vectors_1[2], array3) + np.testing.assert_array_equal(parsed_vectors_1[3], array4) + + merge_ents = nlp.create_pipe("merge_entities") + nlp.add_pipe(merge_ents) + + parsed_vectors_2 = [t.vector for t in nlp(text)] + assert len(parsed_vectors_2) == 3 + np.testing.assert_array_equal(parsed_vectors_2[0], array1) + np.testing.assert_array_equal(parsed_vectors_2[1], array2) + np.testing.assert_array_equal(parsed_vectors_2[2], array34) diff --git a/spacy/tokens/_retokenize.pyx b/spacy/tokens/_retokenize.pyx index a5d06491a..512ad73bc 100644 --- a/spacy/tokens/_retokenize.pyx +++ b/spacy/tokens/_retokenize.pyx @@ -213,6 +213,10 @@ def _merge(Doc doc, merges): new_orth = ''.join([t.text_with_ws for t in spans[token_index]]) if spans[token_index][-1].whitespace_: new_orth = new_orth[:-len(spans[token_index][-1].whitespace_)] + # add the vector of the (merged) entity to the vocab + if not doc.vocab.get_vector(new_orth).any(): + if doc.vocab.vectors_length > 0: + doc.vocab.set_vector(new_orth, span.vector) token = tokens[token_index] lex = doc.vocab.get(doc.mem, new_orth) token.lex = lex