mirror of https://github.com/explosion/spaCy.git
Fix history features
This commit is contained in:
parent
b770f4e108
commit
278a4c17c6
|
@ -70,6 +70,8 @@ from ..attrs cimport ID, TAG, DEP, ORTH, NORM, PREFIX, SUFFIX, TAG
|
||||||
from . import _beam_utils
|
from . import _beam_utils
|
||||||
|
|
||||||
USE_HISTORY = True
|
USE_HISTORY = True
|
||||||
|
HIST_SIZE = 2
|
||||||
|
HIST_DIMS = 16
|
||||||
|
|
||||||
def get_templates(*args, **kwargs):
|
def get_templates(*args, **kwargs):
|
||||||
return []
|
return []
|
||||||
|
@ -262,13 +264,11 @@ cdef class Parser:
|
||||||
|
|
||||||
with Model.use_device('cpu'):
|
with Model.use_device('cpu'):
|
||||||
if depth == 0:
|
if depth == 0:
|
||||||
hist_size = 8
|
|
||||||
nr_dim = 8
|
|
||||||
if USE_HISTORY:
|
if USE_HISTORY:
|
||||||
upper = chain(
|
upper = chain(
|
||||||
HistoryFeatures(nr_class=nr_class, hist_size=hist_size,
|
HistoryFeatures(nr_class=nr_class, hist_size=HIST_SIZE,
|
||||||
nr_dim=nr_dim),
|
nr_dim=HIST_DIMS),
|
||||||
zero_init(Affine(nr_class, nr_class+hist_size*nr_dim,
|
zero_init(Affine(nr_class, nr_class+HIST_SIZE*HIST_DIMS,
|
||||||
drop_factor=0.0)))
|
drop_factor=0.0)))
|
||||||
upper.is_noop = False
|
upper.is_noop = False
|
||||||
else:
|
else:
|
||||||
|
@ -736,15 +736,13 @@ cdef class Parser:
|
||||||
cdef StateClass state
|
cdef StateClass state
|
||||||
cdef int[500] is_valid # TODO: Unhack
|
cdef int[500] is_valid # TODO: Unhack
|
||||||
cdef float* c_scores = &scores[0, 0]
|
cdef float* c_scores = &scores[0, 0]
|
||||||
hists = []
|
|
||||||
for state in states:
|
for state in states:
|
||||||
self.moves.set_valid(is_valid, state.c)
|
self.moves.set_valid(is_valid, state.c)
|
||||||
guess = arg_max_if_valid(c_scores, is_valid, scores.shape[1])
|
guess = arg_max_if_valid(c_scores, is_valid, scores.shape[1])
|
||||||
action = self.moves.c[guess]
|
action = self.moves.c[guess]
|
||||||
action.do(state.c, action.label)
|
action.do(state.c, action.label)
|
||||||
c_scores += scores.shape[1]
|
c_scores += scores.shape[1]
|
||||||
hists.append(guess)
|
state.c.push_hist(guess)
|
||||||
return hists
|
|
||||||
|
|
||||||
def get_batch_loss(self, states, golds, float[:, ::1] scores):
|
def get_batch_loss(self, states, golds, float[:, ::1] scores):
|
||||||
cdef StateClass state
|
cdef StateClass state
|
||||||
|
|
Loading…
Reference in New Issue