* Refactor _advance_beam function

This commit is contained in:
Matthew Honnibal 2015-06-02 18:38:41 +02:00
parent 0786d9b3c7
commit d1b55310a1
1 changed files with 19 additions and 13 deletions

View File

@ -1,9 +1,11 @@
# cython: profile=True
""" """
MALT-style dependency parser MALT-style dependency parser
""" """
from __future__ import unicode_literals from __future__ import unicode_literals
cimport cython cimport cython
from libc.stdint cimport uint32_t, uint64_t from libc.stdint cimport uint32_t, uint64_t
from libc.string cimport memset, memcpy
import random import random
import os.path import os.path
from os import path from os import path
@ -152,11 +154,11 @@ cdef class Parser:
self._advance_beam(gold, gold_parse, True) self._advance_beam(gold, gold_parse, True)
violn.check(pred, gold) violn.check(pred, gold)
counts = {} counts = {}
if pred._states[0].loss >= 1: if pred.loss >= 1:
self._count_feats(counts, tokens, violn.g_hist, 1) self._count_feats(counts, tokens, violn.g_hist, 1)
self._count_feats(counts, tokens, violn.p_hist, -1) self._count_feats(counts, tokens, violn.p_hist, -1)
self.model._model.update(counts) self.model._model.update(counts)
return pred._states[0].loss return pred.loss
def _advance_beam(self, Beam beam, GoldParse gold, bint follow_gold): def _advance_beam(self, Beam beam, GoldParse gold, bint follow_gold):
cdef atom_t[CONTEXT_SIZE] context cdef atom_t[CONTEXT_SIZE] context
@ -167,22 +169,26 @@ cdef class Parser:
for i in range(beam.size): for i in range(beam.size):
state = <State*>beam.at(i) state = <State*>beam.at(i)
fill_context(context, state) fill_context(context, state)
scores = self.model.score(context) self.model.set_scores(beam.scores[i], context)
validities = self.moves.get_valid(state) self.moves.set_valid(beam.is_valid[i], state)
if gold is None:
for j in range(self.moves.n_moves): if follow_gold:
beam.set_cell(i, j, scores[j], validities[j], 0) for i in range(beam.size):
elif not follow_gold: state = <State*>beam.at(i)
for j in range(self.moves.n_moves): for j in range(self.moves.n_moves):
move = &self.moves.c[j] move = &self.moves.c[j]
cost = move.get_cost(move, state, gold) beam.costs[i][j] = move.get_cost(move, state, gold)
beam.set_cell(i, j, scores[j], validities[j], cost) beam.is_valid[i][j] = beam.costs[i][j] == 0
else: elif gold is not None:
for i in range(beam.size):
state = <State*>beam.at(i)
for j in range(self.moves.n_moves): for j in range(self.moves.n_moves):
move = &self.moves.c[j] move = &self.moves.c[j]
cost = move.get_cost(move, state, gold) beam.costs[i][j] = move.get_cost(move, state, gold)
beam.set_cell(i, j, scores[j], cost == 0, cost)
beam.advance(_transition_state, <void*>self.moves.c) beam.advance(_transition_state, <void*>self.moves.c)
state = <State*>beam.at(0)
if state.sent[state.i].sent_end:
beam.size = int(beam.size / 2)
beam.check_done(_check_final_state, NULL) beam.check_done(_check_final_state, NULL)
def _count_feats(self, dict counts, Tokens tokens, list hist, int inc): def _count_feats(self, dict counts, Tokens tokens, list hist, int inc):