* Implement character-based codec, so that we can do word/char backoff

This commit is contained in:
Matthew Honnibal 2015-07-19 22:03:39 +02:00
parent cd1d047cb8
commit ae78c9e3ce
1 changed files with 68 additions and 53 deletions

View File

@ -1,12 +1,18 @@
# cython: profile=True # cython: profile=True
from __future__ import unicode_literals
from libc.stdint cimport uint32_t from libc.stdint cimport uint32_t
from libc.stdint cimport uint64_t from libc.stdint cimport uint64_t
from libc.math cimport exp as c_exp from libc.math cimport exp as c_exp
from libcpp.queue cimport priority_queue from libcpp.queue cimport priority_queue
from libcpp.pair cimport pair from libcpp.pair cimport pair
from ..structs cimport UniStr
from ..strings cimport slice_unicode
from cymem.cymem cimport Address, Pool from cymem.cymem cimport Address, Pool
from preshed.maps cimport PreshMap from preshed.maps cimport PreshMap
from preshed.counter cimport PreshCounter
from ..attrs cimport ORTH, ID, SPACY, TAG, HEAD, DEP, ENT_IOB, ENT_TYPE from ..attrs cimport ORTH, ID, SPACY, TAG, HEAD, DEP, ENT_IOB, ENT_TYPE
from ..tokens.doc cimport Doc from ..tokens.doc cimport Doc
@ -51,51 +57,6 @@ cdef class _BinaryCodec:
break break
cdef class _AttributeCodec:
cdef Pool mem
cdef attr_t* _keys
cdef dict _map
cdef HuffmanCodec _codec
def __init__(self, freqs):
self.mem = Pool()
cdef attr_t key
cdef float count
cdef pair[float, attr_t] item
cdef priority_queue[pair[float, attr_t]] items
for key, count in freqs:
item.first = count
item.second = key
items.push(item)
weights = numpy.ndarray(shape=(items.size(),), dtype=numpy.float32)
self._keys = <attr_t*>self.mem.alloc(items.size(), sizeof(attr_t))
self._map = {}
cdef int i = 0
while not items.empty():
item = items.top()
# We put freq first above, for sorting
self._keys[i] = item.second
weights[i] = item.first
self._map[self._keys[i]] = i
items.pop()
i += 1
self._codec = HuffmanCodec(weights)
def encode(self, attr_t[:] msg, BitArray dest):
cdef int i
for i in range(len(msg)):
msg[i] = self._map[msg[i]]
self._codec.encode(msg, dest)
def decode(self, BitArray bits, attr_t[:] dest):
cdef int i
self._codec.decode(bits, dest)
for i in range(len(dest)):
dest[i] = <attr_t>self._keys[dest[i]]
def _gen_orths(Vocab vocab): def _gen_orths(Vocab vocab):
cdef attr_t orth cdef attr_t orth
cdef size_t addr cdef size_t addr
@ -104,17 +65,34 @@ def _gen_orths(Vocab vocab):
yield orth, c_exp(lex.prob) yield orth, c_exp(lex.prob)
def _gen_chars(Vocab vocab):
cdef attr_t orth
cdef size_t addr
char_weights = {b' ': 0.0}
cdef unicode string
cdef unicode char
for orth, addr in vocab._by_orth.items():
lex = <LexemeC*>addr
string = vocab.strings[lex.orth]
for char in string:
char_weights.setdefault(char, 0.0)
char_weights[char] += c_exp(lex.prob)
char_weights[u' '] += c_exp(lex.prob)
return char_weights.items()
cdef class Packer: cdef class Packer:
def __init__(self, Vocab vocab, attr_freqs): def __init__(self, Vocab vocab, attr_freqs):
self.vocab = vocab self.vocab = vocab
self.lex_codec = _AttributeCodec(_gen_orths(vocab)) self.orth_codec = HuffmanCodec(_gen_orths(vocab))
self.char_codec = HuffmanCodec(_gen_chars(vocab))
codecs = [_AttributeCodec(_gen_orths(vocab)), _BinaryCodec()] codecs = []
attrs = [ORTH, SPACY] attrs = []
for attr, freqs in sorted(attr_freqs): for attr, freqs in sorted(attr_freqs):
if attr in (ORTH, ID, SPACY): if attr in (ORTH, ID, SPACY):
continue continue
codecs.append(_AttributeCodec(freqs)) codecs.append(HuffmanCodec(freqs))
attrs.append(attr) attrs.append(attr)
self._codecs = tuple(codecs) self._codecs = tuple(codecs)
self.attrs = tuple(attrs) self.attrs = tuple(attrs)
@ -124,10 +102,11 @@ cdef class Packer:
return cls(vocab, util.read_encoding_freqs(data_dir)) return cls(vocab, util.read_encoding_freqs(data_dir))
def pack(self, Doc doc): def pack(self, Doc doc):
array = doc.to_array(self.attrs)
cdef BitArray bits = BitArray() cdef BitArray bits = BitArray()
cdef uint32_t length = len(doc) cdef uint32_t length = len(doc.string)
bits.extend(length, 32) bits.extend(length, 32)
self._char_encode(doc, bits)
array = doc.to_array(self.attrs)
for i, codec in enumerate(self._codecs): for i, codec in enumerate(self._codecs):
codec.encode(array[:, i], bits) codec.encode(array[:, i], bits)
return bits return bits
@ -135,7 +114,43 @@ cdef class Packer:
def unpack(self, BitArray bits): def unpack(self, BitArray bits):
bits.seek(0) bits.seek(0)
cdef uint32_t length = bits.read32() cdef uint32_t length = bits.read32()
array = numpy.zeros(shape=(length, len(self._codecs)), dtype=numpy.int32) doc = self._char_decode(bits, length)
array = numpy.zeros(shape=(len(doc), len(self._codecs)), dtype=numpy.int32)
for i, codec in enumerate(self._codecs): for i, codec in enumerate(self._codecs):
codec.decode(bits, array[:, i]) codec.decode(bits, array[:, i])
return array
doc.from_array(self.attrs, array)
return doc
def _char_encode(self, Doc doc, BitArray bits):
cdef unicode string = doc.string
self.char_codec.encode(string, bits)
for token in doc:
for i in range(len(token)-1):
bits.append(False)
bits.append(True)
if token.whitespace_:
bits.append(False)
def _char_decode(self, BitArray bits, n):
chars = [u''] * n
self.char_codec.decode(bits, chars)
cdef unicode string = u''.join(chars)
cdef Doc tokens = Doc(self.vocab)
cdef int i
cdef int start = 0
cdef bint is_spacy
cdef UniStr span
cdef int length = len(string)
iter_bits = iter(bits)
for i in range(length):
is_end_token = iter_bits.next()
if is_end_token:
slice_unicode(&span, string, start, i+1)
lex = self.vocab.get(tokens.mem, &span)
is_spacy = (i+1) < length and string[i+1] == u' '
tokens.push_back(lex, is_spacy)
start = i + 1 + is_spacy
return tokens