Break parser batches into sub-batches, sorted by length.

This commit is contained in:
Matthew Honnibal 2017-10-18 21:45:01 +02:00
parent 394633efce
commit 633a75c7e0
1 changed files with 45 additions and 46 deletions

View File

@ -9,6 +9,7 @@ from collections import Counter, OrderedDict
import ujson import ujson
import json import json
import contextlib import contextlib
import numpy
from libc.math cimport exp from libc.math cimport exp
cimport cython cimport cython
@ -27,7 +28,7 @@ from libc.string cimport memset, memcpy
from libc.stdlib cimport malloc, calloc, free from libc.stdlib cimport malloc, calloc, free
from thinc.typedefs cimport weight_t, class_t, feat_t, atom_t, hash_t from thinc.typedefs cimport weight_t, class_t, feat_t, atom_t, hash_t
from thinc.linear.avgtron cimport AveragedPerceptron from thinc.linear.avgtron cimport AveragedPerceptron
from thinc.linalg cimport VecVec from thinc.linalg cimport Vec, VecVec
from thinc.structs cimport SparseArrayC, FeatureC, ExampleC from thinc.structs cimport SparseArrayC, FeatureC, ExampleC
from thinc.extra.eg cimport Example from thinc.extra.eg cimport Example
from thinc.extra.search cimport Beam from thinc.extra.search cimport Beam
@ -288,6 +289,8 @@ cdef class Parser:
zero_init(Affine(nr_class, hidden_width, drop_factor=0.0)) zero_init(Affine(nr_class, hidden_width, drop_factor=0.0))
) )
upper.is_noop = False upper.is_noop = False
print(upper._layers)
print(upper._layers[0]._layers)
# TODO: This is an unfortunate hack atm! # TODO: This is an unfortunate hack atm!
# Used to set input dimensions in network. # Used to set input dimensions in network.
@ -391,19 +394,22 @@ cdef class Parser:
beam_density = self.cfg.get('beam_density', 0.0) beam_density = self.cfg.get('beam_density', 0.0)
cdef Doc doc cdef Doc doc
cdef Beam beam cdef Beam beam
for docs in cytoolz.partition_all(batch_size, docs): for batch in cytoolz.partition_all(batch_size, docs):
docs = list(docs) batch = list(batch)
by_length = sorted(list(batch), key=lambda doc: len(doc))
for subbatch in cytoolz.partition_all(32, by_length):
subbatch = list(subbatch)
if beam_width == 1: if beam_width == 1:
parse_states = self.parse_batch(docs) parse_states = self.parse_batch(subbatch)
beams = [] beams = []
else: else:
beams = self.beam_parse(docs, beams = self.beam_parse(subbatch,
beam_width=beam_width, beam_density=beam_density) beam_width=beam_width, beam_density=beam_density)
parse_states = [] parse_states = []
for beam in beams: for beam in beams:
parse_states.append(<StateClass>beam.at(0)) parse_states.append(<StateClass>beam.at(0))
self.set_annotations(docs, parse_states) self.set_annotations(subbatch, parse_states)
yield from docs yield from batch
def parse_batch(self, docs): def parse_batch(self, docs):
cdef: cdef:
@ -437,38 +443,22 @@ cdef class Parser:
cdef np.ndarray token_ids = numpy.zeros((nr_state, nr_feat), dtype='i') cdef np.ndarray token_ids = numpy.zeros((nr_state, nr_feat), dtype='i')
cdef np.ndarray is_valid = numpy.zeros((nr_state, nr_class), dtype='i') cdef np.ndarray is_valid = numpy.zeros((nr_state, nr_class), dtype='i')
cdef np.ndarray scores cdef np.ndarray scores
cdef np.ndarray hidden_weights = numpy.ascontiguousarray(vec2scores._layers[-1].W.T)
cdef np.ndarray hidden_bias = vec2scores._layers[-1].b
hW = <float*>hidden_weights.data
hb = <float*>hidden_bias.data
cdef int nr_hidden = hidden_weights.shape[0]
c_token_ids = <int*>token_ids.data c_token_ids = <int*>token_ids.data
c_is_valid = <int*>is_valid.data c_is_valid = <int*>is_valid.data
cdef int has_hidden = not getattr(vec2scores, 'is_noop', False) cdef int has_hidden = not getattr(vec2scores, 'is_noop', False)
cdef int nr_step cdef int nr_step
while not next_step.empty(): while not next_step.empty():
nr_step = next_step.size() nr_step = next_step.size()
if not has_hidden: for i in cython.parallel.prange(nr_step, num_threads=3,
for i in cython.parallel.prange(nr_step, num_threads=6,
nogil=True): nogil=True):
self._parse_step(next_step[i], self._parse_step(next_step[i],
feat_weights, nr_class, nr_feat, nr_piece) feat_weights, hW, hb, nr_class, nr_hidden, nr_feat, nr_piece)
else:
hists = []
for i in range(nr_step):
st = next_step[i]
st.set_context_tokens(&c_token_ids[i*nr_feat], nr_feat)
self.moves.set_valid(&c_is_valid[i*nr_class], st)
hists.append([st.get_hist(j+1) for j in range(8)])
hists = numpy.asarray(hists)
vectors = state2vec(token_ids[:next_step.size()])
if self.cfg.get('hist_size'):
scores = vec2scores((vectors, hists))
else:
scores = vec2scores(vectors)
c_scores = <float*>scores.data
for i in range(nr_step):
st = next_step[i]
guess = arg_max_if_valid(
&c_scores[i*nr_class], &c_is_valid[i*nr_class], nr_class)
action = self.moves.c[guess]
action.do(st, action.label)
st.push_hist(guess)
this_step, next_step = next_step, this_step this_step, next_step = next_step, this_step
next_step.clear() next_step.clear()
for st in this_step: for st in this_step:
@ -528,24 +518,33 @@ cdef class Parser:
return beams return beams
cdef void _parse_step(self, StateC* state, cdef void _parse_step(self, StateC* state,
const float* feat_weights, const float* feat_weights, const float* hW, const float* hb,
int nr_class, int nr_feat, int nr_piece) nogil: int nr_class, int nr_hidden, int nr_feat, int nr_piece) nogil:
'''This only works with no hidden layers -- fast but inaccurate''' '''This only works with no hidden layers -- fast but inaccurate'''
token_ids = <int*>calloc(nr_feat, sizeof(int)) token_ids = <int*>calloc(nr_feat, sizeof(int))
scores = <float*>calloc(nr_class * nr_piece, sizeof(float)) vector = <float*>calloc(nr_hidden * nr_piece, sizeof(float))
scores = <float*>calloc(nr_class, sizeof(float))
is_valid = <int*>calloc(nr_class, sizeof(int)) is_valid = <int*>calloc(nr_class, sizeof(int))
state.set_context_tokens(token_ids, nr_feat) state.set_context_tokens(token_ids, nr_feat)
sum_state_features(scores, sum_state_features(vector,
feat_weights, token_ids, 1, nr_feat, nr_class * nr_piece) feat_weights, token_ids, 1, nr_feat, nr_hidden * nr_piece)
for i in range(nr_hidden):
feature = Vec.max(&vector[i*nr_piece], nr_piece)
for j in range(nr_class):
scores[j] += feature * hW[j]
hW += nr_class
for i in range(nr_class):
scores[i] += hb[i]
self.moves.set_valid(is_valid, state) self.moves.set_valid(is_valid, state)
guess = arg_maxout_if_valid(scores, is_valid, nr_class, nr_piece) guess = arg_max_if_valid(scores, is_valid, nr_class)
action = self.moves.c[guess] action = self.moves.c[guess]
action.do(state, action.label) action.do(state, action.label)
state.push_hist(guess) state.push_hist(guess)
free(is_valid) free(is_valid)
free(scores) free(scores)
free(vector)
free(token_ids) free(token_ids)
def update(self, docs, golds, drop=0., sgd=None, losses=None): def update(self, docs, golds, drop=0., sgd=None, losses=None):