* Add Parser.stepthrough method, with context manager

This commit is contained in:
Matthew Honnibal 2015-08-10 00:08:46 +02:00
parent fe43f8cf39
commit 9de98f5a6f
1 changed files with 61 additions and 19 deletions

View File

@ -88,25 +88,6 @@ cdef class Parser:
self.parse(stcls, eg.c) self.parse(stcls, eg.c)
tokens.set_parse(stcls._sent) tokens.set_parse(stcls._sent)
def get_state(self, Doc tokens, initial_actions):
cdef StateClass stcls = StateClass.init(tokens.data, tokens.length)
self.moves.initialize_state(stcls)
cdef object action_name
cdef Transition action
cdef Example eg = Example(self.model.n_classes, CONTEXT_SIZE,
self.model.n_feats, self.model.n_feats)
for action_name in initial_actions:
if action_name == '_':
self.predict(stcls, &eg.c)
action = self.moves.c[eg.c.guess]
else:
action = self.moves.lookup_transition(action_name)
action.do(stcls, action.label)
if stcls.is_final():
self.moves.finalize_state(stcls)
tokens.set_parse(stcls._sent)
return stcls
cdef void predict(self, StateClass stcls, ExampleC* eg) nogil: cdef void predict(self, StateClass stcls, ExampleC* eg) nogil:
memset(eg.scores, 0, eg.nr_class * sizeof(weight_t)) memset(eg.scores, 0, eg.nr_class * sizeof(weight_t))
self.moves.set_valid(eg.is_valid, stcls) self.moves.set_valid(eg.is_valid, stcls)
@ -139,3 +120,64 @@ cdef class Parser:
self.moves.c[eg.c.guess].do(stcls, self.moves.c[eg.c.guess].label) self.moves.c[eg.c.guess].do(stcls, self.moves.c[eg.c.guess].label)
loss += eg.c.loss loss += eg.c.loss
return loss return loss
def step_through(self, Doc doc):
return StepwiseState(self, doc)
cdef class StepwiseState:
cdef readonly StateClass stcls
cdef readonly Example eg
cdef readonly Doc doc
cdef readonly Parser parser
def __init__(self, Parser parser, Doc doc):
self.parser = parser
self.doc = doc
self.stcls = StateClass.init(doc.data, doc.length)
self.parser.moves.initialize_state(self.stcls)
self.eg = Example(self.parser.model.n_classes, CONTEXT_SIZE,
self.parser.model.n_feats, self.parser.model.n_feats)
def __enter__(self):
return self
def __exit__(self, type, value, traceback):
self.finish()
@property
def is_final(self):
return self.stcls.is_final()
@property
def stack(self):
return self.stcls.stack
@property
def queue(self):
return self.stcls.queue
@property
def heads(self):
return [self.stcls.H(i) for i in range(self.stcls.length)]
@property
def deps(self):
return [self.doc.vocab.strings[self.stcls._sent[i].dep]
for i in range(self.stcls.length)]
def predict(self):
self.parser.predict(self.stcls, &self.eg.c)
action = self.parser.moves.c[self.eg.c.guess]
return self.parser.moves.move_name(action.move, action.label)
def transition(self, action_name):
if action_name == '_':
action_name = self.predict()
action = self.parser.moves.lookup_transition(action_name)
action.do(self.stcls, action.label)
def finish(self):
if self.stcls.is_final():
self.parser.moves.finalize_state(self.stcls)
self.doc.set_parse(self.stcls._sent)