diff --git a/spacy/syntax/parser.pyx b/spacy/syntax/parser.pyx index 967e64cc9..ffe38865c 100644 --- a/spacy/syntax/parser.pyx +++ b/spacy/syntax/parser.pyx @@ -1,9 +1,11 @@ +# cython: profile=True """ MALT-style dependency parser """ from __future__ import unicode_literals cimport cython from libc.stdint cimport uint32_t, uint64_t +from libc.string cimport memset, memcpy import random import os.path from os import path @@ -152,11 +154,11 @@ cdef class Parser: self._advance_beam(gold, gold_parse, True) violn.check(pred, gold) counts = {} - if pred._states[0].loss >= 1: + if pred.loss >= 1: self._count_feats(counts, tokens, violn.g_hist, 1) self._count_feats(counts, tokens, violn.p_hist, -1) self.model._model.update(counts) - return pred._states[0].loss + return pred.loss def _advance_beam(self, Beam beam, GoldParse gold, bint follow_gold): cdef atom_t[CONTEXT_SIZE] context @@ -167,22 +169,26 @@ cdef class Parser: for i in range(beam.size): state = beam.at(i) fill_context(context, state) - scores = self.model.score(context) - validities = self.moves.get_valid(state) - if gold is None: - for j in range(self.moves.n_moves): - beam.set_cell(i, j, scores[j], validities[j], 0) - elif not follow_gold: + self.model.set_scores(beam.scores[i], context) + self.moves.set_valid(beam.is_valid[i], state) + + if follow_gold: + for i in range(beam.size): + state = beam.at(i) for j in range(self.moves.n_moves): move = &self.moves.c[j] - cost = move.get_cost(move, state, gold) - beam.set_cell(i, j, scores[j], validities[j], cost) - else: + beam.costs[i][j] = move.get_cost(move, state, gold) + beam.is_valid[i][j] = beam.costs[i][j] == 0 + elif gold is not None: + for i in range(beam.size): + state = beam.at(i) for j in range(self.moves.n_moves): move = &self.moves.c[j] - cost = move.get_cost(move, state, gold) - beam.set_cell(i, j, scores[j], cost == 0, cost) + beam.costs[i][j] = move.get_cost(move, state, gold) beam.advance(_transition_state, self.moves.c) + state = beam.at(0) + if state.sent[state.i].sent_end: + beam.size = int(beam.size / 2) beam.check_done(_check_final_state, NULL) def _count_feats(self, dict counts, Tokens tokens, list hist, int inc):