diff --git a/spacy/tests/vocab_vectors/test_vectors.py b/spacy/tests/vocab_vectors/test_vectors.py index 4b2e171a6..0b0fd89dc 100644 --- a/spacy/tests/vocab_vectors/test_vectors.py +++ b/spacy/tests/vocab_vectors/test_vectors.py @@ -50,6 +50,13 @@ def ngrams_vocab(en_vocab, ngrams_vectors): def data(): return numpy.asarray([[0.0, 1.0, 2.0], [3.0, -2.0, 4.0]], dtype="f") +@pytest.fixture +def most_similar_vectors_data(): + return numpy.asarray([[0.0, 1.0, 2.0], + [1.0, -2.0, 4.0], + [1.0, 1.0, -1.0], + [2.0, 3.0, 1.0]], dtype="f") + @pytest.fixture def resize_data(): @@ -127,6 +134,12 @@ def test_set_vector(strings, data): assert list(v[strings[0]]) != list(orig[0]) +def test_vectors_most_similar(most_similar_vectors_data): + v = Vectors(data=most_similar_vectors_data) + _, best_rows, _ = v.most_similar(v.data, batch_size=2, n=2, sort=True) + assert all(row[0] == i for i, row in enumerate(best_rows)) + + @pytest.mark.parametrize("text", ["apple and orange"]) def test_vectors_token_vector(tokenizer_v, vectors, text): doc = tokenizer_v(text) @@ -284,7 +297,7 @@ def test_vocab_prune_vectors(): vocab.set_vector("dog", data[1]) vocab.set_vector("kitten", data[2]) - remap = vocab.prune_vectors(2) + remap = vocab.prune_vectors(2, batch_size=2) assert list(remap.keys()) == ["kitten"] neighbour, similarity = list(remap.values())[0] assert neighbour == "cat", remap diff --git a/spacy/vectors.pyx b/spacy/vectors.pyx index 881f01052..6ad1202de 100644 --- a/spacy/vectors.pyx +++ b/spacy/vectors.pyx @@ -336,8 +336,8 @@ cdef class Vectors: best_rows[i:i+batch_size] = xp.argpartition(sims, -n, axis=1)[:,-n:] scores[i:i+batch_size] = xp.partition(sims, -n, axis=1)[:,-n:] - if sort: - sorted_index = xp.arange(scores.shape[0])[:,None],xp.argsort(scores[i:i+batch_size], axis=1)[:,::-1] + if sort and n >= 2: + sorted_index = xp.arange(scores.shape[0])[:,None][i:i+batch_size],xp.argsort(scores[i:i+batch_size], axis=1)[:,::-1] scores[i:i+batch_size] = scores[sorted_index] best_rows[i:i+batch_size] = best_rows[sorted_index]