* Update tests for python3

This commit is contained in:
Matthew Honnibal 2015-07-24 03:47:59 +02:00
parent 1ab25e4dad
commit ce984f471c
3 changed files with 5 additions and 5 deletions

View File

@ -45,7 +45,7 @@ def test1():
codec = HuffmanCodec(list(enumerate(probs))) codec = HuffmanCodec(list(enumerate(probs)))
py_codes = py_encode(dict(enumerate(probs))) py_codes = py_encode(dict(enumerate(probs)))
py_codes = py_codes.items() py_codes = list(py_codes.items())
py_codes.sort() py_codes.sort()
assert codec.strings == [c for i, c in py_codes] assert codec.strings == [c for i, c in py_codes]
@ -60,7 +60,7 @@ def test_round_trip():
strings = list(codec.strings) strings = list(codec.strings)
codes = {codec.leaves[i]: strings[i] for i in range(len(codec.leaves))} codes = {codec.leaves[i]: strings[i] for i in range(len(codec.leaves))}
bits = codec.encode(message) bits = codec.encode(message)
string = b''.join(b'{0:b}'.format(ord(c)).rjust(8, b'0')[::-1] for c in bits.as_bytes()) string = ''.join('{0:b}'.format(c).rjust(8, '0')[::-1] for c in bits.as_bytes())
for word in message: for word in message:
code = codes[word] code = codes[word]
assert string[:len(code)] == code assert string[:len(code)] == code
@ -76,7 +76,7 @@ def test_rosetta():
symb2freq = defaultdict(int) symb2freq = defaultdict(int)
for ch in txt: for ch in txt:
symb2freq[ch] += 1 symb2freq[ch] += 1
by_freq = symb2freq.items() by_freq = list(symb2freq.items())
by_freq.sort(reverse=True, key=lambda item: item[1]) by_freq.sort(reverse=True, key=lambda item: item[1])
symbols = [sym for sym, prob in by_freq] symbols = [sym for sym, prob in by_freq]

View File

@ -61,7 +61,7 @@ def test_char_packer(vocab):
bits.seek(0) bits.seek(0)
result = [b''] * len(byte_str) result = [b''] * len(byte_str)
packer.char_codec.decode(bits, result) packer.char_codec.decode(bits, result)
assert b''.join(result) == byte_str assert bytearray(result) == byte_str
def test_packer_unannotated(tokenizer): def test_packer_unannotated(tokenizer):

View File

@ -39,7 +39,7 @@ def test_retrieve_id(sstore):
def test_med_string(sstore): def test_med_string(sstore):
nine_char_string = sstore[b'0123456789'] nine_char_string = sstore[b'0123456789']
assert sstore[nine_char_string] == b'0123456789' assert sstore[nine_char_string] == u'0123456789'
dummy = sstore[b'A'] dummy = sstore[b'A']
assert sstore[b'0123456789'] == nine_char_string assert sstore[b'0123456789'] == nine_char_string