Support history features in stateclass

This commit is contained in:
Matthew Honnibal 2017-10-03 12:43:48 +02:00
parent 6aa6a5bc25
commit ee41e4fea7
2 changed files with 36 additions and 2 deletions

View File

@ -1,4 +1,4 @@
from libc.string cimport memcpy, memset
from libc.string cimport memcpy, memset, memmove
from libc.stdlib cimport malloc, calloc, free
from libc.stdint cimport uint32_t, uint64_t
@ -15,6 +15,23 @@ from ..typedefs cimport attr_t
cdef inline bint is_space_token(const TokenC* token) nogil:
return Lexeme.c_check_flag(token.lex, IS_SPACE)
cdef struct RingBufferC:
int[8] data
int i
int default
cdef inline int ring_push(RingBufferC* ring, int value) nogil:
ring.data[ring.i] = value
ring.i += 1
if ring.i >= 8:
ring.i = 0
cdef inline int ring_get(RingBufferC* ring, int i) nogil:
if i >= ring.i:
return ring.default
else:
return ring.data[ring.i-i]
cdef cppclass StateC:
int* _stack
@ -23,6 +40,7 @@ cdef cppclass StateC:
TokenC* _sent
Entity* _ents
TokenC _empty_token
RingBufferC _hist
int length
int offset
int _s_i
@ -37,6 +55,7 @@ cdef cppclass StateC:
this.shifted = <bint*>calloc(length + (PADDING * 2), sizeof(bint))
this._sent = <TokenC*>calloc(length + (PADDING * 2), sizeof(TokenC))
this._ents = <Entity*>calloc(length + (PADDING * 2), sizeof(Entity))
memset(&this._hist, 0, sizeof(this._hist))
this.offset = 0
cdef int i
for i in range(length + (PADDING * 2)):
@ -271,7 +290,14 @@ cdef cppclass StateC:
sig[8] = this.B_(0)[0]
sig[9] = this.E_(0)[0]
sig[10] = this.E_(1)[0]
return hash64(sig, sizeof(sig), this._s_i)
return hash64(sig, sizeof(sig), this._s_i) \
+ hash64(<void*>&this._hist, sizeof(RingBufferC), 1)
void push_hist(int act) nogil:
ring_push(&this._hist, act)
int get_hist(int i) nogil:
return ring_get(&this._hist, i)
void push() nogil:
if this.B(0) != -1:

View File

@ -4,6 +4,7 @@ from __future__ import unicode_literals
from libc.string cimport memcpy, memset
from libc.stdint cimport uint32_t, uint64_t
import numpy
from ..vocab cimport EMPTY_LEXEME
from ..structs cimport Entity
@ -38,6 +39,13 @@ cdef class StateClass:
def token_vector_lenth(self):
return self.doc.tensor.shape[1]
@property
def history(self):
hist = numpy.ndarray((8,), dtype='i')
for i in range(8):
hist[i] = self.c.get_hist(i+1)
return hist
def is_final(self):
return self.c.is_final()