From 9de98f5a6fc5d5c012c00fa9278dca21a6beb489 Mon Sep 17 00:00:00 2001 From: Matthew Honnibal Date: Mon, 10 Aug 2015 00:08:46 +0200 Subject: [PATCH] * Add Parser.stepthrough method, with context manager --- spacy/syntax/parser.pyx | 80 +++++++++++++++++++++++++++++++---------- 1 file changed, 61 insertions(+), 19 deletions(-) diff --git a/spacy/syntax/parser.pyx b/spacy/syntax/parser.pyx index bb7e2d96a..d53a1959a 100644 --- a/spacy/syntax/parser.pyx +++ b/spacy/syntax/parser.pyx @@ -88,25 +88,6 @@ cdef class Parser: self.parse(stcls, eg.c) tokens.set_parse(stcls._sent) - def get_state(self, Doc tokens, initial_actions): - cdef StateClass stcls = StateClass.init(tokens.data, tokens.length) - self.moves.initialize_state(stcls) - cdef object action_name - cdef Transition action - cdef Example eg = Example(self.model.n_classes, CONTEXT_SIZE, - self.model.n_feats, self.model.n_feats) - for action_name in initial_actions: - if action_name == '_': - self.predict(stcls, &eg.c) - action = self.moves.c[eg.c.guess] - else: - action = self.moves.lookup_transition(action_name) - action.do(stcls, action.label) - if stcls.is_final(): - self.moves.finalize_state(stcls) - tokens.set_parse(stcls._sent) - return stcls - cdef void predict(self, StateClass stcls, ExampleC* eg) nogil: memset(eg.scores, 0, eg.nr_class * sizeof(weight_t)) self.moves.set_valid(eg.is_valid, stcls) @@ -139,3 +120,64 @@ cdef class Parser: self.moves.c[eg.c.guess].do(stcls, self.moves.c[eg.c.guess].label) loss += eg.c.loss return loss + + def step_through(self, Doc doc): + return StepwiseState(self, doc) + + +cdef class StepwiseState: + cdef readonly StateClass stcls + cdef readonly Example eg + cdef readonly Doc doc + cdef readonly Parser parser + + def __init__(self, Parser parser, Doc doc): + self.parser = parser + self.doc = doc + self.stcls = StateClass.init(doc.data, doc.length) + self.parser.moves.initialize_state(self.stcls) + self.eg = Example(self.parser.model.n_classes, CONTEXT_SIZE, + self.parser.model.n_feats, self.parser.model.n_feats) + + def __enter__(self): + return self + + def __exit__(self, type, value, traceback): + self.finish() + + @property + def is_final(self): + return self.stcls.is_final() + + @property + def stack(self): + return self.stcls.stack + + @property + def queue(self): + return self.stcls.queue + + @property + def heads(self): + return [self.stcls.H(i) for i in range(self.stcls.length)] + + @property + def deps(self): + return [self.doc.vocab.strings[self.stcls._sent[i].dep] + for i in range(self.stcls.length)] + + def predict(self): + self.parser.predict(self.stcls, &self.eg.c) + action = self.parser.moves.c[self.eg.c.guess] + return self.parser.moves.move_name(action.move, action.label) + + def transition(self, action_name): + if action_name == '_': + action_name = self.predict() + action = self.parser.moves.lookup_transition(action_name) + action.do(self.stcls, action.label) + + def finish(self): + if self.stcls.is_final(): + self.parser.moves.finalize_state(self.stcls) + self.doc.set_parse(self.stcls._sent)