2017-01-12 14:09:49 +00:00
|
|
|
# coding: utf-8
|
|
|
|
from __future__ import unicode_literals
|
2016-10-21 15:07:21 +00:00
|
|
|
|
2017-01-12 14:09:49 +00:00
|
|
|
import numpy
|
2017-10-31 10:40:46 +00:00
|
|
|
from numpy.testing import assert_allclose
|
2017-10-31 01:00:26 +00:00
|
|
|
from ...vocab import Vocab
|
|
|
|
from ..._ml import cosine
|
2016-10-21 15:07:21 +00:00
|
|
|
|
|
|
|
|
2017-10-31 01:00:26 +00:00
|
|
|
def test_vocab_add_vector():
|
|
|
|
vocab = Vocab()
|
|
|
|
data = numpy.ndarray((5,3), dtype='f')
|
|
|
|
data[0] = 1.
|
|
|
|
data[1] = 2.
|
|
|
|
vocab.set_vector(u'cat', data[0])
|
|
|
|
vocab.set_vector(u'dog', data[1])
|
|
|
|
cat = vocab[u'cat']
|
|
|
|
assert list(cat.vector) == [1., 1., 1.]
|
|
|
|
dog = vocab[u'dog']
|
|
|
|
assert list(dog.vector) == [2., 2., 2.]
|
|
|
|
|
|
|
|
|
|
|
|
def test_vocab_prune_vectors():
|
|
|
|
vocab = Vocab()
|
|
|
|
_ = vocab[u'cat']
|
|
|
|
_ = vocab[u'dog']
|
|
|
|
_ = vocab[u'kitten']
|
|
|
|
data = numpy.ndarray((5,3), dtype='f')
|
|
|
|
data[0] = 1.
|
|
|
|
data[1] = 2.
|
|
|
|
data[2] = 1.1
|
|
|
|
vocab.set_vector(u'cat', data[0])
|
|
|
|
vocab.set_vector(u'dog', data[1])
|
|
|
|
vocab.set_vector(u'kitten', data[2])
|
|
|
|
|
|
|
|
remap = vocab.prune_vectors(2)
|
2017-10-31 10:40:46 +00:00
|
|
|
assert list(remap.keys()) == [u'kitten']
|
|
|
|
neighbour, similarity = remap.values()[0]
|
2017-10-31 17:25:08 +00:00
|
|
|
assert neighbour == u'cat', remap
|
2017-10-31 10:40:46 +00:00
|
|
|
assert_allclose(similarity, cosine(data[0], data[2]), atol=1e-6)
|