mirror of https://github.com/explosion/spaCy.git
Fix most_similar for vectors with unused rows (#5348)
* Fix most_similar for vectors with unused rows Address issues related to the unused rows in the vector table and `most_similar`: * Update `most_similar()` to search only through rows that are in use according to `key2row`. * Raise an error when `most_similar(n=n)` is larger than the number of vectors in the table. * Set and restore `_unset` correctly when vectors are added or deserialized so that new vectors are added in the correct row. * Set data and keys to the same length in `Vocab.prune_vectors()` to avoid spurious entries in `key2row`. * Fix regression test using `most_similar` Co-authored-by: Matthew Honnibal <honnibal+gh@gmail.com>
This commit is contained in:
parent
70da1fd2d6
commit
40e65d6f63
|
@ -564,6 +564,8 @@ class Errors(object):
|
|||
E196 = ("Refusing to write to token.is_sent_end. Sentence boundaries can "
|
||||
"only be fixed with token.is_sent_start.")
|
||||
E197 = ("Row out of bounds, unable to add row {row} for key {key}.")
|
||||
E198 = ("Unable to return {n} most similar vectors for the current vectors "
|
||||
"table, which contains {n_rows} vectors.")
|
||||
|
||||
|
||||
@add_codes
|
||||
|
|
|
@ -295,7 +295,7 @@ def test_issue3410():
|
|||
|
||||
def test_issue3412():
|
||||
data = numpy.asarray([[0, 0, 0], [1, 2, 3], [9, 8, 7]], dtype="f")
|
||||
vectors = Vectors(data=data)
|
||||
vectors = Vectors(data=data, keys=["A", "B", "C"])
|
||||
keys, best_rows, scores = vectors.most_similar(
|
||||
numpy.asarray([[9, 8, 7], [0, 0, 0]], dtype="f")
|
||||
)
|
||||
|
|
|
@ -3,7 +3,7 @@ from __future__ import unicode_literals
|
|||
|
||||
import pytest
|
||||
import numpy
|
||||
from numpy.testing import assert_allclose
|
||||
from numpy.testing import assert_allclose, assert_equal
|
||||
from spacy._ml import cosine
|
||||
from spacy.vocab import Vocab
|
||||
from spacy.vectors import Vectors
|
||||
|
@ -11,7 +11,7 @@ from spacy.tokenizer import Tokenizer
|
|||
from spacy.strings import hash_string
|
||||
from spacy.tokens import Doc
|
||||
|
||||
from ..util import add_vecs_to_vocab
|
||||
from ..util import add_vecs_to_vocab, make_tempdir
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
|
@ -59,6 +59,11 @@ def most_similar_vectors_data():
|
|||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def most_similar_vectors_keys():
|
||||
return ["a", "b", "c", "d"]
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def resize_data():
|
||||
return numpy.asarray([[0.0, 1.0], [2.0, 3.0]], dtype="f")
|
||||
|
@ -146,11 +151,14 @@ def test_set_vector(strings, data):
|
|||
assert list(v[strings[0]]) != list(orig[0])
|
||||
|
||||
|
||||
def test_vectors_most_similar(most_similar_vectors_data):
|
||||
v = Vectors(data=most_similar_vectors_data)
|
||||
def test_vectors_most_similar(most_similar_vectors_data, most_similar_vectors_keys):
|
||||
v = Vectors(data=most_similar_vectors_data, keys=most_similar_vectors_keys)
|
||||
_, best_rows, _ = v.most_similar(v.data, batch_size=2, n=2, sort=True)
|
||||
assert all(row[0] == i for i, row in enumerate(best_rows))
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
v.most_similar(v.data, batch_size=2, n=10, sort=True)
|
||||
|
||||
|
||||
def test_vectors_most_similar_identical():
|
||||
"""Test that most similar identical vectors are assigned a score of 1.0."""
|
||||
|
@ -331,6 +339,33 @@ def test_vocab_prune_vectors():
|
|||
assert_allclose(similarity, cosine(data[0], data[2]), atol=1e-4, rtol=1e-3)
|
||||
|
||||
|
||||
def test_vectors_serialize():
|
||||
data = numpy.asarray([[4, 2, 2, 2], [4, 2, 2, 2], [1, 1, 1, 1]], dtype="f")
|
||||
v = Vectors(data=data, keys=["A", "B", "C"])
|
||||
b = v.to_bytes()
|
||||
v_r = Vectors()
|
||||
v_r.from_bytes(b)
|
||||
assert_equal(v.data, v_r.data)
|
||||
assert v.key2row == v_r.key2row
|
||||
v.resize((5, 4))
|
||||
v_r.resize((5, 4))
|
||||
row = v.add("D", vector=numpy.asarray([1, 2, 3, 4], dtype="f"))
|
||||
row_r = v_r.add("D", vector=numpy.asarray([1, 2, 3, 4], dtype="f"))
|
||||
assert row == row_r
|
||||
assert_equal(v.data, v_r.data)
|
||||
assert v.is_full == v_r.is_full
|
||||
with make_tempdir() as d:
|
||||
v.to_disk(d)
|
||||
v_r.from_disk(d)
|
||||
assert_equal(v.data, v_r.data)
|
||||
assert v.key2row == v_r.key2row
|
||||
v.resize((5, 4))
|
||||
v_r.resize((5, 4))
|
||||
row = v.add("D", vector=numpy.asarray([10, 20, 30, 40], dtype="f"))
|
||||
row_r = v_r.add("D", vector=numpy.asarray([10, 20, 30, 40], dtype="f"))
|
||||
assert row == row_r
|
||||
assert_equal(v.data, v_r.data)
|
||||
|
||||
def test_vector_is_oov():
|
||||
vocab = Vocab(vectors_name="test_vocab_is_oov")
|
||||
data = numpy.ndarray((5, 3), dtype="f")
|
||||
|
@ -340,4 +375,4 @@ def test_vector_is_oov():
|
|||
vocab.set_vector("dog", data[1])
|
||||
assert vocab["cat"].is_oov is True
|
||||
assert vocab["dog"].is_oov is True
|
||||
assert vocab["hamster"].is_oov is False
|
||||
assert vocab["hamster"].is_oov is False
|
|
@ -212,8 +212,7 @@ cdef class Vectors:
|
|||
copy_shape = (min(shape[0], self.data.shape[0]), min(shape[1], self.data.shape[1]))
|
||||
resized_array[:copy_shape[0], :copy_shape[1]] = self.data[:copy_shape[0], :copy_shape[1]]
|
||||
self.data = resized_array
|
||||
filled = {row for row in self.key2row.values()}
|
||||
self._unset = cppset[int]({row for row in range(shape[0]) if row not in filled})
|
||||
self._sync_unset()
|
||||
removed_items = []
|
||||
for key, row in list(self.key2row.items()):
|
||||
if row >= shape[0]:
|
||||
|
@ -310,8 +309,8 @@ cdef class Vectors:
|
|||
raise ValueError(Errors.E197.format(row=row, key=key))
|
||||
if vector is not None:
|
||||
self.data[row] = vector
|
||||
if self._unset.count(row):
|
||||
self._unset.erase(self._unset.find(row))
|
||||
if self._unset.count(row):
|
||||
self._unset.erase(self._unset.find(row))
|
||||
return row
|
||||
|
||||
def most_similar(self, queries, *, batch_size=1024, n=1, sort=True):
|
||||
|
@ -330,11 +329,14 @@ cdef class Vectors:
|
|||
RETURNS (tuple): The most similar entries as a `(keys, best_rows, scores)`
|
||||
tuple.
|
||||
"""
|
||||
filled = sorted(list({row for row in self.key2row.values()}))
|
||||
if len(filled) < n:
|
||||
raise ValueError(Errors.E198.format(n=n, n_rows=len(filled)))
|
||||
xp = get_array_module(self.data)
|
||||
|
||||
norms = xp.linalg.norm(self.data, axis=1, keepdims=True)
|
||||
norms = xp.linalg.norm(self.data[filled], axis=1, keepdims=True)
|
||||
norms[norms == 0] = 1
|
||||
vectors = self.data / norms
|
||||
vectors = self.data[filled] / norms
|
||||
|
||||
best_rows = xp.zeros((queries.shape[0], n), dtype='i')
|
||||
scores = xp.zeros((queries.shape[0], n), dtype='f')
|
||||
|
@ -356,7 +358,8 @@ cdef class Vectors:
|
|||
scores[i:i+batch_size] = scores[sorted_index]
|
||||
best_rows[i:i+batch_size] = best_rows[sorted_index]
|
||||
|
||||
xp = get_array_module(self.data)
|
||||
for i, j in numpy.ndindex(best_rows.shape):
|
||||
best_rows[i, j] = filled[best_rows[i, j]]
|
||||
# Round values really close to 1 or -1
|
||||
scores = xp.around(scores, decimals=4, out=scores)
|
||||
# Account for numerical error we want to return in range -1, 1
|
||||
|
@ -419,6 +422,7 @@ cdef class Vectors:
|
|||
("vectors", load_vectors),
|
||||
))
|
||||
util.from_disk(path, serializers, [])
|
||||
self._sync_unset()
|
||||
return self
|
||||
|
||||
def to_bytes(self, **kwargs):
|
||||
|
@ -461,4 +465,9 @@ cdef class Vectors:
|
|||
("vectors", deserialize_weights)
|
||||
))
|
||||
util.from_bytes(data, deserializers, [])
|
||||
self._sync_unset()
|
||||
return self
|
||||
|
||||
def _sync_unset(self):
|
||||
filled = {row for row in self.key2row.values()}
|
||||
self._unset = cppset[int]({row for row in range(self.data.shape[0]) if row not in filled})
|
||||
|
|
Loading…
Reference in New Issue