Modernise Huffman tests

This commit is contained in:
Ines Montani 2017-01-12 21:58:40 +01:00
parent edeeeccea5
commit 5dbc6e59f6
1 changed files with 23 additions and 24 deletions

View File

@ -1,15 +1,15 @@
# coding: utf-8
from __future__ import unicode_literals
from __future__ import division
import pytest
from ...serialize.huffman import HuffmanCodec
from ...serialize.bits import BitArray
from spacy.serialize.huffman import HuffmanCodec
from spacy.serialize.bits import BitArray
import numpy
import math
from heapq import heappush, heappop, heapify
from collections import defaultdict
import numpy
import pytest
def py_encode(symb2freq):
@ -29,7 +29,7 @@ def py_encode(symb2freq):
return dict(heappop(heap)[1:])
def test1():
def test_serialize_huffman_1():
probs = numpy.zeros(shape=(10,), dtype=numpy.float32)
probs[0] = 0.3
probs[1] = 0.2
@ -41,45 +41,44 @@ def test1():
probs[7] = 0.005
probs[8] = 0.0001
probs[9] = 0.000001
codec = HuffmanCodec(list(enumerate(probs)))
py_codes = py_encode(dict(enumerate(probs)))
py_codes = list(py_codes.items())
assert codec.strings == [c for i, c in py_codes]
def test_empty():
def test_serialize_huffman_empty():
codec = HuffmanCodec({})
assert codec.strings == []
def test_round_trip():
freqs = {'the': 10, 'quick': 3, 'brown': 4, 'fox': 1, 'jumped': 5, 'over': 8,
'lazy': 1, 'dog': 2, '.': 9}
def test_serialize_huffman_round_trip():
words = ['the', 'quick', 'brown', 'fox', 'jumped', 'over', 'the', 'the',
'lazy', 'dog', '.']
freqs = {'the': 10, 'quick': 3, 'brown': 4, 'fox': 1, 'jumped': 5,
'over': 8, 'lazy': 1, 'dog': 2, '.': 9}
codec = HuffmanCodec(freqs.items())
message = ['the', 'quick', 'brown', 'fox', 'jumped', 'over', 'the',
'the', 'lazy', 'dog', '.']
strings = list(codec.strings)
codes = dict([(codec.leaves[i], strings[i]) for i in range(len(codec.leaves))])
bits = codec.encode(message)
bits = codec.encode(words)
string = ''.join('{0:b}'.format(c).rjust(8, '0')[::-1] for c in bits.as_bytes())
for word in message:
for word in words:
code = codes[word]
assert string[:len(code)] == code
string = string[len(code):]
unpacked = [0] * len(message)
unpacked = [0] * len(words)
codec.decode(bits, unpacked)
assert message == unpacked
assert words == unpacked
def test_rosetta():
txt = u"this is an example for huffman encoding"
def test_serialize_huffman_rosetta():
text = "this is an example for huffman encoding"
symb2freq = defaultdict(int)
for ch in txt:
for ch in text:
symb2freq[ch] += 1
by_freq = list(symb2freq.items())
by_freq.sort(reverse=True, key=lambda item: item[1])
@ -101,7 +100,7 @@ def test_rosetta():
assert my_exp_len == py_exp_len
def test_vocab(EN):
codec = HuffmanCodec([(w.orth, numpy.exp(w.prob)) for w in EN.vocab])
expected_length = 0