mirror of https://github.com/explosion/spaCy.git
Support history features in stateclass
This commit is contained in:
parent
6aa6a5bc25
commit
ee41e4fea7
|
@ -1,4 +1,4 @@
|
|||
from libc.string cimport memcpy, memset
|
||||
from libc.string cimport memcpy, memset, memmove
|
||||
from libc.stdlib cimport malloc, calloc, free
|
||||
from libc.stdint cimport uint32_t, uint64_t
|
||||
|
||||
|
@ -15,6 +15,23 @@ from ..typedefs cimport attr_t
|
|||
cdef inline bint is_space_token(const TokenC* token) nogil:
|
||||
return Lexeme.c_check_flag(token.lex, IS_SPACE)
|
||||
|
||||
cdef struct RingBufferC:
|
||||
int[8] data
|
||||
int i
|
||||
int default
|
||||
|
||||
cdef inline int ring_push(RingBufferC* ring, int value) nogil:
|
||||
ring.data[ring.i] = value
|
||||
ring.i += 1
|
||||
if ring.i >= 8:
|
||||
ring.i = 0
|
||||
|
||||
cdef inline int ring_get(RingBufferC* ring, int i) nogil:
|
||||
if i >= ring.i:
|
||||
return ring.default
|
||||
else:
|
||||
return ring.data[ring.i-i]
|
||||
|
||||
|
||||
cdef cppclass StateC:
|
||||
int* _stack
|
||||
|
@ -23,6 +40,7 @@ cdef cppclass StateC:
|
|||
TokenC* _sent
|
||||
Entity* _ents
|
||||
TokenC _empty_token
|
||||
RingBufferC _hist
|
||||
int length
|
||||
int offset
|
||||
int _s_i
|
||||
|
@ -37,6 +55,7 @@ cdef cppclass StateC:
|
|||
this.shifted = <bint*>calloc(length + (PADDING * 2), sizeof(bint))
|
||||
this._sent = <TokenC*>calloc(length + (PADDING * 2), sizeof(TokenC))
|
||||
this._ents = <Entity*>calloc(length + (PADDING * 2), sizeof(Entity))
|
||||
memset(&this._hist, 0, sizeof(this._hist))
|
||||
this.offset = 0
|
||||
cdef int i
|
||||
for i in range(length + (PADDING * 2)):
|
||||
|
@ -271,7 +290,14 @@ cdef cppclass StateC:
|
|||
sig[8] = this.B_(0)[0]
|
||||
sig[9] = this.E_(0)[0]
|
||||
sig[10] = this.E_(1)[0]
|
||||
return hash64(sig, sizeof(sig), this._s_i)
|
||||
return hash64(sig, sizeof(sig), this._s_i) \
|
||||
+ hash64(<void*>&this._hist, sizeof(RingBufferC), 1)
|
||||
|
||||
void push_hist(int act) nogil:
|
||||
ring_push(&this._hist, act)
|
||||
|
||||
int get_hist(int i) nogil:
|
||||
return ring_get(&this._hist, i)
|
||||
|
||||
void push() nogil:
|
||||
if this.B(0) != -1:
|
||||
|
|
|
@ -4,6 +4,7 @@ from __future__ import unicode_literals
|
|||
|
||||
from libc.string cimport memcpy, memset
|
||||
from libc.stdint cimport uint32_t, uint64_t
|
||||
import numpy
|
||||
|
||||
from ..vocab cimport EMPTY_LEXEME
|
||||
from ..structs cimport Entity
|
||||
|
@ -38,6 +39,13 @@ cdef class StateClass:
|
|||
def token_vector_lenth(self):
|
||||
return self.doc.tensor.shape[1]
|
||||
|
||||
@property
|
||||
def history(self):
|
||||
hist = numpy.ndarray((8,), dtype='i')
|
||||
for i in range(8):
|
||||
hist[i] = self.c.get_hist(i+1)
|
||||
return hist
|
||||
|
||||
def is_final(self):
|
||||
return self.c.is_final()
|
||||
|
||||
|
|
Loading…
Reference in New Issue