Fix test for vectors

This commit is contained in:
Matthew Honnibal 2017-08-19 22:09:12 +02:00
parent b8e1603cc4
commit 41c2218c53
1 changed files with 6 additions and 0 deletions

View File

@ -40,11 +40,15 @@ def test_init_vectors_with_data(strings, data):
def test_init_vectors_with_width(strings): def test_init_vectors_with_width(strings):
v = Vectors(strings, 3) v = Vectors(strings, 3)
for string in strings:
v.add(string)
assert v.shape == (len(strings), 3) assert v.shape == (len(strings), 3)
def test_get_vector(strings, data): def test_get_vector(strings, data):
v = Vectors(strings, data) v = Vectors(strings, data)
for string in strings:
v.add(string)
assert list(v[strings[0]]) == list(data[0]) assert list(v[strings[0]]) == list(data[0])
assert list(v[strings[0]]) != list(data[1]) assert list(v[strings[0]]) != list(data[1])
assert list(v[strings[1]]) != list(data[0]) assert list(v[strings[1]]) != list(data[0])
@ -53,6 +57,8 @@ def test_get_vector(strings, data):
def test_set_vector(strings, data): def test_set_vector(strings, data):
orig = data.copy() orig = data.copy()
v = Vectors(strings, data) v = Vectors(strings, data)
for string in strings:
v.add(string)
assert list(v[strings[0]]) == list(orig[0]) assert list(v[strings[0]]) == list(orig[0])
assert list(v[strings[0]]) != list(orig[1]) assert list(v[strings[0]]) != list(orig[1])
v[strings[0]] = data[1] v[strings[0]] = data[1]