diff --git a/spacy/_ml.py b/spacy/_ml.py index 4a41339aa..898d6ab49 100644 --- a/spacy/_ml.py +++ b/spacy/_ml.py @@ -32,7 +32,7 @@ import io # TODO: Unset this once we don't want to support models previous models. import thinc.neural._classes.layernorm -thinc.neural._classes.layernorm.set_compat_six_eight(True) +thinc.neural._classes.layernorm.set_compat_six_eight(False) VECTORS_KEY = 'spacy_pretrained_vectors' @@ -213,6 +213,72 @@ class PrecomputableMaxouts(Model): return dXf return Yfp, backward +# Thinc's Embed class is a bit broken atm, so drop this here. +from thinc import describe +from thinc.neural._classes.embed import _uniform_init + + +@describe.attributes( + nV=describe.Dimension("Number of vectors"), + nO=describe.Dimension("Size of output"), + vectors=describe.Weights("Embedding table", + lambda obj: (obj.nV, obj.nO), + _uniform_init(-0.1, 0.1) + ), + d_vectors=describe.Gradient("vectors") +) +class Embed(Model): + name = 'embed' + + def __init__(self, nO, nV=None, **kwargs): + if nV is not None: + nV += 1 + Model.__init__(self, **kwargs) + if 'name' in kwargs: + self.name = kwargs['name'] + self.column = kwargs.get('column', 0) + self.nO = nO + self.nV = nV + + def predict(self, ids): + if ids.ndim == 2: + ids = ids[:, self.column] + return self.ops.xp.ascontiguousarray(self.vectors[ids], dtype='f') + + def begin_update(self, ids, drop=0.): + if ids.ndim == 2: + ids = ids[:, self.column] + vectors = self.ops.xp.ascontiguousarray(self.vectors[ids], dtype='f') + def backprop_embed(d_vectors, sgd=None): + n_vectors = d_vectors.shape[0] + self.ops.scatter_add(self.d_vectors, ids, d_vectors) + if sgd is not None: + sgd(self._mem.weights, self._mem.gradient, key=self.id) + return None + return vectors, backprop_embed + + +def HistoryFeatures(nr_class, hist_size=8, nr_dim=8): + '''Wrap a model, adding features representing action history.''' + if hist_size == 0: + return layerize(noop()) + embed_tables = [Embed(nr_dim, nr_class, column=i, name='embed%d') + for i in range(hist_size)] + embed = concatenate(*embed_tables) + ops = embed.ops + def add_history_fwd(vectors_hists, drop=0.): + vectors, hist_ids = vectors_hists + hist_feats, bp_hists = embed.begin_update(hist_ids, drop=drop) + outputs = ops.xp.hstack((vectors, hist_feats)) + + def add_history_bwd(d_outputs, sgd=None): + d_vectors = d_outputs[:, :vectors.shape[1]] + d_hists = d_outputs[:, vectors.shape[1]:] + bp_hists(d_hists, sgd=sgd) + return embed.ops.xp.ascontiguousarray(d_vectors) + return outputs, add_history_bwd + return wrap(add_history_fwd, embed) + def drop_layer(layer, factor=2.): def drop_layer_fwd(X, drop=0.): diff --git a/spacy/cli/evaluate.py b/spacy/cli/evaluate.py index 42e077dc2..29e30b7d2 100644 --- a/spacy/cli/evaluate.py +++ b/spacy/cli/evaluate.py @@ -42,7 +42,8 @@ def evaluate(cmd, model, data_path, gpu_id=-1, gold_preproc=False, Evaluate a model. To render a sample of parses in a HTML file, set an output directory as the displacy_path argument. """ - util.use_gpu(gpu_id) + if gpu_id >= 0: + util.use_gpu(gpu_id) util.set_env_log(False) data_path = util.ensure_path(data_path) displacy_path = util.ensure_path(displacy_path) diff --git a/spacy/syntax/_beam_utils.pyx b/spacy/syntax/_beam_utils.pyx index a26900f6b..da4efefbc 100644 --- a/spacy/syntax/_beam_utils.pyx +++ b/spacy/syntax/_beam_utils.pyx @@ -21,6 +21,7 @@ cdef int _transition_state(void* _dest, void* _src, class_t clas, void* _moves) moves = _moves dest.clone(src) moves[clas].do(dest.c, moves[clas].label) + dest.c.push_hist(clas) cdef int _check_final_state(void* _state, void* extra_args) except -1: @@ -148,8 +149,8 @@ def get_token_ids(states, int n_tokens): nr_update = 0 def update_beam(TransitionSystem moves, int nr_feature, int max_steps, states, golds, - state2vec, vec2scores, - int width, float density, + state2vec, vec2scores, + int width, float density, int hist_feats, losses=None, drop=0.): global nr_update cdef MaxViolation violn @@ -180,7 +181,11 @@ def update_beam(TransitionSystem moves, int nr_feature, int max_steps, # Now that we have our flat list of states, feed them through the model token_ids = get_token_ids(states, nr_feature) vectors, bp_vectors = state2vec.begin_update(token_ids, drop=drop) - scores, bp_scores = vec2scores.begin_update(vectors, drop=drop) + if hist_feats: + hists = numpy.asarray([st.history[:hist_feats] for st in states], dtype='i') + scores, bp_scores = vec2scores.begin_update((vectors, hists), drop=drop) + else: + scores, bp_scores = vec2scores.begin_update(vectors, drop=drop) # Store the callbacks for the backward pass backprops.append((token_ids, bp_vectors, bp_scores)) diff --git a/spacy/syntax/_state.pxd b/spacy/syntax/_state.pxd index 4fb16881a..1864b22b3 100644 --- a/spacy/syntax/_state.pxd +++ b/spacy/syntax/_state.pxd @@ -1,4 +1,4 @@ -from libc.string cimport memcpy, memset +from libc.string cimport memcpy, memset, memmove from libc.stdlib cimport malloc, calloc, free from libc.stdint cimport uint32_t, uint64_t @@ -15,6 +15,23 @@ from ..typedefs cimport attr_t cdef inline bint is_space_token(const TokenC* token) nogil: return Lexeme.c_check_flag(token.lex, IS_SPACE) +cdef struct RingBufferC: + int[8] data + int i + int default + +cdef inline int ring_push(RingBufferC* ring, int value) nogil: + ring.data[ring.i] = value + ring.i += 1 + if ring.i >= 8: + ring.i = 0 + +cdef inline int ring_get(RingBufferC* ring, int i) nogil: + if i >= ring.i: + return ring.default + else: + return ring.data[ring.i-i] + cdef cppclass StateC: int* _stack @@ -23,6 +40,7 @@ cdef cppclass StateC: TokenC* _sent Entity* _ents TokenC _empty_token + RingBufferC _hist int length int offset int _s_i @@ -37,6 +55,7 @@ cdef cppclass StateC: this.shifted = calloc(length + (PADDING * 2), sizeof(bint)) this._sent = calloc(length + (PADDING * 2), sizeof(TokenC)) this._ents = calloc(length + (PADDING * 2), sizeof(Entity)) + memset(&this._hist, 0, sizeof(this._hist)) this.offset = 0 cdef int i for i in range(length + (PADDING * 2)): @@ -74,6 +93,9 @@ cdef cppclass StateC: free(this.shifted - PADDING) void set_context_tokens(int* ids, int n) nogil: + if n == 2: + ids[0] = this.B(0) + ids[1] = this.S(0) if n == 8: ids[0] = this.B(0) ids[1] = this.B(1) @@ -271,7 +293,14 @@ cdef cppclass StateC: sig[8] = this.B_(0)[0] sig[9] = this.E_(0)[0] sig[10] = this.E_(1)[0] - return hash64(sig, sizeof(sig), this._s_i) + return hash64(sig, sizeof(sig), this._s_i) \ + + hash64(&this._hist, sizeof(RingBufferC), 1) + + void push_hist(int act) nogil: + ring_push(&this._hist, act+1) + + int get_hist(int i) nogil: + return ring_get(&this._hist, i) void push() nogil: if this.B(0) != -1: diff --git a/spacy/syntax/nn_parser.pyx b/spacy/syntax/nn_parser.pyx index 459c94463..fdcf1d2d1 100644 --- a/spacy/syntax/nn_parser.pyx +++ b/spacy/syntax/nn_parser.pyx @@ -50,6 +50,7 @@ from .._ml import zero_init, PrecomputableAffine, PrecomputableMaxouts from .._ml import Tok2Vec, doc2feats, rebatch, fine_tune from .._ml import Residual, drop_layer, flatten from .._ml import link_vectors_to_models +from .._ml import HistoryFeatures from ..compat import json_dumps from . import _parse_features @@ -67,12 +68,10 @@ from ..gold cimport GoldParse from ..attrs cimport ID, TAG, DEP, ORTH, NORM, PREFIX, SUFFIX, TAG from . import _beam_utils -USE_FINE_TUNE = True def get_templates(*args, **kwargs): return [] -USE_FTRL = True DEBUG = False def set_debug(val): global DEBUG @@ -239,12 +238,17 @@ cdef class Parser: Base class of the DependencyParser and EntityRecognizer. """ @classmethod - def Model(cls, nr_class, token_vector_width=128, hidden_width=200, depth=1, **cfg): - depth = util.env_opt('parser_hidden_depth', depth) - token_vector_width = util.env_opt('token_vector_width', token_vector_width) - hidden_width = util.env_opt('hidden_width', hidden_width) - parser_maxout_pieces = util.env_opt('parser_maxout_pieces', 2) - embed_size = util.env_opt('embed_size', 7000) + def Model(cls, nr_class, **cfg): + depth = util.env_opt('parser_hidden_depth', cfg.get('hidden_depth', 2)) + token_vector_width = util.env_opt('token_vector_width', cfg.get('token_vector_width', 128)) + hidden_width = util.env_opt('hidden_width', cfg.get('hidden_width', 128)) + parser_maxout_pieces = util.env_opt('parser_maxout_pieces', cfg.get('maxout_pieces', 1)) + embed_size = util.env_opt('embed_size', cfg.get('embed_size', 7000)) + hist_size = util.env_opt('history_feats', cfg.get('hist_size', 4)) + hist_width = util.env_opt('history_width', cfg.get('hist_width', 16)) + if hist_size >= 1 and depth == 0: + raise ValueError("Inconsistent hyper-params: " + "history_feats >= 1 but parser_hidden_depth==0") tok2vec = Tok2Vec(token_vector_width, embed_size, pretrained_dims=cfg.get('pretrained_dims', 0)) tok2vec = chain(tok2vec, flatten) @@ -262,22 +266,40 @@ cdef class Parser: if depth == 0: upper = chain() upper.is_noop = True - else: + elif hist_size and depth == 1: upper = chain( - clone(Maxout(hidden_width), depth-1), + HistoryFeatures(nr_class=nr_class, hist_size=hist_size, + nr_dim=hist_width), + zero_init(Affine(nr_class, hidden_width+hist_size*hist_width, + drop_factor=0.0))) + upper.is_noop = False + elif hist_size: + upper = chain( + HistoryFeatures(nr_class=nr_class, hist_size=hist_size, + nr_dim=hist_width), + LayerNorm(Maxout(hidden_width, hidden_width+hist_size*hist_width)), + clone(LayerNorm(Maxout(hidden_width, hidden_width)), depth-2), zero_init(Affine(nr_class, hidden_width, drop_factor=0.0)) ) upper.is_noop = False + else: + upper = chain( + clone(LayerNorm(Maxout(hidden_width, hidden_width)), depth-1), + zero_init(Affine(nr_class, hidden_width, drop_factor=0.0)) + ) + upper.is_noop = False + # TODO: This is an unfortunate hack atm! # Used to set input dimensions in network. lower.begin_training(lower.ops.allocate((500, token_vector_width))) - upper.begin_training(upper.ops.allocate((500, hidden_width))) cfg = { 'nr_class': nr_class, - 'depth': depth, + 'hidden_depth': depth, 'token_vector_width': token_vector_width, 'hidden_width': hidden_width, - 'maxout_pieces': parser_maxout_pieces + 'maxout_pieces': parser_maxout_pieces, + 'hist_size': hist_size, + 'hist_width': hist_width } return (tok2vec, lower, upper), cfg @@ -350,7 +372,7 @@ cdef class Parser: _cleanup(beam) return output - def pipe(self, docs, int batch_size=1000, int n_threads=2, + def pipe(self, docs, int batch_size=256, int n_threads=2, beam_width=None, beam_density=None): """ Process a stream of documents. @@ -427,12 +449,18 @@ cdef class Parser: self._parse_step(next_step[i], feat_weights, nr_class, nr_feat, nr_piece) else: + hists = [] for i in range(nr_step): st = next_step[i] st.set_context_tokens(&c_token_ids[i*nr_feat], nr_feat) self.moves.set_valid(&c_is_valid[i*nr_class], st) + hists.append([st.get_hist(j+1) for j in range(8)]) + hists = numpy.asarray(hists) vectors = state2vec(token_ids[:next_step.size()]) - scores = vec2scores(vectors) + if self.cfg.get('hist_size'): + scores = vec2scores((vectors, hists)) + else: + scores = vec2scores(vectors) c_scores = scores.data for i in range(nr_step): st = next_step[i] @@ -440,6 +468,7 @@ cdef class Parser: &c_scores[i*nr_class], &c_is_valid[i*nr_class], nr_class) action = self.moves.c[guess] action.do(st, action.label) + st.push_hist(guess) this_step, next_step = next_step, this_step next_step.clear() for st in this_step: @@ -478,7 +507,12 @@ cdef class Parser: states.append(stcls) token_ids = self.get_token_ids(states) vectors = state2vec(token_ids) - scores = vec2scores(vectors) + if self.cfg.get('hist_size', 0): + hists = numpy.asarray([st.history[:self.cfg['hist_size']] + for st in states], dtype='i') + scores = vec2scores((vectors, hists)) + else: + scores = vec2scores(vectors) j = 0 c_scores = scores.data for i in range(beam.size): @@ -497,8 +531,6 @@ cdef class Parser: const float* feat_weights, int nr_class, int nr_feat, int nr_piece) nogil: '''This only works with no hidden layers -- fast but inaccurate''' - #for i in cython.parallel.prange(next_step.size(), num_threads=4, nogil=True): - # self._parse_step(next_step[i], feat_weights, nr_class, nr_feat) token_ids = calloc(nr_feat, sizeof(int)) scores = calloc(nr_class * nr_piece, sizeof(float)) is_valid = calloc(nr_class, sizeof(int)) @@ -510,6 +542,7 @@ cdef class Parser: guess = arg_maxout_if_valid(scores, is_valid, nr_class, nr_piece) action = self.moves.c[guess] action.do(state, action.label) + state.push_hist(guess) free(is_valid) free(scores) @@ -550,7 +583,11 @@ cdef class Parser: if drop != 0: mask = vec2scores.ops.get_dropout_mask(vector.shape, drop) vector *= mask - scores, bp_scores = vec2scores.begin_update(vector, drop=drop) + hists = numpy.asarray([st.history for st in states], dtype='i') + if self.cfg.get('hist_size', 0): + scores, bp_scores = vec2scores.begin_update((vector, hists), drop=drop) + else: + scores, bp_scores = vec2scores.begin_update(vector, drop=drop) d_scores = self.get_batch_loss(states, golds, scores) d_scores /= len(docs) @@ -569,7 +606,8 @@ cdef class Parser: else: backprops.append((token_ids, d_vector, bp_vector)) self.transition_batch(states, scores) - todo = [st for st in todo if not st[0].is_final()] + todo = [(st, gold) for (st, gold) in todo + if not st.is_final()] if losses is not None: losses[self.name] += (d_scores**2).sum() n_steps += 1 @@ -602,7 +640,7 @@ cdef class Parser: states_d_scores, backprops = _beam_utils.update_beam(self.moves, self.nr_feature, 500, states, golds, state2vec, vec2scores, - width, density, + width, density, self.cfg.get('hist_size', 0), drop=drop, losses=losses) backprop_lower = [] cdef float batch_size = len(docs) @@ -648,6 +686,7 @@ cdef class Parser: while state.B(0) < start and not state.is_final(): action = self.moves.c[oracle_actions.pop(0)] action.do(state.c, action.label) + state.c.push_hist(action.clas) n_moves += 1 has_gold = self.moves.has_gold(gold, start=start, end=start+max_length) @@ -711,6 +750,7 @@ cdef class Parser: action = self.moves.c[guess] action.do(state.c, action.label) c_scores += scores.shape[1] + state.c.push_hist(guess) def get_batch_loss(self, states, golds, float[:, ::1] scores): cdef StateClass state @@ -934,6 +974,7 @@ cdef int _transition_state(void* _dest, void* _src, class_t clas, void* _moves) moves = _moves dest.clone(src) moves[clas].do(dest.c, moves[clas].label) + dest.c.push_hist(clas) cdef int _check_final_state(void* _state, void* extra_args) except -1: diff --git a/spacy/syntax/stateclass.pyx b/spacy/syntax/stateclass.pyx index 228a3ff91..ddd1f558c 100644 --- a/spacy/syntax/stateclass.pyx +++ b/spacy/syntax/stateclass.pyx @@ -4,6 +4,7 @@ from __future__ import unicode_literals from libc.string cimport memcpy, memset from libc.stdint cimport uint32_t, uint64_t +import numpy from ..vocab cimport EMPTY_LEXEME from ..structs cimport Entity @@ -38,6 +39,13 @@ cdef class StateClass: def token_vector_lenth(self): return self.doc.tensor.shape[1] + @property + def history(self): + hist = numpy.ndarray((8,), dtype='i') + for i in range(8): + hist[i] = self.c.get_hist(i+1) + return hist + def is_final(self): return self.c.is_final() @@ -54,27 +62,3 @@ cdef class StateClass: n0 = words[self.B(0)] n1 = words[self.B(1)] return ' '.join((third, second, top, '|', n0, n1)) - - @classmethod - def nr_context_tokens(cls): - return 13 - - def set_context_tokens(self, int[::1] output): - output[0] = self.B(0) - output[1] = self.B(1) - output[2] = self.S(0) - output[3] = self.S(1) - output[4] = self.S(2) - output[5] = self.L(self.S(0), 1) - output[6] = self.L(self.S(0), 2) - output[6] = self.R(self.S(0), 1) - output[7] = self.L(self.B(0), 1) - output[8] = self.R(self.S(0), 2) - output[9] = self.L(self.S(1), 1) - output[10] = self.L(self.S(1), 2) - output[11] = self.R(self.S(1), 1) - output[12] = self.R(self.S(1), 2) - - for i in range(13): - if output[i] != -1: - output[i] += self.c.offset diff --git a/website/api/_top-level/_cli.jade b/website/api/_top-level/_cli.jade index f59d5afdd..3a4b4702a 100644 --- a/website/api/_top-level/_cli.jade +++ b/website/api/_top-level/_cli.jade @@ -314,6 +314,16 @@ p +cell Size of the parser's and NER's hidden layers. +cell #[code 128] + +row + +cell #[code history_feats] + +cell Number of previous action ID features for parser and NER. + +cell #[code 128] + + +row + +cell #[code history_width] + +cell Number of embedding dimensions for each action ID. + +cell #[code 128] + +row +cell #[code learn_rate] +cell Learning rate.