From ba10fd8af5442168eba704967b09c039ff2110f3 Mon Sep 17 00:00:00 2001 From: Matthew Honnibal Date: Tue, 9 Jun 2015 01:39:54 +0200 Subject: [PATCH] * Add StateClass, to replace/refactor the mess in _state --- spacy/syntax/stateclass.pxd | 93 ++++++++++++++++++++++++++ spacy/syntax/stateclass.pyx | 126 ++++++++++++++++++++++++++++++++++++ 2 files changed, 219 insertions(+) create mode 100644 spacy/syntax/stateclass.pxd create mode 100644 spacy/syntax/stateclass.pyx diff --git a/spacy/syntax/stateclass.pxd b/spacy/syntax/stateclass.pxd new file mode 100644 index 000000000..63e22cac5 --- /dev/null +++ b/spacy/syntax/stateclass.pxd @@ -0,0 +1,93 @@ +from libc.string cimport memcpy, memset + +from cymem.cymem cimport Pool + +from structs cimport TokenC + +from .syntax._state cimport State + +from .vocab cimport EMPTY_LEXEME + + +cdef TokenC EMPTY_TOKEN + + +cdef class StateClass: + cdef Pool mem + cdef int* _stack + cdef int* _buffer + cdef TokenC* _sent + cdef int length + cdef int _s_i + cdef int _b_i + + @staticmethod + cdef inline StateClass init(const TokenC* sent, int length): + cdef StateClass self = StateClass(length) + memcpy(self._sent, sent, sizeof(TokenC*) * length) + return self + + @staticmethod + cdef inline StateClass from_struct(Pool mem, const State* state): + cdef StateClass self = StateClass.init(state.sent, state.sent_len) + memcpy(self._stack, state.stack - state.stack_len, sizeof(int) * state.stack_len) + self._s_i = state.stack_len - 1 + self._b_i = state.i + return self + + cdef inline const TokenC* S_(self, int i) nogil: + return self.safe_get(self.S(i)) + + cdef inline const TokenC* B_(self, int i) nogil: + return self.safe_get(self.B(i)) + + cdef inline const TokenC* H_(self, int i) nogil: + return self.safe_get(self.B(i)) + + cdef inline const TokenC* L_(self, int i, int idx) nogil: + return self.safe_get(self.L(i, idx)) + + cdef inline const TokenC* R_(self, int i, int idx) nogil: + return self.safe_get(self.R(i, idx)) + + cdef inline const TokenC* safe_get(self, int i) nogil: + if 0 >= i >= self.length: + return &EMPTY_TOKEN + else: + return self._sent + + cdef int S(self, int i) nogil + cdef int B(self, int i) nogil + + cdef int H(self, int i) nogil + + cdef int L(self, int i, int idx) nogil + cdef int R(self, int i, int idx) nogil + + cdef bint empty(self) nogil + + cdef bint eol(self) nogil + + cdef bint is_final(self) nogil + + cdef bint has_head(self, int i) nogil + + cdef bint stack_is_connected(self) nogil + + cdef int stack_depth(self) nogil + + cdef int buffer_length(self) nogil + + cdef void push(self) nogil + + cdef void pop(self) nogil + + cdef void add_arc(self, int head, int child, int label) nogil + + cdef void del_arc(self, int head, int child) nogil + + cdef void set_sent_end(self, int i) nogil + + cdef void clone(self, StateClass src) nogil + + diff --git a/spacy/syntax/stateclass.pyx b/spacy/syntax/stateclass.pyx new file mode 100644 index 000000000..511283da3 --- /dev/null +++ b/spacy/syntax/stateclass.pyx @@ -0,0 +1,126 @@ +from libc.string cimport memcpy, memset +from libc.stdint cimport uint32_t +from .vocab cimport EMPTY_LEXEME + + +memset(&EMPTY_TOKEN, 0, sizeof(TokenC)) +EMPTY_TOKEN.lex = &EMPTY_LEXEME + + +cdef class StateClass: + def __cinit__(self, int length): + self.mem = Pool() + self._stack = self.mem.alloc(sizeof(int), length) + self._buffer = self.mem.alloc(sizeof(int), length) + self._sent = self.mem.alloc(sizeof(TokenC*), length) + self.length = 0 + for i in range(self.length): + self._buffer[i] = i + + cdef int S(self, int i) nogil: + if self._s_i - (i+1) < 0: + return -1 + return self._stack[self._s_i - (i+1)] + + cdef int B(self, int i) nogil: + if (i + self._b_i) >= self.length: + return -1 + return self._buffer[self._b_i + i] + + cdef int H(self, int i) nogil: + if i < 0 or i >= self.length: + return -1 + return self._sent[i].head + i + + cdef int L(self, int i, int idx) nogil: + if 0 <= _popcount(self.safe_get(i).l_kids) <= idx: + return -1 + return _nth_significant_bit(self.safe_get(i).l_kids, idx) + + cdef int R(self, int i, int idx) nogil: + if 0 <= _popcount(self.safe_get(i).r_kids) <= idx: + return -1 + return _nth_significant_bit(self.safe_get(i).r_kids, idx) + + cdef bint empty(self) nogil: + return self._s_i <= 0 + + cdef bint eol(self) nogil: + return self._b_i >= self.length + + cdef bint is_final(self) nogil: + return self.eol() and self.empty() + + cdef bint has_head(self, int i) nogil: + return self.safe_get(i).head != 0 + + cdef bint stack_is_connected(self) nogil: + return False + + cdef int stack_depth(self) nogil: + return self._s_i + + cdef int buffer_length(self) nogil: + return self.length - self._b_i + + cdef void push(self) nogil: + self._stack[self._s_i] = self.B(0) + self._s_i += 1 + self._b_i += 1 + + cdef void pop(self) nogil: + self._s_i -= 1 + + cdef void add_arc(self, int head, int child, int label) nogil: + if self.has_head(child): + self.del_arc(self.H(child), child) + + cdef int dist = head - child + self._sent[child].head = dist + self._sent[child].dep = label + # Keep a bit-vector tracking child dependencies. If a word has a child at + # offset i from it, set that bit (tracking left and right separately) + if child > head: + self._sent[head].r_kids |= 1 << (-dist) + else: + self._sent[head].l_kids |= 1 << dist + + cdef void del_arc(self, int head, int child) nogil: + cdef int dist = head - child + if child > head: + self._sent[head].r_kids &= ~(1 << (-dist)) + else: + self._sent[head].l_kids &= ~(1 << dist) + + cdef void set_sent_end(self, int i) nogil: + if 0 < i < self.length: + self._sent[i].sent_end = True + + cdef void clone(self, StateClass src) nogil: + memcpy(self._sent, src._sent, self.length * sizeof(TokenC)) + memcpy(self._stack, src._stack, self.length * sizeof(int)) + memcpy(self._buffer, src._buffer, self.length * sizeof(int)) + self._b_i = src._b_i + self._s_i = src._s_i + + +# From https://en.wikipedia.org/wiki/Hamming_weight +cdef inline uint32_t _popcount(uint32_t x) nogil: + """Find number of non-zero bits.""" + cdef int count = 0 + while x != 0: + x &= x - 1 + count += 1 + return count + + +cdef inline uint32_t _nth_significant_bit(uint32_t bits, int n) nogil: + cdef int i + for i in range(32): + if bits & (1 << i): + if n < 1: + return i + n -= 1 + return 0 + +