From bdf2dba9fb411b2d16633b3952f48ab0c987a59b Mon Sep 17 00:00:00 2001 From: Matthew Honnibal Date: Sun, 7 May 2017 02:02:43 +0200 Subject: [PATCH 1/9] WIP on refactor, with hidde pre-computing --- bin/parser/train_ud.py | 24 +++- spacy/_ml.py | 236 ++++++++++----------------------- spacy/syntax/parser.pyx | 253 +++++++++++++++++------------------- spacy/syntax/stateclass.pyx | 19 +-- 4 files changed, 216 insertions(+), 316 deletions(-) diff --git a/bin/parser/train_ud.py b/bin/parser/train_ud.py index f0aff22b2..793f1c821 100644 --- a/bin/parser/train_ud.py +++ b/bin/parser/train_ud.py @@ -82,6 +82,7 @@ def organize_data(vocab, train_sents): def main(lang_name, train_loc, dev_loc, model_dir, clusters_loc=None): LangClass = spacy.util.get_lang_class(lang_name) train_sents = list(read_conllx(train_loc)) + dev_sents = list(read_conllx(dev_loc)) train_sents = PseudoProjectivity.preprocess_training_data(train_sents) actions = ArcEager.get_actions(gold_parses=train_sents) @@ -136,6 +137,7 @@ def main(lang_name, train_loc, dev_loc, model_dir, clusters_loc=None): parser = DependencyParser(vocab, actions=actions, features=features, L1=0.0) Xs, ys = organize_data(vocab, train_sents) + dev_Xs, dev_ys = organize_data(vocab, dev_sents) Xs = Xs[:100] ys = ys[:100] with encoder.model.begin_training(Xs[:100], ys[:100]) as (trainer, optimizer): @@ -145,13 +147,13 @@ def main(lang_name, train_loc, dev_loc, model_dir, clusters_loc=None): parser.begin_training(docs, ys) nn_loss = [0.] def track_progress(): - scorer = score_model(vocab, encoder, tagger, parser, Xs, ys) + scorer = score_model(vocab, encoder, tagger, parser, dev_Xs, dev_ys) itn = len(nn_loss) print('%d:\t%.3f\t%.3f\t%.3f' % (itn, nn_loss[-1], scorer.uas, scorer.tags_acc)) nn_loss.append(0.) trainer.each_epoch.append(track_progress) - trainer.batch_size = 6 - trainer.nb_epoch = 10000 + trainer.batch_size = 12 + trainer.nb_epoch = 2 for docs, golds in trainer.iterate(Xs, ys, progress_bar=False): docs = [Doc(vocab, words=[w.text for w in doc]) for doc in docs] tokvecs, upd_tokvecs = encoder.begin_update(docs) @@ -163,10 +165,20 @@ def main(lang_name, train_loc, dev_loc, model_dir, clusters_loc=None): upd_tokvecs(d_tokvecs, sgd=optimizer) nn_loss[-1] += loss nlp = LangClass(vocab=vocab, tagger=tagger, parser=parser) - nlp.end_training(model_dir) - scorer = score_model(vocab, tagger, parser, read_conllx(dev_loc)) - print('%d:\t%.3f\t%.3f\t%.3f' % (itn, scorer.uas, scorer.las, scorer.tags_acc)) + #nlp.end_training(model_dir) + #scorer = score_model(vocab, tagger, parser, read_conllx(dev_loc)) + #print('%d:\t%.3f\t%.3f\t%.3f' % (itn, scorer.uas, scorer.las, scorer.tags_acc)) if __name__ == '__main__': + import cProfile + import pstats + if 0: + plac.call(main) + else: + cProfile.runctx("plac.call(main)", globals(), locals(), "Profile.prof") + s = pstats.Stats("Profile.prof") + s.strip_dirs().sort_stats("time").print_stats() + + plac.call(main) diff --git a/spacy/_ml.py b/spacy/_ml.py index 87549369f..be82c6b2c 100644 --- a/spacy/_ml.py +++ b/spacy/_ml.py @@ -21,181 +21,23 @@ def get_col(idx): return layerize(forward) -def build_model(state2vec, width, depth, nr_class): - with Model.define_operators({'>>': chain, '**': clone}): - model = ( - state2vec - >> Maxout(width, 1344) - >> Maxout(width, width) - >> Affine(nr_class, width) - ) - return model - - -def build_debug_model(state2vec, width, depth, nr_class): - with Model.define_operators({'>>': chain, '**': clone}): - model = ( - state2vec - >> Maxout(width) - >> Affine(nr_class) - ) - return model - - -def build_debug_state2vec(width, nr_vector=1000, nF=1, nB=0, nS=1, nL=2, nR=2): - ops = Model.ops - def forward(tokens_attrs_vectors, drop=0.): - tokens, attr_vals, tokvecs = tokens_attrs_vectors - - orig_tokvecs_shape = tokvecs.shape - tokvecs = tokvecs.reshape((tokvecs.shape[0], tokvecs.shape[1] * - tokvecs.shape[2])) - - vector = tokvecs - - def backward(d_vector, sgd=None): - d_tokvecs = vector.reshape(orig_tokvecs_shape) - return (tokens, d_tokvecs) - return vector, backward - model = layerize(forward) - return model - - -def build_state2vec(nr_context_tokens, width, nr_vector=1000): - ops = Model.ops - with Model.define_operators({'|': concatenate, '+': add, '>>': chain}): - - hiddens = [get_col(i) >> Affine(width) for i in range(nr_context_tokens)] - model = ( - get_token_vectors - >> add(*hiddens) - >> Maxout(width) - ) - return model - - -def print_shape(prefix): - def forward(X, drop=0.): - return X, lambda dX, **kwargs: dX - return layerize(forward) - - -@layerize -def get_token_vectors(tokens_attrs_vectors, drop=0.): - ops = Model.ops - tokens, attrs, vectors = tokens_attrs_vectors - def backward(d_output, sgd=None): - return (tokens, d_output) - return vectors, backward - - -def build_parser_state2vec(width, nr_vector=1000, nF=1, nB=0, nS=1, nL=2, nR=2): - embed_tags = _reshape(chain(get_col(0), HashEmbed(16, nr_vector))) - embed_deps = _reshape(chain(get_col(1), HashEmbed(16, nr_vector))) - ops = embed_tags.ops - def forward(tokens_attrs_vectors, drop=0.): - tokens, attr_vals, tokvecs = tokens_attrs_vectors - tagvecs, bp_tagvecs = embed_deps.begin_update(attr_vals, drop=drop) - depvecs, bp_depvecs = embed_tags.begin_update(attr_vals, drop=drop) - orig_tokvecs_shape = tokvecs.shape - tokvecs = tokvecs.reshape((tokvecs.shape[0], tokvecs.shape[1] * - tokvecs.shape[2])) - - shapes = (tagvecs.shape, depvecs.shape, tokvecs.shape) - assert tagvecs.shape[0] == depvecs.shape[0] == tokvecs.shape[0], shapes - vector = ops.xp.hstack((tagvecs, depvecs, tokvecs)) - - def backward(d_vector, sgd=None): - d_tagvecs, d_depvecs, d_tokvecs = backprop_concatenate(d_vector, shapes) - assert d_tagvecs.shape == shapes[0], (d_tagvecs.shape, shapes) - assert d_depvecs.shape == shapes[1], (d_depvecs.shape, shapes) - assert d_tokvecs.shape == shapes[2], (d_tokvecs.shape, shapes) - bp_tagvecs(d_tagvecs) - bp_depvecs(d_depvecs) - d_tokvecs = d_tokvecs.reshape(orig_tokvecs_shape) - - return (tokens, d_tokvecs) - return vector, backward - model = layerize(forward) - model._layers = [embed_tags, embed_deps] - return model - - -def backprop_concatenate(gradient, shapes): - grads = [] - start = 0 - for shape in shapes: - end = start + shape[1] - grads.append(gradient[:, start : end]) - start = end - return grads - - -def _reshape(layer): - '''Transforms input with shape - (states, tokens, features) - into input with shape: - (states * tokens, features) - So that it can be used with a token-wise feature extraction layer, e.g. - an embedding layer. The embedding layer outputs: - (states * tokens, ndim) - But we want to concatenate the vectors for the tokens, so we produce: - (states, tokens * ndim) - We then need to reverse the transforms to do the backward pass. Recall - the simple rule here: each layer is a map: - inputs -> (outputs, (d_outputs->d_inputs)) - So the shapes must match like this: - shape of forward input == shape of backward output - shape of backward input == shape of forward output - ''' - def forward(X__bfm, drop=0.): - b, f, m = X__bfm.shape - B = b*f - M = f*m - X__Bm = X__bfm.reshape((B, m)) - y__Bn, bp_yBn = layer.begin_update(X__Bm, drop=drop) - n = y__Bn.shape[1] - N = f * n - y__bN = y__Bn.reshape((b, N)) - def backward(dy__bN, sgd=None): - dy__Bn = dy__bN.reshape((B, n)) - dX__Bm = bp_yBn(dy__Bn, sgd) - if dX__Bm is None: - return None - else: - return dX__Bm.reshape((b, f, m)) - return y__bN, backward - model = layerize(forward) - model._layers.append(layer) - return model - - -@layerize -def flatten(seqs, drop=0.): - ops = Model.ops - def finish_update(d_X, sgd=None): - return d_X - X = ops.xp.concatenate([ops.asarray(seq) for seq in seqs]) - return X, finish_update - - def build_tok2vec(lang, width, depth=2, embed_size=1000): cols = [ID, LOWER, PREFIX, SUFFIX, SHAPE, TAG] with Model.define_operators({'>>': chain, '|': concatenate, '**': clone}): #static = get_col(cols.index(ID)) >> StaticVectors(lang, width) lower = get_col(cols.index(LOWER)) >> HashEmbed(width, embed_size) - prefix = get_col(cols.index(PREFIX)) >> HashEmbed(width, embed_size) - suffix = get_col(cols.index(SUFFIX)) >> HashEmbed(width, embed_size) - shape = get_col(cols.index(SHAPE)) >> HashEmbed(width, embed_size) - tag = get_col(cols.index(TAG)) >> HashEmbed(width, embed_size) + prefix = get_col(cols.index(PREFIX)) >> HashEmbed(width//4, embed_size) + suffix = get_col(cols.index(SUFFIX)) >> HashEmbed(width//4, embed_size) + shape = get_col(cols.index(SHAPE)) >> HashEmbed(width//4, embed_size) + tag = get_col(cols.index(TAG)) >> HashEmbed(width//2, embed_size) tok2vec = ( doc2feats(cols) >> with_flatten( #(static | prefix | suffix | shape) (lower | prefix | suffix | shape | tag) - >> Maxout(width, width*5) - #>> (ExtractWindow(nW=1) >> Maxout(width, width*3)) - #>> (ExtractWindow(nW=1) >> Maxout(width, width*3)) + >> Maxout(width) + >> (ExtractWindow(nW=1) >> Maxout(width, width*3)) + >> (ExtractWindow(nW=1) >> Maxout(width, width*3)) ) ) return tok2vec @@ -208,3 +50,67 @@ def doc2feats(cols): return feats, None model = layerize(forward) return model + + +def build_feature_precomputer(model, feat_maps): + '''Allow a model to be "primed" by pre-computing input features in bulk. + + This is used for the parser, where we want to take a batch of documents, + and compute vectors for each (token, position) pair. These vectors can then + be reused, especially for beam-search. + + Let's say we're using 12 features for each state, e.g. word at start of + buffer, three words on stack, their children, etc. In the normal arc-eager + system, a document of length N is processed in 2*N states. This means we'll + create 2*N*12 feature vectors --- but if we pre-compute, we only need + N*12 vector computations. The saving for beam-search is much better: + if we have a beam of k, we'll normally make 2*N*12*K computations -- + so we can save the factor k. This also gives a nice CPU/GPU division: + we can do all our hard maths up front, packed into large multiplications, + and do the hard-to-program parsing on the CPU. + ''' + def precompute(input_vectors): + cached, backprops = zip(*[lyr.begin_update(input_vectors) + for lyr in feat_maps) + def forward(batch_token_ids, drop=0.): + output = ops.allocate((batch_size, output_width)) + # i: batch index + # j: position index (i.e. N0, S0, etc + # tok_i: Index of the token within its document + for i, token_ids in enumerate(batch_token_ids): + for j, tok_i in enumerate(token_ids): + output[i] += cached[j][tok_i] + def backward(d_vector, sgd=None): + d_inputs = ops.allocate((batch_size, n_feat, vec_width)) + for i, token_ids in enumerate(batch_token_ids): + for j in range(len(token_ids)): + d_inputs[i][j] = backprops[j](d_vector, sgd) + # Return the IDs, so caller can associate to correct token + return (batch_token_ids, d_inputs) + return vector, backward + return chain(layerize(forward), model) + return precompute + + +def print_shape(prefix): + def forward(X, drop=0.): + return X, lambda dX, **kwargs: dX + return layerize(forward) + + +@layerize +def get_token_vectors(tokens_attrs_vectors, drop=0.): + ops = Model.ops + tokens, attrs, vectors = tokens_attrs_vectors + def backward(d_output, sgd=None): + return (tokens, d_output) + return vectors, backward + + +@layerize +def flatten(seqs, drop=0.): + ops = Model.ops + def finish_update(d_X, sgd=None): + return d_X + X = ops.xp.concatenate([ops.asarray(seq) for seq in seqs]) + return X, finish_update diff --git a/spacy/syntax/parser.pyx b/spacy/syntax/parser.pyx index b7d33e1c9..1eb03fdb4 100644 --- a/spacy/syntax/parser.pyx +++ b/spacy/syntax/parser.pyx @@ -44,9 +44,7 @@ from ..strings cimport StringStore from ..gold cimport GoldParse from ..attrs cimport TAG, DEP -from .._ml import build_parser_state2vec, build_model -from .._ml import build_state2vec, build_model -from .._ml import build_debug_state2vec, build_debug_model +from .._ml import build_state2vec, build_model, precompute_hiddens USE_FTRL = True @@ -114,12 +112,12 @@ cdef class Parser: def __reduce__(self): return (Parser, (self.vocab, self.moves, self.model), None, None) - def build_model(self, width=64, nr_vector=1000, nF=1, nB=1, nS=1, nL=1, nR=1, **_): + def build_model(self, width=32, nr_vector=1000, nF=1, nB=1, nS=1, nL=1, nR=1, **_): nr_context_tokens = StateClass.nr_context_tokens(nF, nB, nS, nL, nR) - state2vec = build_state2vec(nr_context_tokens, width, nr_vector) - #state2vec = build_debug_state2vec(width, nr_vector) - model = build_debug_model(state2vec, width*2, 2, self.moves.n_moves) - return model + + return build_model_precomputer( + build_model(state2vec, width*2, 2, self.moves.n_moves) + build_feature_maps(nr_context_tokens, width, nr_vector)) def __call__(self, Doc tokens): """ @@ -132,7 +130,7 @@ cdef class Parser: """ self.parse_batch([tokens]) self.moves.finalize_doc(tokens) - + def pipe(self, stream, int batch_size=1000, int n_threads=2): """ Process a stream of documents. @@ -167,158 +165,50 @@ cdef class Parser: yield doc def parse_batch(self, docs): - states = self._init_states(docs) - nr_class = self.moves.n_moves cdef Doc doc cdef StateClass state - cdef int guess - tokvecs = [d.tensor for d in docs] - all_states = list(states) - todo = zip(states, tokvecs) + model, states = self.init_batch(docs) + todo = list(states) while todo: - states, tokvecs = zip(*todo) - scores, _ = self._begin_update(states, tokvecs) - for state, guess in zip(states, scores.argmax(axis=1)): - action = self.moves.c[guess] - action.do(state.c, action.label) - todo = filter(lambda sp: not sp[0].py_is_final(), todo) - for state, doc in zip(all_states, docs): + todo = model(todo) + for state, doc in zip(states, docs): self.moves.finalize_state(state.c) for i in range(doc.length): doc.c[i] = state.c._sent[i] - def begin_training(self, docs, golds): - for gold in golds: - self.moves.preprocess_gold(gold) - states = self._init_states(docs) - tokvecs = [d.tensor for d in docs] - d_tokens = [self.model.ops.allocate(d.tensor.shape) for d in docs] - nr_class = self.moves.n_moves - costs = self.model.ops.allocate((len(docs), nr_class), dtype='f') - gradients = self.model.ops.allocate((len(docs), nr_class), dtype='f') - is_valid = self.model.ops.allocate((len(docs), nr_class), dtype='i') - attr_names = self.model.ops.allocate((2,), dtype='i') - attr_names[0] = TAG - attr_names[1] = DEP - - features = self._get_features(states, tokvecs, attr_names) - self.model.begin_training(features) - - def update(self, docs, golds, drop=0., sgd=None): if isinstance(docs, Doc) and isinstance(golds, GoldParse): return self.update([docs], [golds], drop=drop) for gold in golds: self.moves.preprocess_gold(gold) - states = self._init_states(docs) - tokvecs = [d.tensor for d in docs] + + model, states = self.init_batch(docs) + d_tokens = [self.model.ops.allocate(d.tensor.shape) for d in docs] - nr_class = self.moves.n_moves output = list(d_tokens) - todo = zip(states, tokvecs, golds, d_tokens) - assert len(states) == len(todo) - losses = [] + todo = zip(states, golds, d_tokens) while todo: - states, tokvecs, golds, d_tokens = zip(*todo) - scores, finish_update = self._begin_update(states, tokvecs) - token_ids, batch_token_grads = finish_update(golds, sgd=sgd, losses=losses, - force_gold=False) + states, golds, d_tokens = zip(*todo) + states, finish_update = model.begin_update(states) + d_state_features = finish_update(golds, sgd=sgd) for i, tok_ids in enumerate(token_ids): for j, tok_i in enumerate(tok_ids): if tok_i >= 0: - d_tokens[i][tok_i] += batch_token_grads[i, j] - - self._transition_batch(states, scores) + d_tokens[i][tok_i] += d_state_features[i, j] # Get unfinished states (and their matching gold and token gradients) todo = filter(lambda sp: not sp[0].py_is_final(), todo) return output, sum(losses) - def _begin_update(self, states, tokvecs, drop=0.): - nr_class = self.moves.n_moves - attr_names = self.model.ops.allocate((2,), dtype='i') - attr_names[0] = TAG - attr_names[1] = DEP + def begin_training(self, docs, golds): + for gold in golds: + self.moves.preprocess_gold(gold) + states = self._init_states(docs) + tokvecs = [d.tensor for d in docs] - features = self._get_features(states, tokvecs, attr_names) - scores, finish_update = self.model.begin_update(features, drop=drop) - assert scores.shape[0] == len(states), (len(states), scores.shape) - assert len(scores.shape) == 2 - is_valid = self.model.ops.allocate((len(states), nr_class), dtype='i') - self._validate_batch(is_valid, states) - softmaxed = self.model.ops.softmax(scores) - softmaxed *= is_valid - softmaxed /= softmaxed.sum(axis=1).reshape((softmaxed.shape[0], 1)) - def backward(golds, sgd=None, losses=[], force_gold=False): - nonlocal softmaxed - costs = self.model.ops.allocate((len(states), nr_class), dtype='f') - d_scores = self.model.ops.allocate((len(states), nr_class), dtype='f') + features = self._get_features(states, tokvecs) + self.model.begin_training(features) - self._cost_batch(costs, is_valid, states, golds) - self._set_gradient(d_scores, scores, is_valid, costs) - losses.append(numpy.abs(d_scores).sum()) - if force_gold: - softmaxed *= costs <= 0 - return finish_update(d_scores, sgd=sgd) - return softmaxed, backward - - def _init_states(self, docs): - states = [] - cdef Doc doc - cdef StateClass state - for i, doc in enumerate(docs): - state = StateClass.init(doc.c, doc.length) - self.moves.initialize_state(state.c) - states.append(state) - return states - - def _get_features(self, states, all_tokvecs, attr_names, - nF=1, nB=0, nS=2, nL=2, nR=2): - n_tokens = states[0].nr_context_tokens(nF, nB, nS, nL, nR) - vector_length = all_tokvecs[0].shape[1] - tokens = self.model.ops.allocate((len(states), n_tokens), dtype='int32') - features = self.model.ops.allocate((len(states), n_tokens, attr_names.shape[0]), dtype='uint64') - tokvecs = self.model.ops.allocate((len(states), n_tokens, vector_length), dtype='f') - for i, state in enumerate(states): - state.set_context_tokens(tokens[i], nF, nB, nS, nL, nR) - state.set_attributes(features[i], tokens[i], attr_names) - state.set_token_vectors(tokvecs[i], all_tokvecs[i], tokens[i]) - return (tokens, features, tokvecs) - - def _validate_batch(self, int[:, ::1] is_valid, states): - cdef StateClass state - cdef int i - for i, state in enumerate(states): - self.moves.set_valid(&is_valid[i, 0], state.c) - - def _cost_batch(self, weight_t[:, ::1] costs, int[:, ::1] is_valid, - states, golds): - cdef int i - cdef StateClass state - cdef GoldParse gold - for i, (state, gold) in enumerate(zip(states, golds)): - self.moves.set_costs(&is_valid[i, 0], &costs[i, 0], state, gold) - - def _transition_batch(self, states, scores): - cdef StateClass state - cdef int guess - for state, guess in zip(states, scores.argmax(axis=1)): - action = self.moves.c[guess] - action.do(state.c, action.label) - - def _set_gradient(self, gradients, scores, is_valid, costs): - """Do multi-label log loss""" - cdef double Z, gZ, max_, g_max - n = gradients.shape[0] - scores = scores * is_valid - g_scores = scores * is_valid * (costs <= 0.) - exps = numpy.exp(scores - scores.max(axis=1).reshape((n, 1))) - exps *= is_valid - g_exps = numpy.exp(g_scores - g_scores.max(axis=1).reshape((n, 1))) - g_exps *= costs <= 0. - g_exps *= is_valid - gradients[:] = exps / exps.sum(axis=1).reshape((n, 1)) - gradients -= g_exps / g_exps.sum(axis=1).reshape((n, 1)) def step_through(self, Doc doc, GoldParse gold=None): """ @@ -355,6 +245,97 @@ cdef class Parser: self.cfg.setdefault('extra_labels', []).append(label) +def _transition_batch(self, states, scores): + cdef StateClass state + cdef int guess + for state, guess in zip(states, scores.argmax(axis=1)): + action = self.moves.c[guess] + action.do(state.c, action.label) + +def _set_gradient(self, gradients, scores, is_valid, costs): + """Do multi-label log loss""" + cdef double Z, gZ, max_, g_max + n = gradients.shape[0] + scores = scores * is_valid + g_scores = scores * is_valid * (costs <= 0.) + exps = numpy.exp(scores - scores.max(axis=1).reshape((n, 1))) + exps *= is_valid + g_exps = numpy.exp(g_scores - g_scores.max(axis=1).reshape((n, 1))) + g_exps *= costs <= 0. + g_exps *= is_valid + gradients[:] = exps / exps.sum(axis=1).reshape((n, 1)) + gradients -= g_exps / g_exps.sum(axis=1).reshape((n, 1)) + + +def _begin_update(self, model, states, tokvecs, drop=0.): + nr_class = self.moves.n_moves + attr_names = self.model.ops.allocate((2,), dtype='i') + attr_names[0] = TAG + attr_names[1] = DEP + + features = self._get_features(states, tokvecs, attr_names) + scores, finish_update = self.model.begin_update(features, drop=drop) + assert scores.shape[0] == len(states), (len(states), scores.shape) + assert len(scores.shape) == 2 + is_valid = self.model.ops.allocate((len(states), nr_class), dtype='i') + self._validate_batch(is_valid, states) + softmaxed = self.model.ops.softmax(scores) + softmaxed *= is_valid + softmaxed /= softmaxed.sum(axis=1).reshape((softmaxed.shape[0], 1)) + def backward(golds, sgd=None, losses=[], force_gold=False): + nonlocal softmaxed + costs = self.model.ops.allocate((len(states), nr_class), dtype='f') + d_scores = self.model.ops.allocate((len(states), nr_class), dtype='f') + + self._cost_batch(costs, is_valid, states, golds) + self._set_gradient(d_scores, scores, is_valid, costs) + losses.append(numpy.abs(d_scores).sum()) + if force_gold: + softmaxed *= costs <= 0 + return finish_update(d_scores, sgd=sgd) + return softmaxed, backward + +def _init_states(self, docs): + states = [] + cdef Doc doc + cdef StateClass state + for i, doc in enumerate(docs): + state = StateClass.init(doc.c, doc.length) + self.moves.initialize_state(state.c) + states.append(state) + return states + +def _validate_batch(self, int[:, ::1] is_valid, states): + cdef StateClass state + cdef int i + for i, state in enumerate(states): + self.moves.set_valid(&is_valid[i, 0], state.c) + +def _cost_batch(self, weight_t[:, ::1] costs, int[:, ::1] is_valid, + states, golds): + cdef int i + cdef StateClass state + cdef GoldParse gold + for i, (state, gold) in enumerate(zip(states, golds)): + self.moves.set_costs(&is_valid[i, 0], &costs[i, 0], state, gold) + + + +def _get_features(self, states, all_tokvecs, attr_names, + nF=1, nB=0, nS=2, nL=2, nR=2): + n_tokens = states[0].nr_context_tokens(nF, nB, nS, nL, nR) + vector_length = all_tokvecs[0].shape[1] + tokens = self.model.ops.allocate((len(states), n_tokens), dtype='int32') + features = self.model.ops.allocate((len(states), n_tokens, attr_names.shape[0]), dtype='uint64') + tokvecs = self.model.ops.allocate((len(states), n_tokens, vector_length), dtype='f') + for i, state in enumerate(states): + state.set_context_tokens(tokens[i], nF, nB, nS, nL, nR) + state.set_attributes(features[i], tokens[i], attr_names) + state.set_token_vectors(tokvecs[i], all_tokvecs[i], tokens[i]) + return (tokens, features, tokvecs) + + + cdef int dropout(FeatureC* feats, int nr_feat, float prob) except -1: if prob <= 0 or prob >= 1.: return 0 diff --git a/spacy/syntax/stateclass.pyx b/spacy/syntax/stateclass.pyx index 22d8134aa..d0a374b41 100644 --- a/spacy/syntax/stateclass.pyx +++ b/spacy/syntax/stateclass.pyx @@ -48,7 +48,7 @@ cdef class StateClass: @classmethod def nr_context_tokens(cls, int nF, int nB, int nS, int nL, int nR): - return 4 + return 5 def set_context_tokens(self, int[:] output, nF=1, nB=0, nS=2, nL=2, nR=2): @@ -56,14 +56,15 @@ cdef class StateClass: output[1] = self.B(1) output[2] = self.S(0) output[3] = self.S(1) - #output[4] = self.L(self.S(0), 1) - #output[5] = self.L(self.S(0), 2) - #output[6] = self.R(self.S(0), 1) - #output[7] = self.R(self.S(0), 2) - #output[7] = self.L(self.S(1), 1) - #output[8] = self.L(self.S(1), 2) - #output[9] = self.R(self.S(1), 1) - #output[10] = self.R(self.S(1), 2) + output[4] = self.S(2) + #output[5] = self.L(self.S(0), 1) + #output[6] = self.L(self.S(0), 2) + #output[7] = self.R(self.S(0), 1) + #output[8] = self.R(self.S(0), 2) + #output[10] = self.L(self.S(1), 1) + #output[11] = self.L(self.S(1), 2) + #output[12] = self.R(self.S(1), 1) + #output[13] = self.R(self.S(1), 2) def set_attributes(self, uint64_t[:, :] vals, int[:] tokens, int[:] names): cdef int i, j, tok_i From 4441866f55299ff0833b1d4b26359b24e4abf71d Mon Sep 17 00:00:00 2001 From: Matthew Honnibal Date: Sun, 7 May 2017 22:47:06 +0200 Subject: [PATCH 2/9] Checkpoint -- nearly finished reimpl --- spacy/syntax/parser.pyx | 117 +++++++++++++++++++++++++++------------- 1 file changed, 79 insertions(+), 38 deletions(-) diff --git a/spacy/syntax/parser.pyx b/spacy/syntax/parser.pyx index 1eb03fdb4..c7170c747 100644 --- a/spacy/syntax/parser.pyx +++ b/spacy/syntax/parser.pyx @@ -54,8 +54,69 @@ def set_debug(val): DEBUG = val -def get_templates(*args, **kwargs): - return [] +def get_greedy_model_for_batch(tokvecs, TransitionSystem moves, feat_maps, upper_model): + is_valid = model.ops.allocate((len(docs), system.n_moves), dtype='i') + costs = model.ops.allocate((len(docs), system.n_moves), dtype='f') + token_ids = model.ops.allocate((len(docs), StateClass.nr_context_tokens()), + dtype='uint64') + cached, backprops = zip(*[lyr.begin_update(tokvecs) for lyr in feat_maps) + + def forward(states, drop=0.): + nonlocal is_valid, costs, token_ids, features + is_valid = is_valid[:len(states)] + costs = costs[:len(states)] + token_ids = token_ids[:len(states)] + is_valid = is_valid[:len(states)] + for state in states: + state.set_context_tokens(&token_ids[i]) + moves.set_valid(&is_valid[i], state.c) + + features = cached[token_ids].sum(axis=1) + + scores, bp_scores = upper_model.begin_update(features, drop=drop) + softmaxed = model.ops.softmax(scores) + # Renormalize for invalid actions + softmaxed *= is_valid + softmaxed /= softmaxed.sum(axis=1).reshape((softmaxed.shape[0], 1)) + + def backward(golds, sgd=None): + nonlocal costs_, is_valid_, moves_ + cdef TransitionSystem moves = moves_ + cdef int[:, :] is_valid + cdef float[:, :] costs + for i, (state, gold) in enumerate(zip(states, golds)): + moves.set_costs(&costs[i], &is_valid[i], + state, gold) + set_log_loss(model.ops, d_scores, + scores, is_valid, costs) + d_tokens = bp_scores(d_scores, sgd) + return d_tokens + + return softmaxed, backward + + return layerize(forward) + + +def set_log_loss(ops, gradients, scores, is_valid, costs): + """Do multi-label log loss""" + n = gradients.shape[0] + scores = scores * is_valid + g_scores = scores * is_valid * (costs <= 0.) + exps = ops.xp.exp(scores - scores.max(axis=1).reshape((n, 1))) + exps *= is_valid + g_exps = ops.xp.exp(g_scores - g_scores.max(axis=1).reshape((n, 1))) + g_exps *= costs <= 0. + g_exps *= is_valid + gradients[:] = exps / exps.sum(axis=1).reshape((n, 1)) + gradients -= g_exps / g_exps.sum(axis=1).reshape((n, 1)) + + +def transition_batch(TransitionSystem moves, states, scores): + cdef StateClass state + cdef int guess + for state, guess in zip(states, scores.argmax(axis=1)): + action = moves.c[guess] + action.do(state.c, action.label) cdef class Parser: @@ -114,10 +175,8 @@ cdef class Parser: def build_model(self, width=32, nr_vector=1000, nF=1, nB=1, nS=1, nL=1, nR=1, **_): nr_context_tokens = StateClass.nr_context_tokens(nF, nB, nS, nL, nR) - - return build_model_precomputer( - build_model(state2vec, width*2, 2, self.moves.n_moves) - build_feature_maps(nr_context_tokens, width, nr_vector)) + self.model = build_model(width*2, 2, self.moves.n_moves) + self.feature_maps = build_feature_maps(nr_context_tokens, width, nr_vector)) def __call__(self, Doc tokens): """ @@ -129,7 +188,6 @@ cdef class Parser: None """ self.parse_batch([tokens]) - self.moves.finalize_doc(tokens) def pipe(self, stream, int batch_size=1000, int n_threads=2): """ @@ -167,14 +225,20 @@ cdef class Parser: def parse_batch(self, docs): cdef Doc doc cdef StateClass state - model, states = self.init_batch(docs) + model = get_greedy_model_for_batch([d.tensor for d in docs], + self.moves, self.model, self.feat_maps) + states = [StateClass.init(doc.c, doc.length) for doc in docs] todo = list(states) while todo: - todo = model(todo) + scores = model(todo) + transition_batch(self.moves, todo, scores) + todo = [st for st in states if not st.is_final()] for state, doc in zip(states, docs): self.moves.finalize_state(state.c) for i in range(doc.length): doc.c[i] = state.c._sent[i] + for doc in docs: + self.moves.finalize_parse(doc) def update(self, docs, golds, drop=0., sgd=None): if isinstance(docs, Doc) and isinstance(golds, GoldParse): @@ -182,20 +246,19 @@ cdef class Parser: for gold in golds: self.moves.preprocess_gold(gold) - model, states = self.init_batch(docs) + model = get_greedy_model_for_batch([d.tensor for d in docs], + self.moves, self.model, self.feat_maps) d_tokens = [self.model.ops.allocate(d.tensor.shape) for d in docs] output = list(d_tokens) todo = zip(states, golds, d_tokens) while todo: states, golds, d_tokens = zip(*todo) - states, finish_update = model.begin_update(states) + scores, finish_update = model.begin_update(token_ids) d_state_features = finish_update(golds, sgd=sgd) - for i, tok_ids in enumerate(token_ids): - for j, tok_i in enumerate(tok_ids): - if tok_i >= 0: - d_tokens[i][tok_i] += d_state_features[i, j] - + for i, token_ids in enumerate(token_ids): + d_tokens[i][token_ids] += d_state_features[i] + transition_batch(self.moves, states) # Get unfinished states (and their matching gold and token gradients) todo = filter(lambda sp: not sp[0].py_is_final(), todo) return output, sum(losses) @@ -245,28 +308,6 @@ cdef class Parser: self.cfg.setdefault('extra_labels', []).append(label) -def _transition_batch(self, states, scores): - cdef StateClass state - cdef int guess - for state, guess in zip(states, scores.argmax(axis=1)): - action = self.moves.c[guess] - action.do(state.c, action.label) - -def _set_gradient(self, gradients, scores, is_valid, costs): - """Do multi-label log loss""" - cdef double Z, gZ, max_, g_max - n = gradients.shape[0] - scores = scores * is_valid - g_scores = scores * is_valid * (costs <= 0.) - exps = numpy.exp(scores - scores.max(axis=1).reshape((n, 1))) - exps *= is_valid - g_exps = numpy.exp(g_scores - g_scores.max(axis=1).reshape((n, 1))) - g_exps *= costs <= 0. - g_exps *= is_valid - gradients[:] = exps / exps.sum(axis=1).reshape((n, 1)) - gradients -= g_exps / g_exps.sum(axis=1).reshape((n, 1)) - - def _begin_update(self, model, states, tokvecs, drop=0.): nr_class = self.moves.n_moves attr_names = self.model.ops.allocate((2,), dtype='i') From 35458987e8d93db025bedabb7c5045ac7060f7ff Mon Sep 17 00:00:00 2001 From: Matthew Honnibal Date: Sun, 7 May 2017 23:05:01 +0200 Subject: [PATCH 3/9] Checkpoint -- nearly finished reimpl --- spacy/syntax/parser.pyx | 83 +++++++++++++++++++---------------------- 1 file changed, 39 insertions(+), 44 deletions(-) diff --git a/spacy/syntax/parser.pyx b/spacy/syntax/parser.pyx index c7170c747..984cc1b5b 100644 --- a/spacy/syntax/parser.pyx +++ b/spacy/syntax/parser.pyx @@ -28,6 +28,8 @@ from murmurhash.mrmr cimport hash64 from preshed.maps cimport MapStruct from preshed.maps cimport map_get +from thinc.api import layerize + from numpy import exp from . import _parse_features @@ -55,40 +57,45 @@ def set_debug(val): def get_greedy_model_for_batch(tokvecs, TransitionSystem moves, feat_maps, upper_model): - is_valid = model.ops.allocate((len(docs), system.n_moves), dtype='i') - costs = model.ops.allocate((len(docs), system.n_moves), dtype='f') - token_ids = model.ops.allocate((len(docs), StateClass.nr_context_tokens()), - dtype='uint64') - cached, backprops = zip(*[lyr.begin_update(tokvecs) for lyr in feat_maps) + cdef int[:, :] is_valid_ + cdef float[:, :] costs_ + cdef int[:, :] token_ids + is_valid = upper_model.ops.allocate((len(tokvecs), moves.n_moves), dtype='i') + costs = upper_model.ops.allocate((len(tokvecs), moves.n_moves), dtype='f') + token_ids = upper_model.ops.allocate((len(tokvecs), StateClass.nr_context_tokens()), + dtype='uint64') + cached, backprops = zip(*[lyr.begin_update(tokvecs) for lyr in feat_maps]) + is_valid_ = is_valid + costs_ = costs def forward(states, drop=0.): - nonlocal is_valid, costs, token_ids, features + nonlocal is_valid, costs, token_ids, moves is_valid = is_valid[:len(states)] costs = costs[:len(states)] token_ids = token_ids[:len(states)] is_valid = is_valid[:len(states)] - for state in states: - state.set_context_tokens(&token_ids[i]) - moves.set_valid(&is_valid[i], state.c) + cdef StateClass state + for i, state in enumerate(states): + state.set_context_tokens(token_ids[i]) + moves.set_valid(&is_valid_[i, 0], state.c) features = cached[token_ids].sum(axis=1) scores, bp_scores = upper_model.begin_update(features, drop=drop) - softmaxed = model.ops.softmax(scores) + softmaxed = upper_model.ops.softmax(scores) # Renormalize for invalid actions softmaxed *= is_valid softmaxed /= softmaxed.sum(axis=1).reshape((softmaxed.shape[0], 1)) def backward(golds, sgd=None): - nonlocal costs_, is_valid_, moves_ - cdef TransitionSystem moves = moves_ - cdef int[:, :] is_valid - cdef float[:, :] costs + nonlocal costs_, is_valid_, moves for i, (state, gold) in enumerate(zip(states, golds)): - moves.set_costs(&costs[i], &is_valid[i], + moves.set_costs(&is_valid_[i, 0], &costs_[i, 0], state, gold) - set_log_loss(model.ops, d_scores, - scores, is_valid, costs) + d_scores = scores.copy() + d_scores.fill(0) + set_log_loss(upper_model.ops, d_scores, + scores, is_valid_, costs_) d_tokens = bp_scores(d_scores, sgd) return d_tokens @@ -119,6 +126,17 @@ def transition_batch(TransitionSystem moves, states, scores): action.do(state.c, action.label) +def init_states(TransitionSystem moves, docs): + states = [] + cdef Doc doc + cdef StateClass state + for i, doc in enumerate(docs): + state = StateClass.init(doc.c, doc.length) + moves.initialize_state(state.c) + states.append(state) + return states + + cdef class Parser: """ Base class of the DependencyParser and EntityRecognizer. @@ -176,7 +194,8 @@ cdef class Parser: def build_model(self, width=32, nr_vector=1000, nF=1, nB=1, nS=1, nL=1, nR=1, **_): nr_context_tokens = StateClass.nr_context_tokens(nF, nB, nS, nL, nR) self.model = build_model(width*2, 2, self.moves.n_moves) - self.feature_maps = build_feature_maps(nr_context_tokens, width, nr_vector)) + # TODO + self.feature_maps = [] #build_feature_maps(nr_context_tokens, width, nr_vector) def __call__(self, Doc tokens): """ @@ -248,6 +267,7 @@ cdef class Parser: model = get_greedy_model_for_batch([d.tensor for d in docs], self.moves, self.model, self.feat_maps) + states = init_states(self.moves, docs) d_tokens = [self.model.ops.allocate(d.tensor.shape) for d in docs] output = list(d_tokens) @@ -261,7 +281,7 @@ cdef class Parser: transition_batch(self.moves, states) # Get unfinished states (and their matching gold and token gradients) todo = filter(lambda sp: not sp[0].py_is_final(), todo) - return output, sum(losses) + return output def begin_training(self, docs, golds): for gold in golds: @@ -336,31 +356,6 @@ def _begin_update(self, model, states, tokvecs, drop=0.): return finish_update(d_scores, sgd=sgd) return softmaxed, backward -def _init_states(self, docs): - states = [] - cdef Doc doc - cdef StateClass state - for i, doc in enumerate(docs): - state = StateClass.init(doc.c, doc.length) - self.moves.initialize_state(state.c) - states.append(state) - return states - -def _validate_batch(self, int[:, ::1] is_valid, states): - cdef StateClass state - cdef int i - for i, state in enumerate(states): - self.moves.set_valid(&is_valid[i, 0], state.c) - -def _cost_batch(self, weight_t[:, ::1] costs, int[:, ::1] is_valid, - states, golds): - cdef int i - cdef StateClass state - cdef GoldParse gold - for i, (state, gold) in enumerate(zip(states, golds)): - self.moves.set_costs(&is_valid[i, 0], &costs[i, 0], state, gold) - - def _get_features(self, states, all_tokvecs, attr_names, nF=1, nB=0, nS=2, nL=2, nR=2): From 10682d35abfcaba4619a1a381c881b1a08891946 Mon Sep 17 00:00:00 2001 From: Matthew Honnibal Date: Mon, 8 May 2017 00:38:35 +0200 Subject: [PATCH 4/9] Get pre-computed version working --- bin/parser/train_ud.py | 9 ++-- spacy/_ml.py | 85 +++++++++++++++++---------------- spacy/syntax/parser.pxd | 1 + spacy/syntax/parser.pyx | 102 ++++++++++++++++++++++------------------ 4 files changed, 106 insertions(+), 91 deletions(-) diff --git a/bin/parser/train_ud.py b/bin/parser/train_ud.py index 793f1c821..fb8a746d2 100644 --- a/bin/parser/train_ud.py +++ b/bin/parser/train_ud.py @@ -144,7 +144,6 @@ def main(lang_name, train_loc, dev_loc, model_dir, clusters_loc=None): docs = list(Xs) for doc in docs: encoder(doc) - parser.begin_training(docs, ys) nn_loss = [0.] def track_progress(): scorer = score_model(vocab, encoder, tagger, parser, dev_Xs, dev_ys) @@ -153,7 +152,7 @@ def main(lang_name, train_loc, dev_loc, model_dir, clusters_loc=None): nn_loss.append(0.) trainer.each_epoch.append(track_progress) trainer.batch_size = 12 - trainer.nb_epoch = 2 + trainer.nb_epoch = 20 for docs, golds in trainer.iterate(Xs, ys, progress_bar=False): docs = [Doc(vocab, words=[w.text for w in doc]) for doc in docs] tokvecs, upd_tokvecs = encoder.begin_update(docs) @@ -161,9 +160,9 @@ def main(lang_name, train_loc, dev_loc, model_dir, clusters_loc=None): doc.tensor = tokvec for doc, gold in zip(docs, golds): tagger.update(doc, gold) - d_tokvecs, loss = parser.update(docs, golds, sgd=optimizer) + d_tokvecs = parser.update(docs, golds, sgd=optimizer) upd_tokvecs(d_tokvecs, sgd=optimizer) - nn_loss[-1] += loss + #nn_loss[-1] += loss nlp = LangClass(vocab=vocab, tagger=tagger, parser=parser) #nlp.end_training(model_dir) #scorer = score_model(vocab, tagger, parser, read_conllx(dev_loc)) @@ -173,7 +172,7 @@ def main(lang_name, train_loc, dev_loc, model_dir, clusters_loc=None): if __name__ == '__main__': import cProfile import pstats - if 0: + if 1: plac.call(main) else: cProfile.runctx("plac.call(main)", globals(), locals(), "Profile.prof") diff --git a/spacy/_ml.py b/spacy/_ml.py index be82c6b2c..0aee30df6 100644 --- a/spacy/_ml.py +++ b/spacy/_ml.py @@ -51,47 +51,6 @@ def doc2feats(cols): model = layerize(forward) return model - -def build_feature_precomputer(model, feat_maps): - '''Allow a model to be "primed" by pre-computing input features in bulk. - - This is used for the parser, where we want to take a batch of documents, - and compute vectors for each (token, position) pair. These vectors can then - be reused, especially for beam-search. - - Let's say we're using 12 features for each state, e.g. word at start of - buffer, three words on stack, their children, etc. In the normal arc-eager - system, a document of length N is processed in 2*N states. This means we'll - create 2*N*12 feature vectors --- but if we pre-compute, we only need - N*12 vector computations. The saving for beam-search is much better: - if we have a beam of k, we'll normally make 2*N*12*K computations -- - so we can save the factor k. This also gives a nice CPU/GPU division: - we can do all our hard maths up front, packed into large multiplications, - and do the hard-to-program parsing on the CPU. - ''' - def precompute(input_vectors): - cached, backprops = zip(*[lyr.begin_update(input_vectors) - for lyr in feat_maps) - def forward(batch_token_ids, drop=0.): - output = ops.allocate((batch_size, output_width)) - # i: batch index - # j: position index (i.e. N0, S0, etc - # tok_i: Index of the token within its document - for i, token_ids in enumerate(batch_token_ids): - for j, tok_i in enumerate(token_ids): - output[i] += cached[j][tok_i] - def backward(d_vector, sgd=None): - d_inputs = ops.allocate((batch_size, n_feat, vec_width)) - for i, token_ids in enumerate(batch_token_ids): - for j in range(len(token_ids)): - d_inputs[i][j] = backprops[j](d_vector, sgd) - # Return the IDs, so caller can associate to correct token - return (batch_token_ids, d_inputs) - return vector, backward - return chain(layerize(forward), model) - return precompute - - def print_shape(prefix): def forward(X, drop=0.): return X, lambda dX, **kwargs: dX @@ -114,3 +73,47 @@ def flatten(seqs, drop=0.): return d_X X = ops.xp.concatenate([ops.asarray(seq) for seq in seqs]) return X, finish_update + + + +#def build_feature_precomputer(model, feat_maps): +# '''Allow a model to be "primed" by pre-computing input features in bulk. +# +# This is used for the parser, where we want to take a batch of documents, +# and compute vectors for each (token, position) pair. These vectors can then +# be reused, especially for beam-search. +# +# Let's say we're using 12 features for each state, e.g. word at start of +# buffer, three words on stack, their children, etc. In the normal arc-eager +# system, a document of length N is processed in 2*N states. This means we'll +# create 2*N*12 feature vectors --- but if we pre-compute, we only need +# N*12 vector computations. The saving for beam-search is much better: +# if we have a beam of k, we'll normally make 2*N*12*K computations -- +# so we can save the factor k. This also gives a nice CPU/GPU division: +# we can do all our hard maths up front, packed into large multiplications, +# and do the hard-to-program parsing on the CPU. +# ''' +# def precompute(input_vectors): +# cached, backprops = zip(*[lyr.begin_update(input_vectors) +# for lyr in feat_maps) +# def forward(batch_token_ids, drop=0.): +# output = ops.allocate((batch_size, output_width)) +# # i: batch index +# # j: position index (i.e. N0, S0, etc +# # tok_i: Index of the token within its document +# for i, token_ids in enumerate(batch_token_ids): +# for j, tok_i in enumerate(token_ids): +# output[i] += cached[j][tok_i] +# def backward(d_vector, sgd=None): +# d_inputs = ops.allocate((batch_size, n_feat, vec_width)) +# for i, token_ids in enumerate(batch_token_ids): +# for j in range(len(token_ids)): +# d_inputs[i][j] = backprops[j](d_vector, sgd) +# # Return the IDs, so caller can associate to correct token +# return (batch_token_ids, d_inputs) +# return vector, backward +# return chain(layerize(forward), model) +# return precompute +# +# + diff --git a/spacy/syntax/parser.pxd b/spacy/syntax/parser.pxd index 0b3279a1b..42088a9ff 100644 --- a/spacy/syntax/parser.pxd +++ b/spacy/syntax/parser.pxd @@ -13,5 +13,6 @@ cdef class Parser: cdef readonly object model cdef readonly TransitionSystem moves cdef readonly object cfg + cdef public object feature_maps #cdef int parseC(self, TokenC* tokens, int length, int nr_feat) nogil diff --git a/spacy/syntax/parser.pyx b/spacy/syntax/parser.pyx index 984cc1b5b..84a927b94 100644 --- a/spacy/syntax/parser.pyx +++ b/spacy/syntax/parser.pyx @@ -28,10 +28,11 @@ from murmurhash.mrmr cimport hash64 from preshed.maps cimport MapStruct from preshed.maps cimport map_get -from thinc.api import layerize -from numpy import exp +from thinc.api import layerize, chain +from thinc.neural import Model, Maxout +from .._ml import get_col from . import _parse_features from ._parse_features cimport CONTEXT_SIZE from ._parse_features cimport fill_context @@ -46,8 +47,9 @@ from ..strings cimport StringStore from ..gold cimport GoldParse from ..attrs cimport TAG, DEP -from .._ml import build_state2vec, build_model, precompute_hiddens +def get_templates(*args, **kwargs): + return [] USE_FTRL = True DEBUG = False @@ -56,30 +58,39 @@ def set_debug(val): DEBUG = val -def get_greedy_model_for_batch(tokvecs, TransitionSystem moves, feat_maps, upper_model): +def get_greedy_model_for_batch(tokvecs, TransitionSystem moves, upper_model, feat_maps): cdef int[:, :] is_valid_ cdef float[:, :] costs_ - cdef int[:, :] token_ids + lengths = [len(t) for t in tokvecs] + tokvecs = upper_model.ops.flatten(tokvecs) is_valid = upper_model.ops.allocate((len(tokvecs), moves.n_moves), dtype='i') costs = upper_model.ops.allocate((len(tokvecs), moves.n_moves), dtype='f') - token_ids = upper_model.ops.allocate((len(tokvecs), StateClass.nr_context_tokens()), - dtype='uint64') + token_ids = upper_model.ops.allocate((len(tokvecs), len(feat_maps)), dtype='i') cached, backprops = zip(*[lyr.begin_update(tokvecs) for lyr in feat_maps]) is_valid_ = is_valid costs_ = costs - def forward(states, drop=0.): + def forward(states_offsets, drop=0.): nonlocal is_valid, costs, token_ids, moves + states, offsets = states_offsets is_valid = is_valid[:len(states)] costs = costs[:len(states)] token_ids = token_ids[:len(states)] is_valid = is_valid[:len(states)] cdef StateClass state - for i, state in enumerate(states): + cdef int i + for i, (offset, state) in enumerate(zip(offsets, states)): state.set_context_tokens(token_ids[i]) moves.set_valid(&is_valid_[i, 0], state.c) - - features = cached[token_ids].sum(axis=1) + adjusted_ids = token_ids.copy() + for i, offset in enumerate(offsets): + adjusted_ids[i] *= token_ids[i] >= 0 + adjusted_ids[i] += offset + features = upper_model.ops.allocate((len(states), 64), dtype='f') + for i in range(len(states)): + for j, tok_i in enumerate(adjusted_ids[i]): + if tok_i >= 0: + features[i] += cached[j][tok_i] scores, bp_scores = upper_model.begin_update(features, drop=drop) softmaxed = upper_model.ops.softmax(scores) @@ -89,15 +100,16 @@ def get_greedy_model_for_batch(tokvecs, TransitionSystem moves, feat_maps, upper def backward(golds, sgd=None): nonlocal costs_, is_valid_, moves + cdef int i for i, (state, gold) in enumerate(zip(states, golds)): moves.set_costs(&is_valid_[i, 0], &costs_[i, 0], state, gold) d_scores = scores.copy() d_scores.fill(0) set_log_loss(upper_model.ops, d_scores, - scores, is_valid_, costs_) + scores, is_valid, costs) d_tokens = bp_scores(d_scores, sgd) - return d_tokens + return (token_ids, d_tokens) return softmaxed, backward @@ -127,14 +139,18 @@ def transition_batch(TransitionSystem moves, states, scores): def init_states(TransitionSystem moves, docs): - states = [] cdef Doc doc cdef StateClass state + offsets = [] + states = [] + offset = 0 for i, doc in enumerate(docs): state = StateClass.init(doc.c, doc.length) moves.initialize_state(state.c) states.append(state) - return states + offsets.append(offset) + offset += len(doc) + return states, offsets cdef class Parser: @@ -184,18 +200,22 @@ cdef class Parser: cfg['actions'] = TransitionSystem.get_actions(**cfg) self.moves = TransitionSystem(vocab.strings, cfg['actions']) if model is None: - model = self.build_model(**cfg) - self.model = model + self.model, self.feature_maps = self.build_model(**cfg) + else: + self.model, self.feature_maps = model self.cfg = cfg def __reduce__(self): return (Parser, (self.vocab, self.moves, self.model), None, None) - def build_model(self, width=32, nr_vector=1000, nF=1, nB=1, nS=1, nL=1, nR=1, **_): + def build_model(self, width=64, nr_vector=1000, nF=1, nB=1, nS=1, nL=1, nR=1, **_): nr_context_tokens = StateClass.nr_context_tokens(nF, nB, nS, nL, nR) - self.model = build_model(width*2, 2, self.moves.n_moves) + + model = chain(Maxout(width, width), Maxout(self.moves.n_moves, width)) # TODO - self.feature_maps = [] #build_feature_maps(nr_context_tokens, width, nr_vector) + feature_maps = [Maxout(width, width) + for i in range(nr_context_tokens)] + return model, feature_maps def __call__(self, Doc tokens): """ @@ -245,19 +265,21 @@ cdef class Parser: cdef Doc doc cdef StateClass state model = get_greedy_model_for_batch([d.tensor for d in docs], - self.moves, self.model, self.feat_maps) - states = [StateClass.init(doc.c, doc.length) for doc in docs] - todo = list(states) + self.moves, self.model, self.feature_maps) + states, offsets = init_states(self.moves, docs) + all_states = list(states) + todo = list(zip(states, offsets)) while todo: - scores = model(todo) - transition_batch(self.moves, todo, scores) - todo = [st for st in states if not st.is_final()] - for state, doc in zip(states, docs): + states, offsets = zip(*todo) + scores = model((states, offsets)) + transition_batch(self.moves, states, scores) + todo = [st for st in todo if not st[0].py_is_final()] + for state, doc in zip(all_states, docs): self.moves.finalize_state(state.c) for i in range(doc.length): doc.c[i] = state.c._sent[i] for doc in docs: - self.moves.finalize_parse(doc) + self.moves.finalize_doc(doc) def update(self, docs, golds, drop=0., sgd=None): if isinstance(docs, Doc) and isinstance(golds, GoldParse): @@ -266,33 +288,23 @@ cdef class Parser: self.moves.preprocess_gold(gold) model = get_greedy_model_for_batch([d.tensor for d in docs], - self.moves, self.model, self.feat_maps) - states = init_states(self.moves, docs) + self.moves, self.model, self.feature_maps) + states, offsets = init_states(self.moves, docs) d_tokens = [self.model.ops.allocate(d.tensor.shape) for d in docs] output = list(d_tokens) - todo = zip(states, golds, d_tokens) + todo = zip(states, offsets, golds, d_tokens) while todo: - states, golds, d_tokens = zip(*todo) - scores, finish_update = model.begin_update(token_ids) - d_state_features = finish_update(golds, sgd=sgd) + states, offsets, golds, d_tokens = zip(*todo) + scores, finish_update = model.begin_update((states, offsets)) + (token_ids, d_state_features) = finish_update(golds, sgd=sgd) for i, token_ids in enumerate(token_ids): d_tokens[i][token_ids] += d_state_features[i] - transition_batch(self.moves, states) + transition_batch(self.moves, states, scores) # Get unfinished states (and their matching gold and token gradients) todo = filter(lambda sp: not sp[0].py_is_final(), todo) return output - def begin_training(self, docs, golds): - for gold in golds: - self.moves.preprocess_gold(gold) - states = self._init_states(docs) - tokvecs = [d.tensor for d in docs] - - features = self._get_features(states, tokvecs) - self.model.begin_training(features) - - def step_through(self, Doc doc, GoldParse gold=None): """ Set up a stepwise state, to introspect and control the transition sequence. From 2e2268a442e4a060f0d8d0d1bd68fd45566fcc53 Mon Sep 17 00:00:00 2001 From: Matthew Honnibal Date: Mon, 8 May 2017 11:36:37 +0200 Subject: [PATCH 5/9] Precomputable hidden now working --- bin/parser/train_ud.py | 4 +-- spacy/_ml.py | 54 ++++++++++++++++++++++++++++++++++++++++- spacy/syntax/parser.pyx | 30 ++++++++++++++--------- 3 files changed, 74 insertions(+), 14 deletions(-) diff --git a/bin/parser/train_ud.py b/bin/parser/train_ud.py index fb8a746d2..7b8aedfd3 100644 --- a/bin/parser/train_ud.py +++ b/bin/parser/train_ud.py @@ -138,8 +138,8 @@ def main(lang_name, train_loc, dev_loc, model_dir, clusters_loc=None): Xs, ys = organize_data(vocab, train_sents) dev_Xs, dev_ys = organize_data(vocab, dev_sents) - Xs = Xs[:100] - ys = ys[:100] + Xs = Xs[:1000] + ys = ys[:1000] with encoder.model.begin_training(Xs[:100], ys[:100]) as (trainer, optimizer): docs = list(Xs) for doc in docs: diff --git a/spacy/_ml.py b/spacy/_ml.py index 0aee30df6..e62971a69 100644 --- a/spacy/_ml.py +++ b/spacy/_ml.py @@ -5,9 +5,61 @@ from thinc.neural._classes.hash_embed import HashEmbed from thinc.neural._classes.convolution import ExtractWindow from thinc.neural._classes.static_vectors import StaticVectors from thinc.neural._classes.batchnorm import BatchNorm - +from thinc import describe +from thinc.describe import Dimension, Synapses, Biases, Gradient +from thinc.neural._classes.affine import _set_dimensions_if_needed from .attrs import ID, LOWER, PREFIX, SUFFIX, SHAPE, TAG, DEP +import numpy + + +@describe.on_data(_set_dimensions_if_needed) +@describe.attributes( + nI=Dimension("Input size"), + nF=Dimension("Number of features"), + nO=Dimension("Output size"), + W=Synapses("Weights matrix", + lambda obj: (obj.nO, obj.nF, obj.nI), + lambda W, ops: ops.xavier_uniform_init(W)), + b=Biases("Bias vector", + lambda obj: (obj.nO,)), + d_W=Gradient("W"), + d_b=Gradient("b") +) +class PrecomputableAffine(Model): + def __init__(self, nO=None, nI=None, nF=None, **kwargs): + Model.__init__(self, **kwargs) + self.nO = nO + self.nI = nI + self.nF = nF + + def begin_update(self, X, drop=0.): + # X: (b, i) + # Xf: (b, f, i) + # dY: (b, o) + # dYf: (b, f, o) + #Yf = numpy.einsum('bi,ofi->bfo', X, self.W) + Yf = self.ops.xp.tensordot( + X, self.W, axes=[[1], [2]]).transpose((0, 2, 1)) + Yf += self.b + def backward(dY_ids, sgd=None): + dY, ids = dY_ids + Xf = X[ids] + + #dW = numpy.einsum('bo,bfi->ofi', dY, Xf) + dW = self.ops.xp.tensordot(dY, Xf, axes=[[0], [0]]) + db = dY.sum(axis=0) + #dXf = numpy.einsum('bo,ofi->bfi', dY, self.W) + dXf = self.ops.xp.tensordot(dY, self.W, axes=[[1], [0]]) + + self.d_W += dW + self.d_b += db + + if sgd is not None: + sgd(self._mem.weights, self._mem.gradient, key=self.id) + return dXf + return Yf, backward + def get_col(idx): def forward(X, drop=0.): diff --git a/spacy/syntax/parser.pyx b/spacy/syntax/parser.pyx index 84a927b94..d597219e2 100644 --- a/spacy/syntax/parser.pyx +++ b/spacy/syntax/parser.pyx @@ -32,7 +32,7 @@ from preshed.maps cimport map_get from thinc.api import layerize, chain from thinc.neural import Model, Maxout -from .._ml import get_col +from .._ml import PrecomputableAffine from . import _parse_features from ._parse_features cimport CONTEXT_SIZE from ._parse_features cimport fill_context @@ -58,21 +58,24 @@ def set_debug(val): DEBUG = val -def get_greedy_model_for_batch(tokvecs, TransitionSystem moves, upper_model, feat_maps): +def get_greedy_model_for_batch(tokvecs, TransitionSystem moves, upper_model, lower_model): cdef int[:, :] is_valid_ cdef float[:, :] costs_ lengths = [len(t) for t in tokvecs] tokvecs = upper_model.ops.flatten(tokvecs) is_valid = upper_model.ops.allocate((len(tokvecs), moves.n_moves), dtype='i') costs = upper_model.ops.allocate((len(tokvecs), moves.n_moves), dtype='f') - token_ids = upper_model.ops.allocate((len(tokvecs), len(feat_maps)), dtype='i') - cached, backprops = zip(*[lyr.begin_update(tokvecs) for lyr in feat_maps]) + token_ids = upper_model.ops.allocate((len(tokvecs), lower_model.nF), dtype='i') + + cached, bp_features = lower_model.begin_update(tokvecs, drop=0.) + is_valid_ = is_valid costs_ = costs def forward(states_offsets, drop=0.): nonlocal is_valid, costs, token_ids, moves states, offsets = states_offsets + assert len(states) != 0 is_valid = is_valid[:len(states)] costs = costs[:len(states)] token_ids = token_ids[:len(states)] @@ -90,12 +93,17 @@ def get_greedy_model_for_batch(tokvecs, TransitionSystem moves, upper_model, fea for i in range(len(states)): for j, tok_i in enumerate(adjusted_ids[i]): if tok_i >= 0: - features[i] += cached[j][tok_i] + features[i] += cached[tok_i, j] scores, bp_scores = upper_model.begin_update(features, drop=drop) + scores = upper_model.ops.relu(scores) softmaxed = upper_model.ops.softmax(scores) # Renormalize for invalid actions softmaxed *= is_valid + totals = softmaxed.sum(axis=1) + for total in totals: + assert total > 0, (totals, scores, softmaxed) + assert total <= 1.1, totals softmaxed /= softmaxed.sum(axis=1).reshape((softmaxed.shape[0], 1)) def backward(golds, sgd=None): @@ -108,7 +116,9 @@ def get_greedy_model_for_batch(tokvecs, TransitionSystem moves, upper_model, fea d_scores.fill(0) set_log_loss(upper_model.ops, d_scores, scores, is_valid, costs) - d_tokens = bp_scores(d_scores, sgd) + upper_model.ops.backprop_relu(d_scores, scores, inplace=True) + d_features = bp_scores(d_scores, sgd) + d_tokens = bp_features((d_features, adjusted_ids), sgd) return (token_ids, d_tokens) return softmaxed, backward @@ -211,11 +221,9 @@ cdef class Parser: def build_model(self, width=64, nr_vector=1000, nF=1, nB=1, nS=1, nL=1, nR=1, **_): nr_context_tokens = StateClass.nr_context_tokens(nF, nB, nS, nL, nR) - model = chain(Maxout(width, width), Maxout(self.moves.n_moves, width)) - # TODO - feature_maps = [Maxout(width, width) - for i in range(nr_context_tokens)] - return model, feature_maps + upper = chain(Maxout(width, width), Maxout(self.moves.n_moves, width)) + lower = PrecomputableAffine(width, nF=nr_context_tokens, nI=width) + return upper, lower def __call__(self, Doc tokens): """ From 807cb2e3703806c96c8e4c47db2fc0ef5380d39b Mon Sep 17 00:00:00 2001 From: Matthew Honnibal Date: Mon, 8 May 2017 14:24:43 +0200 Subject: [PATCH 6/9] Add PretrainableMaxouts --- spacy/_ml.py | 58 ++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 58 insertions(+) diff --git a/spacy/_ml.py b/spacy/_ml.py index e62971a69..d3bb903e7 100644 --- a/spacy/_ml.py +++ b/spacy/_ml.py @@ -61,6 +61,64 @@ class PrecomputableAffine(Model): return Yf, backward +@describe.on_data(_set_dimensions_if_needed) +@describe.attributes( + nI=Dimension("Input size"), + nF=Dimension("Number of features"), + nP=Dimension("Number of pieces"), + nO=Dimension("Output size"), + W=Synapses("Weights matrix", + lambda obj: (obj.nF, obj.nO, obj.nP, obj.nI), + lambda W, ops: ops.xavier_uniform_init(W)), + b=Biases("Bias vector", + lambda obj: (obj.nO, obj.nP)), + d_W=Gradient("W"), + d_b=Gradient("b") +) +class PrecomputableMaxouts(Model): + def __init__(self, nO=None, nI=None, nF=None, pieces=2, **kwargs): + Model.__init__(self, **kwargs) + self.nO = nO + self.nP = pieces + self.nI = nI + self.nF = nF + + def begin_update(self, X, drop=0.): + # X: (b, i) + # Yfp: (f, b, o, p) + # Yf: (f, b, o) + # Xf: (b, f, i) + # dY: (b, o) + # dYp: (b, o, p) + # W: (f, o, p, i) + # b: (o, p) + + Yfp = numpy.einsum('bi,fopi->fbop', X, self.W) + Yfp += self.b + Yf = self.ops.allocate((self.nF, X.shape[0], self.nO)) + which = self.ops.allocate((self.nF, X.shape[0], self.nO), dtype='i') + for i in range(self.nF): + Yf[i], which[i] = self.ops.maxout(Yfp[i]) + def backward(dY_ids, sgd=None): + dY, ids = dY_ids + Xf = X[ids] + dYp = self.ops.allocate((dY.shape[0], self.nO, self.nP)) + for i in range(self.nF): + dYp += self.ops.backprop_maxout(dY, which[i], self.nP) + + dXf = numpy.einsum('bop,fopi->bfi', dYp, self.W) + dW = numpy.einsum('bop,bfi->fopi', dYp, Xf) + db = dYp.sum(axis=0) + + self.d_W += dW + self.d_b += db + + if sgd is not None: + sgd(self._mem.weights, self._mem.gradient, key=self.id) + return dXf + return Yf, backward + + def get_col(idx): def forward(X, drop=0.): assert len(X.shape) <= 3 From 8d2eab74da43d9dcbf1a038a79d8bcd36d45c5e4 Mon Sep 17 00:00:00 2001 From: Matthew Honnibal Date: Mon, 8 May 2017 14:24:55 +0200 Subject: [PATCH 7/9] Use PretrainableMaxouts --- spacy/syntax/parser.pyx | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/spacy/syntax/parser.pyx b/spacy/syntax/parser.pyx index d597219e2..4989a1fb3 100644 --- a/spacy/syntax/parser.pyx +++ b/spacy/syntax/parser.pyx @@ -32,7 +32,7 @@ from preshed.maps cimport map_get from thinc.api import layerize, chain from thinc.neural import Model, Maxout -from .._ml import PrecomputableAffine +from .._ml import PrecomputableAffine, PrecomputableMaxouts from . import _parse_features from ._parse_features cimport CONTEXT_SIZE from ._parse_features cimport fill_context @@ -93,7 +93,7 @@ def get_greedy_model_for_batch(tokvecs, TransitionSystem moves, upper_model, low for i in range(len(states)): for j, tok_i in enumerate(adjusted_ids[i]): if tok_i >= 0: - features[i] += cached[tok_i, j] + features[i] += cached[j, tok_i] scores, bp_scores = upper_model.begin_update(features, drop=drop) scores = upper_model.ops.relu(scores) @@ -222,7 +222,7 @@ cdef class Parser: nr_context_tokens = StateClass.nr_context_tokens(nF, nB, nS, nL, nR) upper = chain(Maxout(width, width), Maxout(self.moves.n_moves, width)) - lower = PrecomputableAffine(width, nF=nr_context_tokens, nI=width) + lower = PrecomputableMaxouts(width, nF=nr_context_tokens, nI=width) return upper, lower def __call__(self, Doc tokens): From a66a4a4d0fb7c99bf5eb1ba81cef8ae1564c0c45 Mon Sep 17 00:00:00 2001 From: Matthew Honnibal Date: Mon, 8 May 2017 14:46:50 +0200 Subject: [PATCH 8/9] Replace einsums --- spacy/_ml.py | 12 +++++++++--- spacy/pipeline.pyx | 2 +- spacy/syntax/parser.pyx | 4 ++-- 3 files changed, 12 insertions(+), 6 deletions(-) diff --git a/spacy/_ml.py b/spacy/_ml.py index d3bb903e7..f9894cd54 100644 --- a/spacy/_ml.py +++ b/spacy/_ml.py @@ -93,7 +93,10 @@ class PrecomputableMaxouts(Model): # W: (f, o, p, i) # b: (o, p) - Yfp = numpy.einsum('bi,fopi->fbop', X, self.W) + # Yfp = numpy.einsum('bi,fopi->fbop', X, self.W) + Yfp = self.ops.xp.tensordot(X, self.W, + axes=[[1], [3]]).transpose((1, 0, 2, 3)) + Yfp = self.ops.xp.ascontiguousarray(Yfp) Yfp += self.b Yf = self.ops.allocate((self.nF, X.shape[0], self.nO)) which = self.ops.allocate((self.nF, X.shape[0], self.nO), dtype='i') @@ -106,8 +109,11 @@ class PrecomputableMaxouts(Model): for i in range(self.nF): dYp += self.ops.backprop_maxout(dY, which[i], self.nP) - dXf = numpy.einsum('bop,fopi->bfi', dYp, self.W) - dW = numpy.einsum('bop,bfi->fopi', dYp, Xf) + #dXf = numpy.einsum('bop,fopi->bfi', dYp, self.W) + dXf = self.ops.xp.tensordot(dYp, self.W, axes=[[1,2], [1,2]]) + #dW = numpy.einsum('bfi,bop->fopi', Xf, dYp) + dW = self.ops.xp.tensordot(Xf, dYp, axes=[[0], [0]]) + dW = dW.transpose((0, 2, 3, 1)) db = dYp.sum(axis=0) self.d_W += dW diff --git a/spacy/pipeline.pyx b/spacy/pipeline.pyx index 61c71c2bb..c357397f8 100644 --- a/spacy/pipeline.pyx +++ b/spacy/pipeline.pyx @@ -21,7 +21,7 @@ class TokenVectorEncoder(object): '''Assign position-sensitive vectors to tokens, using a CNN or RNN.''' def __init__(self, vocab, **cfg): self.vocab = vocab - self.model = build_tok2vec(vocab.lang, 64, **cfg) + self.model = build_tok2vec(vocab.lang, **cfg) self.tagger = chain( self.model, Softmax(self.vocab.morphology.n_tags)) diff --git a/spacy/syntax/parser.pyx b/spacy/syntax/parser.pyx index 4989a1fb3..76f16f881 100644 --- a/spacy/syntax/parser.pyx +++ b/spacy/syntax/parser.pyx @@ -89,7 +89,7 @@ def get_greedy_model_for_batch(tokvecs, TransitionSystem moves, upper_model, low for i, offset in enumerate(offsets): adjusted_ids[i] *= token_ids[i] >= 0 adjusted_ids[i] += offset - features = upper_model.ops.allocate((len(states), 64), dtype='f') + features = upper_model.ops.allocate((len(states), lower_model.nO), dtype='f') for i in range(len(states)): for j, tok_i in enumerate(adjusted_ids[i]): if tok_i >= 0: @@ -222,7 +222,7 @@ cdef class Parser: nr_context_tokens = StateClass.nr_context_tokens(nF, nB, nS, nL, nR) upper = chain(Maxout(width, width), Maxout(self.moves.n_moves, width)) - lower = PrecomputableMaxouts(width, nF=nr_context_tokens, nI=width) + lower = PrecomputableMaxouts(width, nF=nr_context_tokens, nI=width*2) return upper, lower def __call__(self, Doc tokens): From 66252f3e7124d6114908cedf2c8e7f7b75fcf0f3 Mon Sep 17 00:00:00 2001 From: Matthew Honnibal Date: Mon, 8 May 2017 14:47:11 +0200 Subject: [PATCH 9/9] Change vector width --- bin/parser/train_ud.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/bin/parser/train_ud.py b/bin/parser/train_ud.py index 7b8aedfd3..af79aa927 100644 --- a/bin/parser/train_ud.py +++ b/bin/parser/train_ud.py @@ -133,7 +133,7 @@ def main(lang_name, train_loc, dev_loc, model_dir, clusters_loc=None): for tag in tags: assert tag in vocab.morphology.tag_map, repr(tag) tagger = Tagger(vocab) - encoder = TokenVectorEncoder(vocab) + encoder = TokenVectorEncoder(vocab, width=128) parser = DependencyParser(vocab, actions=actions, features=features, L1=0.0) Xs, ys = organize_data(vocab, train_sents)