mirror of https://github.com/explosion/spaCy.git
* Refactor _advance_beam function
This commit is contained in:
parent
0786d9b3c7
commit
d1b55310a1
|
@ -1,9 +1,11 @@
|
||||||
|
# cython: profile=True
|
||||||
"""
|
"""
|
||||||
MALT-style dependency parser
|
MALT-style dependency parser
|
||||||
"""
|
"""
|
||||||
from __future__ import unicode_literals
|
from __future__ import unicode_literals
|
||||||
cimport cython
|
cimport cython
|
||||||
from libc.stdint cimport uint32_t, uint64_t
|
from libc.stdint cimport uint32_t, uint64_t
|
||||||
|
from libc.string cimport memset, memcpy
|
||||||
import random
|
import random
|
||||||
import os.path
|
import os.path
|
||||||
from os import path
|
from os import path
|
||||||
|
@ -152,11 +154,11 @@ cdef class Parser:
|
||||||
self._advance_beam(gold, gold_parse, True)
|
self._advance_beam(gold, gold_parse, True)
|
||||||
violn.check(pred, gold)
|
violn.check(pred, gold)
|
||||||
counts = {}
|
counts = {}
|
||||||
if pred._states[0].loss >= 1:
|
if pred.loss >= 1:
|
||||||
self._count_feats(counts, tokens, violn.g_hist, 1)
|
self._count_feats(counts, tokens, violn.g_hist, 1)
|
||||||
self._count_feats(counts, tokens, violn.p_hist, -1)
|
self._count_feats(counts, tokens, violn.p_hist, -1)
|
||||||
self.model._model.update(counts)
|
self.model._model.update(counts)
|
||||||
return pred._states[0].loss
|
return pred.loss
|
||||||
|
|
||||||
def _advance_beam(self, Beam beam, GoldParse gold, bint follow_gold):
|
def _advance_beam(self, Beam beam, GoldParse gold, bint follow_gold):
|
||||||
cdef atom_t[CONTEXT_SIZE] context
|
cdef atom_t[CONTEXT_SIZE] context
|
||||||
|
@ -167,22 +169,26 @@ cdef class Parser:
|
||||||
for i in range(beam.size):
|
for i in range(beam.size):
|
||||||
state = <State*>beam.at(i)
|
state = <State*>beam.at(i)
|
||||||
fill_context(context, state)
|
fill_context(context, state)
|
||||||
scores = self.model.score(context)
|
self.model.set_scores(beam.scores[i], context)
|
||||||
validities = self.moves.get_valid(state)
|
self.moves.set_valid(beam.is_valid[i], state)
|
||||||
if gold is None:
|
|
||||||
for j in range(self.moves.n_moves):
|
if follow_gold:
|
||||||
beam.set_cell(i, j, scores[j], validities[j], 0)
|
for i in range(beam.size):
|
||||||
elif not follow_gold:
|
state = <State*>beam.at(i)
|
||||||
for j in range(self.moves.n_moves):
|
for j in range(self.moves.n_moves):
|
||||||
move = &self.moves.c[j]
|
move = &self.moves.c[j]
|
||||||
cost = move.get_cost(move, state, gold)
|
beam.costs[i][j] = move.get_cost(move, state, gold)
|
||||||
beam.set_cell(i, j, scores[j], validities[j], cost)
|
beam.is_valid[i][j] = beam.costs[i][j] == 0
|
||||||
else:
|
elif gold is not None:
|
||||||
|
for i in range(beam.size):
|
||||||
|
state = <State*>beam.at(i)
|
||||||
for j in range(self.moves.n_moves):
|
for j in range(self.moves.n_moves):
|
||||||
move = &self.moves.c[j]
|
move = &self.moves.c[j]
|
||||||
cost = move.get_cost(move, state, gold)
|
beam.costs[i][j] = move.get_cost(move, state, gold)
|
||||||
beam.set_cell(i, j, scores[j], cost == 0, cost)
|
|
||||||
beam.advance(_transition_state, <void*>self.moves.c)
|
beam.advance(_transition_state, <void*>self.moves.c)
|
||||||
|
state = <State*>beam.at(0)
|
||||||
|
if state.sent[state.i].sent_end:
|
||||||
|
beam.size = int(beam.size / 2)
|
||||||
beam.check_done(_check_final_state, NULL)
|
beam.check_done(_check_final_state, NULL)
|
||||||
|
|
||||||
def _count_feats(self, dict counts, Tokens tokens, list hist, int inc):
|
def _count_feats(self, dict counts, Tokens tokens, list hist, int inc):
|
||||||
|
|
Loading…
Reference in New Issue