From 03a520ec4f9296c3d0fa57045154db23f43d2e12 Mon Sep 17 00:00:00 2001 From: Matthew Honnibal Date: Thu, 27 Oct 2016 17:58:56 +0200 Subject: [PATCH] Change signature of Parser.parseC, so that nr_class is read from the transition system. This allows the transition system to modify the number of actions in initialize_state. --- spacy/syntax/parser.pxd | 2 +- spacy/syntax/parser.pyx | 17 +++++++++-------- 2 files changed, 10 insertions(+), 9 deletions(-) diff --git a/spacy/syntax/parser.pxd b/spacy/syntax/parser.pxd index 1ad0ce729..4370f5c6f 100644 --- a/spacy/syntax/parser.pxd +++ b/spacy/syntax/parser.pxd @@ -20,4 +20,4 @@ cdef class Parser: cdef readonly TransitionSystem moves cdef readonly object cfg - cdef int parseC(self, TokenC* tokens, int length, int nr_feat, int nr_class) nogil + cdef int parseC(self, TokenC* tokens, int length, int nr_feat) nogil diff --git a/spacy/syntax/parser.pyx b/spacy/syntax/parser.pyx index 62b61c37b..d7fce5b3d 100644 --- a/spacy/syntax/parser.pyx +++ b/spacy/syntax/parser.pyx @@ -107,10 +107,9 @@ cdef class Parser: return (Parser, (self.vocab, self.moves, self.model), None, None) def __call__(self, Doc tokens): - cdef int nr_class = self.moves.n_moves cdef int nr_feat = self.model.nr_feat with nogil: - status = self.parseC(tokens.c, tokens.length, nr_feat, nr_class) + status = self.parseC(tokens.c, tokens.length, nr_feat) # Check for KeyboardInterrupt etc. Untested PyErr_CheckSignals() if status != 0: @@ -123,7 +122,6 @@ cdef class Parser: cdef int* lengths = mem.alloc(batch_size, sizeof(int)) cdef Doc doc cdef int i - cdef int nr_class = self.moves.n_moves cdef int nr_feat = self.model.nr_feat cdef int status queue = [] @@ -134,7 +132,7 @@ cdef class Parser: if len(queue) == batch_size: with nogil: for i in cython.parallel.prange(batch_size, num_threads=n_threads): - status = self.parseC(doc_ptr[i], lengths[i], nr_feat, nr_class) + status = self.parseC(doc_ptr[i], lengths[i], nr_feat) if status != 0: with gil: raise ParserStateError(queue[i]) @@ -146,7 +144,7 @@ cdef class Parser: batch_size = len(queue) with nogil: for i in cython.parallel.prange(batch_size, num_threads=n_threads): - status = self.parseC(doc_ptr[i], lengths[i], nr_feat, nr_class) + status = self.parseC(doc_ptr[i], lengths[i], nr_feat) if status != 0: with gil: raise ParserStateError(queue[i]) @@ -155,7 +153,12 @@ cdef class Parser: self.moves.finalize_doc(doc) yield doc - cdef int parseC(self, TokenC* tokens, int length, int nr_feat, int nr_class) nogil: + cdef int parseC(self, TokenC* tokens, int length, int nr_feat) nogil: + state = new StateC(tokens, length) + # NB: This can change self.moves.n_moves! + self.moves.initialize_state(state) + nr_class = self.moves.n_moves + cdef ExampleC eg eg.nr_feat = nr_feat eg.nr_atom = CONTEXT_SIZE @@ -164,8 +167,6 @@ cdef class Parser: eg.atoms = calloc(sizeof(atom_t), CONTEXT_SIZE) eg.scores = calloc(sizeof(weight_t), nr_class) eg.is_valid = calloc(sizeof(int), nr_class) - state = new StateC(tokens, length) - self.moves.initialize_state(state) cdef int i while not state.is_final(): self.model.set_featuresC(&eg, state)