* Work on serialization. Needs more reorganisation

This commit is contained in:
Matthew Honnibal 2015-07-16 19:55:47 +02:00
parent d8458d6a25
commit c8282f9934
2 changed files with 17 additions and 9 deletions

View File

@ -4,6 +4,8 @@ from libc.stdint cimport int64_t
from libc.stdint cimport int32_t from libc.stdint cimport int32_t
from libc.stdint cimport uint64_t from libc.stdint cimport uint64_t
from .vocab cimport Vocab
cdef struct Node: cdef struct Node:
float prob float prob
@ -18,6 +20,7 @@ cdef struct Code:
cdef class Serializer: cdef class Serializer:
cdef list codecs cdef list codecs
cdef Vocab vocab
cdef class HuffmanCodec: cdef class HuffmanCodec:

View File

@ -3,10 +3,15 @@ from libc.stdint cimport uint32_t
from libc.stdint cimport int64_t from libc.stdint cimport int64_t
from libc.stdint cimport int32_t from libc.stdint cimport int32_t
from libc.stdint cimport uint64_t from libc.stdint cimport uint64_t
from libcpp.queue cimport priority_queue
from libcpp.pair cimport pair
from preshed.maps cimport PreshMap from preshed.maps cimport PreshMap
from murmurhash.mrmr cimport hash64 from murmurhash.mrmr cimport hash64
from .tokens.doc cimport Doc
from .vocab cimport Vocab
from os import path
import numpy import numpy
cimport cython cimport cython
@ -97,7 +102,7 @@ cdef class Serializer:
def __init__(self, Vocab vocab, data_dir): def __init__(self, Vocab vocab, data_dir):
model_dir = path.join(data_dir, 'bitter') model_dir = path.join(data_dir, 'bitter')
self.vocab = vocab # Vocab owns the word codec, the big one self.vocab = vocab # Vocab owns the word codec, the big one
self.cfg = Config.read(model_dir, 'config') #self.cfg = Config.read(model_dir, 'config')
self.codecs = tuple([CodecWrapper(attr) for attr in self.cfg.attrs]) self.codecs = tuple([CodecWrapper(attr) for attr in self.cfg.attrs])
def __call__(self, doc_or_bits): def __call__(self, doc_or_bits):
@ -129,7 +134,7 @@ cdef class Serializer:
cdef bint is_spacy cdef bint is_spacy
for id_ in ids: for id_ in ids:
is_spacy = biterator.next() is_spacy = biterator.next()
doc.push_back(vocab.lexemes.at(id_), is_spacy) doc.push_back(self.vocab.lexemes.at(id_), is_spacy)
cdef int length = doc.length cdef int length = doc.length
array = numpy.zeros(shape=(length, len(self.codecs)), dtype=numpy.int) array = numpy.zeros(shape=(length, len(self.codecs)), dtype=numpy.int)
@ -139,20 +144,20 @@ cdef class Serializer:
return doc return doc
cdef class AttributeEncoder: cdef class CodecWrapper:
"""Wrapper around HuffmanCodec""" """Wrapper around HuffmanCodec"""
def __init__(self, freqs, id=0): def __init__(self, freqs, id=0):
cdef uint64_t key cdef uint64_t key
cdef uint64_t count cdef uint64_t count
cdef pair[uint64_t] item cdef pair[uint64_t, uint64_t] item
cdef priority_queue[pair[uint64_t]] items cdef priority_queue[pair[uint64_t, uint64_t]] items
for key, count in freqs: for key, count in freqs:
item.first = count item.first = count
item.second = key item.second = key
items.push(item) items.push(item)
weights = array('f') weights = [] #array('f')
keys = array('i') keys = [] #array('i')
key_to_i = PreshMap() key_to_i = PreshMap()
i = 0 i = 0
while not items.empty(): while not items.empty():
@ -188,8 +193,8 @@ cdef class HuffmanCodec:
eol (uint32_t): The index of the weight of the EOL symbol. eol (uint32_t): The index of the weight of the EOL symbol.
""" """
def __init__(self, float[:] weights, unt32_t eol): def __init__(self, float[:] weights, uint32_t eol):
self.codes.resize(len(probs)) self.codes.resize(len(weights))
for i in range(len(self.codes)): for i in range(len(self.codes)):
self.codes[i].bits = 0 self.codes[i].bits = 0
self.codes[i].length = 0 self.codes[i].length = 0