Most similar bug (#4446)

* Add batch size indexing

* Don't sort if n == 1

* Add test for most similar vectors issue

* Change > to >=
This commit is contained in:
Daniel King 2019-10-16 14:18:55 -07:00 committed by Matthew Honnibal
parent 4a77d03ff7
commit e646956176
2 changed files with 16 additions and 3 deletions

View File

@ -50,6 +50,13 @@ def ngrams_vocab(en_vocab, ngrams_vectors):
def data():
return numpy.asarray([[0.0, 1.0, 2.0], [3.0, -2.0, 4.0]], dtype="f")
@pytest.fixture
def most_similar_vectors_data():
return numpy.asarray([[0.0, 1.0, 2.0],
[1.0, -2.0, 4.0],
[1.0, 1.0, -1.0],
[2.0, 3.0, 1.0]], dtype="f")
@pytest.fixture
def resize_data():
@ -127,6 +134,12 @@ def test_set_vector(strings, data):
assert list(v[strings[0]]) != list(orig[0])
def test_vectors_most_similar(most_similar_vectors_data):
v = Vectors(data=most_similar_vectors_data)
_, best_rows, _ = v.most_similar(v.data, batch_size=2, n=2, sort=True)
assert all(row[0] == i for i, row in enumerate(best_rows))
@pytest.mark.parametrize("text", ["apple and orange"])
def test_vectors_token_vector(tokenizer_v, vectors, text):
doc = tokenizer_v(text)
@ -284,7 +297,7 @@ def test_vocab_prune_vectors():
vocab.set_vector("dog", data[1])
vocab.set_vector("kitten", data[2])
remap = vocab.prune_vectors(2)
remap = vocab.prune_vectors(2, batch_size=2)
assert list(remap.keys()) == ["kitten"]
neighbour, similarity = list(remap.values())[0]
assert neighbour == "cat", remap

View File

@ -336,8 +336,8 @@ cdef class Vectors:
best_rows[i:i+batch_size] = xp.argpartition(sims, -n, axis=1)[:,-n:]
scores[i:i+batch_size] = xp.partition(sims, -n, axis=1)[:,-n:]
if sort:
sorted_index = xp.arange(scores.shape[0])[:,None],xp.argsort(scores[i:i+batch_size], axis=1)[:,::-1]
if sort and n >= 2:
sorted_index = xp.arange(scores.shape[0])[:,None][i:i+batch_size],xp.argsort(scores[i:i+batch_size], axis=1)[:,::-1]
scores[i:i+batch_size] = scores[sorted_index]
best_rows[i:i+batch_size] = best_rows[sorted_index]