From 6aa6a5bc25eeebf1ffea4ee97f7e26d3f09c357a Mon Sep 17 00:00:00 2001 From: Matthew Honnibal Date: Tue, 3 Oct 2017 12:43:09 +0200 Subject: [PATCH 01/27] Add a layer type for history features --- spacy/_ml.py | 22 ++++++++++++++++++++++ 1 file changed, 22 insertions(+) diff --git a/spacy/_ml.py b/spacy/_ml.py index 62fc7543f..38f220cc1 100644 --- a/spacy/_ml.py +++ b/spacy/_ml.py @@ -21,6 +21,7 @@ from thinc.neural._classes.affine import _set_dimensions_if_needed from thinc.api import FeatureExtracter, with_getitem from thinc.neural.pooling import Pooling, max_pool, mean_pool, sum_pool from thinc.neural._classes.attention import ParametricAttention +from thinc.neural._classes.embed import Embed from thinc.linear.linear import LinearModel from thinc.api import uniqued, wrap, flatten_add_lengths, noop @@ -212,6 +213,27 @@ class PrecomputableMaxouts(Model): return Yfp, backward +def HistoryFeatures(nr_class, hist_size=8, nr_dim=8): + '''Wrap a model, adding features representing action history.''' + embed = Embed(nr_dim, nr_dim, nr_class) + ops = embed.ops + def add_history_fwd(vectors_hists, drop=0.): + vectors, hist_ids = vectors_hists + flat_hists, bp_hists = embed.begin_update(hist_ids.flatten(), drop=drop) + hists = flat_hists.reshape((hist_ids.shape[0], + hist_ids.shape[1] * flat_hists.shape[1])) + outputs = ops.xp.hstack((vectors, hists)) + + def add_history_bwd(d_outputs, sgd=None): + d_vectors = d_outputs[:, :vectors.shape[1]] + d_hists = d_outputs[:, vectors.shape[1]:] + bp_hists(d_hists.reshape((d_hists.shape[0]*hist_size, + int(d_hists.shape[1]/hist_size))), sgd=sgd) + return embed.ops.xp.ascontiguousarray(d_vectors) + return outputs, add_history_bwd + return wrap(add_history_fwd, embed) + + def drop_layer(layer, factor=2.): def drop_layer_fwd(X, drop=0.): if drop <= 0.: From ee41e4fea7609119655a6ad73ead2df4b754c552 Mon Sep 17 00:00:00 2001 From: Matthew Honnibal Date: Tue, 3 Oct 2017 12:43:48 +0200 Subject: [PATCH 02/27] Support history features in stateclass --- spacy/syntax/_state.pxd | 30 ++++++++++++++++++++++++++++-- spacy/syntax/stateclass.pyx | 8 ++++++++ 2 files changed, 36 insertions(+), 2 deletions(-) diff --git a/spacy/syntax/_state.pxd b/spacy/syntax/_state.pxd index 4fb16881a..f4fa49286 100644 --- a/spacy/syntax/_state.pxd +++ b/spacy/syntax/_state.pxd @@ -1,4 +1,4 @@ -from libc.string cimport memcpy, memset +from libc.string cimport memcpy, memset, memmove from libc.stdlib cimport malloc, calloc, free from libc.stdint cimport uint32_t, uint64_t @@ -15,6 +15,23 @@ from ..typedefs cimport attr_t cdef inline bint is_space_token(const TokenC* token) nogil: return Lexeme.c_check_flag(token.lex, IS_SPACE) +cdef struct RingBufferC: + int[8] data + int i + int default + +cdef inline int ring_push(RingBufferC* ring, int value) nogil: + ring.data[ring.i] = value + ring.i += 1 + if ring.i >= 8: + ring.i = 0 + +cdef inline int ring_get(RingBufferC* ring, int i) nogil: + if i >= ring.i: + return ring.default + else: + return ring.data[ring.i-i] + cdef cppclass StateC: int* _stack @@ -23,6 +40,7 @@ cdef cppclass StateC: TokenC* _sent Entity* _ents TokenC _empty_token + RingBufferC _hist int length int offset int _s_i @@ -37,6 +55,7 @@ cdef cppclass StateC: this.shifted = calloc(length + (PADDING * 2), sizeof(bint)) this._sent = calloc(length + (PADDING * 2), sizeof(TokenC)) this._ents = calloc(length + (PADDING * 2), sizeof(Entity)) + memset(&this._hist, 0, sizeof(this._hist)) this.offset = 0 cdef int i for i in range(length + (PADDING * 2)): @@ -271,7 +290,14 @@ cdef cppclass StateC: sig[8] = this.B_(0)[0] sig[9] = this.E_(0)[0] sig[10] = this.E_(1)[0] - return hash64(sig, sizeof(sig), this._s_i) + return hash64(sig, sizeof(sig), this._s_i) \ + + hash64(&this._hist, sizeof(RingBufferC), 1) + + void push_hist(int act) nogil: + ring_push(&this._hist, act) + + int get_hist(int i) nogil: + return ring_get(&this._hist, i) void push() nogil: if this.B(0) != -1: diff --git a/spacy/syntax/stateclass.pyx b/spacy/syntax/stateclass.pyx index 228a3ff91..9c179820c 100644 --- a/spacy/syntax/stateclass.pyx +++ b/spacy/syntax/stateclass.pyx @@ -4,6 +4,7 @@ from __future__ import unicode_literals from libc.string cimport memcpy, memset from libc.stdint cimport uint32_t, uint64_t +import numpy from ..vocab cimport EMPTY_LEXEME from ..structs cimport Entity @@ -38,6 +39,13 @@ cdef class StateClass: def token_vector_lenth(self): return self.doc.tensor.shape[1] + @property + def history(self): + hist = numpy.ndarray((8,), dtype='i') + for i in range(8): + hist[i] = self.c.get_hist(i+1) + return hist + def is_final(self): return self.c.is_final() From b50a359e1140e968e115f09ced042fc6a02fac22 Mon Sep 17 00:00:00 2001 From: Matthew Honnibal Date: Tue, 3 Oct 2017 12:44:01 +0200 Subject: [PATCH 03/27] Add support for history features in parsing models --- spacy/syntax/nn_parser.pyx | 51 +++++++++++++++++++++++++++++++------- 1 file changed, 42 insertions(+), 9 deletions(-) diff --git a/spacy/syntax/nn_parser.pyx b/spacy/syntax/nn_parser.pyx index 1efdc4474..2277e568e 100644 --- a/spacy/syntax/nn_parser.pyx +++ b/spacy/syntax/nn_parser.pyx @@ -51,6 +51,7 @@ from .._ml import zero_init, PrecomputableAffine, PrecomputableMaxouts from .._ml import Tok2Vec, doc2feats, rebatch, fine_tune from .._ml import Residual, drop_layer, flatten from .._ml import link_vectors_to_models +from .._ml import HistoryFeatures from ..compat import json_dumps from . import _parse_features @@ -68,7 +69,7 @@ from ..gold cimport GoldParse from ..attrs cimport ID, TAG, DEP, ORTH, NORM, PREFIX, SUFFIX, TAG from . import _beam_utils -USE_FINE_TUNE = True +USE_HISTORY = True def get_templates(*args, **kwargs): return [] @@ -261,18 +262,35 @@ cdef class Parser: with Model.use_device('cpu'): if depth == 0: - upper = chain() - upper.is_noop = True - else: + hist_size = 8 + nr_dim = 8 + if USE_HISTORY: + upper = chain( + HistoryFeatures(nr_class=nr_class, hist_size=hist_size, + nr_dim=nr_dim), + zero_init(Affine(nr_class, nr_class+hist_size*nr_dim, + drop_factor=0.0))) + upper.is_noop = False + else: + upper = chain() + upper.is_noop = True + elif USE_HISTORY: upper = chain( - clone(Maxout(hidden_width), depth-1), + HistoryFeatures(nr_class=nr_class, hist_size=8, nr_dim=8), + Maxout(hidden_width, hidden_width+8*8), zero_init(Affine(nr_class, hidden_width, drop_factor=0.0)) ) upper.is_noop = False + else: + upper = chain( + Maxout(hidden_width, hidden_width), + zero_init(Affine(nr_class, hidden_width, drop_factor=0.0)) + ) + upper.is_noop = False + # TODO: This is an unfortunate hack atm! # Used to set input dimensions in network. lower.begin_training(lower.ops.allocate((500, token_vector_width))) - upper.begin_training(upper.ops.allocate((500, hidden_width))) cfg = { 'nr_class': nr_class, 'depth': depth, @@ -428,12 +446,18 @@ cdef class Parser: self._parse_step(next_step[i], feat_weights, nr_class, nr_feat, nr_piece) else: + hists = [] for i in range(nr_step): st = next_step[i] st.set_context_tokens(&c_token_ids[i*nr_feat], nr_feat) self.moves.set_valid(&c_is_valid[i*nr_class], st) + hists.append([st.get_hist(j+1) for j in range(8)]) + hists = numpy.asarray(hists) vectors = state2vec(token_ids[:next_step.size()]) - scores = vec2scores(vectors) + if USE_HISTORY: + scores = vec2scores((vectors, hists)) + else: + scores = vec2scores(vectors) c_scores = scores.data for i in range(nr_step): st = next_step[i] @@ -441,6 +465,7 @@ cdef class Parser: &c_scores[i*nr_class], &c_is_valid[i*nr_class], nr_class) action = self.moves.c[guess] action.do(st, action.label) + st.push_hist(guess) this_step, next_step = next_step, this_step next_step.clear() for st in this_step: @@ -551,7 +576,11 @@ cdef class Parser: if drop != 0: mask = vec2scores.ops.get_dropout_mask(vector.shape, drop) vector *= mask - scores, bp_scores = vec2scores.begin_update(vector, drop=drop) + hists = numpy.asarray([st.history for st in states], dtype='i') + if USE_HISTORY: + scores, bp_scores = vec2scores.begin_update((vector, hists), drop=drop) + else: + scores, bp_scores = vec2scores.begin_update(vector, drop=drop) d_scores = self.get_batch_loss(states, golds, scores) d_scores /= len(docs) @@ -570,7 +599,8 @@ cdef class Parser: else: backprops.append((token_ids, d_vector, bp_vector)) self.transition_batch(states, scores) - todo = [st for st in todo if not st[0].is_final()] + todo = [(st, gold) for (st, gold) in todo + if not st.is_final()] if losses is not None: losses[self.name] += (d_scores**2).sum() n_steps += 1 @@ -706,12 +736,15 @@ cdef class Parser: cdef StateClass state cdef int[500] is_valid # TODO: Unhack cdef float* c_scores = &scores[0, 0] + hists = [] for state in states: self.moves.set_valid(is_valid, state.c) guess = arg_max_if_valid(c_scores, is_valid, scores.shape[1]) action = self.moves.c[guess] action.do(state.c, action.label) c_scores += scores.shape[1] + hists.append(guess) + return hists def get_batch_loss(self, states, golds, float[:, ::1] scores): cdef StateClass state From b770f4e1082b1da83597bb723a0e1986befdd069 Mon Sep 17 00:00:00 2001 From: Matthew Honnibal Date: Tue, 3 Oct 2017 13:26:55 +0200 Subject: [PATCH 04/27] Fix embed class in history features --- spacy/_ml.py | 52 ++++++++++++++++++++++++++++++++++++++++++++-------- 1 file changed, 44 insertions(+), 8 deletions(-) diff --git a/spacy/_ml.py b/spacy/_ml.py index 38f220cc1..3b6e4da10 100644 --- a/spacy/_ml.py +++ b/spacy/_ml.py @@ -21,7 +21,6 @@ from thinc.neural._classes.affine import _set_dimensions_if_needed from thinc.api import FeatureExtracter, with_getitem from thinc.neural.pooling import Pooling, max_pool, mean_pool, sum_pool from thinc.neural._classes.attention import ParametricAttention -from thinc.neural._classes.embed import Embed from thinc.linear.linear import LinearModel from thinc.api import uniqued, wrap, flatten_add_lengths, noop @@ -212,23 +211,60 @@ class PrecomputableMaxouts(Model): return dXf return Yfp, backward +# Thinc's Embed class is a bit broken atm, so drop this here. +from thinc import describe +from thinc.neural._classes.embed import _uniform_init + +@describe.attributes( + nV=describe.Dimension("Number of vectors"), + nO=describe.Dimension("Size of output"), + vectors=describe.Weights("Embedding table", + lambda obj: (obj.nV, obj.nO), + _uniform_init(-0.1, 0.1) + ), + d_vectors=describe.Gradient("vectors") +) +class Embed(Model): + name = 'embed' + + def __init__(self, nO, nV=None, **kwargs): + Model.__init__(self, **kwargs) + self.column = kwargs.get('column', 0) + self.nO = nO + self.nV = nV + + def predict(self, ids): + if ids.ndim == 2: + ids = ids[:, self.column] + return self._embed(ids) + + def begin_update(self, ids, drop=0.): + if ids.ndim == 2: + ids = ids[:, self.column] + vectors = self.vectors[ids] + def backprop_embed(d_vectors, sgd=None): + n_vectors = d_vectors.shape[0] + self.ops.scatter_add(self.d_vectors, ids, d_vectors) + if sgd is not None: + sgd(self._mem.weights, self._mem.gradient, key=self.id) + return None + return vectors, backprop_embed + def HistoryFeatures(nr_class, hist_size=8, nr_dim=8): '''Wrap a model, adding features representing action history.''' - embed = Embed(nr_dim, nr_dim, nr_class) + embed_tables = [Embed(nr_dim, nr_class, column=i) for i in range(hist_size)] + embed = concatenate(*embed_tables) ops = embed.ops def add_history_fwd(vectors_hists, drop=0.): vectors, hist_ids = vectors_hists - flat_hists, bp_hists = embed.begin_update(hist_ids.flatten(), drop=drop) - hists = flat_hists.reshape((hist_ids.shape[0], - hist_ids.shape[1] * flat_hists.shape[1])) - outputs = ops.xp.hstack((vectors, hists)) + hist_feats, bp_hists = embed.begin_update(hist_ids) + outputs = ops.xp.hstack((vectors, hist_feats)) def add_history_bwd(d_outputs, sgd=None): d_vectors = d_outputs[:, :vectors.shape[1]] d_hists = d_outputs[:, vectors.shape[1]:] - bp_hists(d_hists.reshape((d_hists.shape[0]*hist_size, - int(d_hists.shape[1]/hist_size))), sgd=sgd) + bp_hists(d_hists, sgd=sgd) return embed.ops.xp.ascontiguousarray(d_vectors) return outputs, add_history_bwd return wrap(add_history_fwd, embed) From 278a4c17c642b366b71dccc7cec202dc22cfcb93 Mon Sep 17 00:00:00 2001 From: Matthew Honnibal Date: Tue, 3 Oct 2017 13:27:10 +0200 Subject: [PATCH 05/27] Fix history features --- spacy/syntax/nn_parser.pyx | 14 ++++++-------- 1 file changed, 6 insertions(+), 8 deletions(-) diff --git a/spacy/syntax/nn_parser.pyx b/spacy/syntax/nn_parser.pyx index 2277e568e..4a874e834 100644 --- a/spacy/syntax/nn_parser.pyx +++ b/spacy/syntax/nn_parser.pyx @@ -70,6 +70,8 @@ from ..attrs cimport ID, TAG, DEP, ORTH, NORM, PREFIX, SUFFIX, TAG from . import _beam_utils USE_HISTORY = True +HIST_SIZE = 2 +HIST_DIMS = 16 def get_templates(*args, **kwargs): return [] @@ -262,13 +264,11 @@ cdef class Parser: with Model.use_device('cpu'): if depth == 0: - hist_size = 8 - nr_dim = 8 if USE_HISTORY: upper = chain( - HistoryFeatures(nr_class=nr_class, hist_size=hist_size, - nr_dim=nr_dim), - zero_init(Affine(nr_class, nr_class+hist_size*nr_dim, + HistoryFeatures(nr_class=nr_class, hist_size=HIST_SIZE, + nr_dim=HIST_DIMS), + zero_init(Affine(nr_class, nr_class+HIST_SIZE*HIST_DIMS, drop_factor=0.0))) upper.is_noop = False else: @@ -736,15 +736,13 @@ cdef class Parser: cdef StateClass state cdef int[500] is_valid # TODO: Unhack cdef float* c_scores = &scores[0, 0] - hists = [] for state in states: self.moves.set_valid(is_valid, state.c) guess = arg_max_if_valid(c_scores, is_valid, scores.shape[1]) action = self.moves.c[guess] action.do(state.c, action.label) c_scores += scores.shape[1] - hists.append(guess) - return hists + state.c.push_hist(guess) def get_batch_loss(self, states, golds, float[:, ::1] scores): cdef StateClass state From dc3c79194763d28e5d9e34918c22a05585d151cc Mon Sep 17 00:00:00 2001 From: Matthew Honnibal Date: Tue, 3 Oct 2017 13:41:23 +0200 Subject: [PATCH 06/27] Fix history size option --- spacy/syntax/nn_parser.pyx | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/spacy/syntax/nn_parser.pyx b/spacy/syntax/nn_parser.pyx index 4a874e834..87099aa4f 100644 --- a/spacy/syntax/nn_parser.pyx +++ b/spacy/syntax/nn_parser.pyx @@ -70,8 +70,8 @@ from ..attrs cimport ID, TAG, DEP, ORTH, NORM, PREFIX, SUFFIX, TAG from . import _beam_utils USE_HISTORY = True -HIST_SIZE = 2 -HIST_DIMS = 16 +HIST_SIZE = 8 # Max 8 +HIST_DIMS = 8 def get_templates(*args, **kwargs): return [] @@ -276,8 +276,8 @@ cdef class Parser: upper.is_noop = True elif USE_HISTORY: upper = chain( - HistoryFeatures(nr_class=nr_class, hist_size=8, nr_dim=8), - Maxout(hidden_width, hidden_width+8*8), + HistoryFeatures(nr_class=nr_class, hist_size=HIST_SIZE, nr_dim=HIST_DIMS), + Maxout(hidden_width, hidden_width+HIST_SIZE*HIST_DIMS), zero_init(Affine(nr_class, hidden_width, drop_factor=0.0)) ) upper.is_noop = False From 92066b04d6401207d8b0fad4d121eca06a7210c5 Mon Sep 17 00:00:00 2001 From: Matthew Honnibal Date: Wed, 4 Oct 2017 19:55:34 -0500 Subject: [PATCH 07/27] Fix Embed and HistoryFeatures --- spacy/_ml.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/spacy/_ml.py b/spacy/_ml.py index 6df10b6b2..6ebccd69a 100644 --- a/spacy/_ml.py +++ b/spacy/_ml.py @@ -231,6 +231,8 @@ class Embed(Model): def __init__(self, nO, nV=None, **kwargs): Model.__init__(self, **kwargs) + if 'name' in kwargs: + self.name = kwargs['name'] self.column = kwargs.get('column', 0) self.nO = nO self.nV = nV @@ -238,12 +240,12 @@ class Embed(Model): def predict(self, ids): if ids.ndim == 2: ids = ids[:, self.column] - return self._embed(ids) + return self.ops.xp.ascontiguousarray(self.vectors[ids]) def begin_update(self, ids, drop=0.): if ids.ndim == 2: ids = ids[:, self.column] - vectors = self.vectors[ids] + vectors = self.ops.xp.ascontiguousarray(self.vectors[ids]) def backprop_embed(d_vectors, sgd=None): n_vectors = d_vectors.shape[0] self.ops.scatter_add(self.d_vectors, ids, d_vectors) @@ -255,7 +257,8 @@ class Embed(Model): def HistoryFeatures(nr_class, hist_size=8, nr_dim=8): '''Wrap a model, adding features representing action history.''' - embed_tables = [Embed(nr_dim, nr_class, column=i) for i in range(hist_size)] + embed_tables = [Embed(nr_dim, nr_class, column=i, name='embed%d') + for i in range(hist_size)] embed = concatenate(*embed_tables) ops = embed.ops def add_history_fwd(vectors_hists, drop=0.): From 943af4423a9b3f0eced972c420ae023d8e3a1dd4 Mon Sep 17 00:00:00 2001 From: Matthew Honnibal Date: Wed, 4 Oct 2017 20:06:05 -0500 Subject: [PATCH 08/27] Make depth setting in parser work again --- spacy/syntax/nn_parser.pyx | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/spacy/syntax/nn_parser.pyx b/spacy/syntax/nn_parser.pyx index 016807e87..422b0fdc7 100644 --- a/spacy/syntax/nn_parser.pyx +++ b/spacy/syntax/nn_parser.pyx @@ -277,6 +277,7 @@ cdef class Parser: upper = chain( HistoryFeatures(nr_class=nr_class, hist_size=HIST_SIZE, nr_dim=HIST_DIMS), Maxout(hidden_width, hidden_width+HIST_SIZE*HIST_DIMS), + clone(Maxout(hidden_width, hidden_width), depth-2), zero_init(Affine(nr_class, hidden_width, drop_factor=0.0)) ) upper.is_noop = False @@ -286,7 +287,7 @@ cdef class Parser: zero_init(Affine(nr_class, hidden_width, drop_factor=0.0)) ) upper.is_noop = False - + # TODO: This is an unfortunate hack atm! # Used to set input dimensions in network. lower.begin_training(lower.ops.allocate((500, token_vector_width))) From dcdfa071aaf983adae5b3fb39336a2b1102970ab Mon Sep 17 00:00:00 2001 From: Matthew Honnibal Date: Wed, 4 Oct 2017 20:06:52 -0500 Subject: [PATCH 09/27] Disable LayerNorm hack --- spacy/_ml.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/spacy/_ml.py b/spacy/_ml.py index 1f78de9a9..6223715b5 100644 --- a/spacy/_ml.py +++ b/spacy/_ml.py @@ -32,7 +32,7 @@ import io # TODO: Unset this once we don't want to support models previous models. import thinc.neural._classes.layernorm -thinc.neural._classes.layernorm.set_compat_six_eight(True) +thinc.neural._classes.layernorm.set_compat_six_eight(False) VECTORS_KEY = 'spacy_pretrained_vectors' From e25ffcb11f349c1f411d6d51280146eb5f72126a Mon Sep 17 00:00:00 2001 From: Matthew Honnibal Date: Thu, 5 Oct 2017 19:38:13 -0500 Subject: [PATCH 10/27] Move history size under feature flags --- spacy/syntax/nn_parser.pyx | 31 ++++++++++++++++--------------- 1 file changed, 16 insertions(+), 15 deletions(-) diff --git a/spacy/syntax/nn_parser.pyx b/spacy/syntax/nn_parser.pyx index 422b0fdc7..b57e8b466 100644 --- a/spacy/syntax/nn_parser.pyx +++ b/spacy/syntax/nn_parser.pyx @@ -68,14 +68,10 @@ from ..gold cimport GoldParse from ..attrs cimport ID, TAG, DEP, ORTH, NORM, PREFIX, SUFFIX, TAG from . import _beam_utils -USE_HISTORY = True -HIST_SIZE = 8 # Max 8 -HIST_DIMS = 8 def get_templates(*args, **kwargs): return [] -USE_FTRL = True DEBUG = False def set_debug(val): global DEBUG @@ -248,6 +244,8 @@ cdef class Parser: hidden_width = util.env_opt('hidden_width', hidden_width) parser_maxout_pieces = util.env_opt('parser_maxout_pieces', 2) embed_size = util.env_opt('embed_size', 7000) + hist_size = util.env_opt('history_feats', cfg.get('history_feats', 0)) + hist_width = util.env_opt('history_width', cfg.get('history_width', 0)) tok2vec = Tok2Vec(token_vector_width, embed_size, pretrained_dims=cfg.get('pretrained_dims', 0)) tok2vec = chain(tok2vec, flatten) @@ -263,20 +261,21 @@ cdef class Parser: with Model.use_device('cpu'): if depth == 0: - if USE_HISTORY: + if hist_size: upper = chain( - HistoryFeatures(nr_class=nr_class, hist_size=HIST_SIZE, - nr_dim=HIST_DIMS), - zero_init(Affine(nr_class, nr_class+HIST_SIZE*HIST_DIMS, + HistoryFeatures(nr_class=nr_class, hist_size=hist_size, + nr_dim=hist_width), + zero_init(Affine(nr_class, nr_class+hist_size*hist_size, drop_factor=0.0))) upper.is_noop = False else: upper = chain() upper.is_noop = True - elif USE_HISTORY: + elif hist_size: upper = chain( - HistoryFeatures(nr_class=nr_class, hist_size=HIST_SIZE, nr_dim=HIST_DIMS), - Maxout(hidden_width, hidden_width+HIST_SIZE*HIST_DIMS), + HistoryFeatures(nr_class=nr_class, hist_size=hist_size, + nr_dim=hist_width), + Maxout(hidden_width, hidden_width+hist_size*hist_width), clone(Maxout(hidden_width, hidden_width), depth-2), zero_init(Affine(nr_class, hidden_width, drop_factor=0.0)) ) @@ -296,7 +295,9 @@ cdef class Parser: 'depth': depth, 'token_vector_width': token_vector_width, 'hidden_width': hidden_width, - 'maxout_pieces': parser_maxout_pieces + 'maxout_pieces': parser_maxout_pieces, + 'hist_size': hist_size, + 'hist_width': hist_width } return (tok2vec, lower, upper), cfg @@ -369,7 +370,7 @@ cdef class Parser: _cleanup(beam) return output - def pipe(self, docs, int batch_size=1000, int n_threads=2, + def pipe(self, docs, int batch_size=256, int n_threads=2, beam_width=None, beam_density=None): """ Process a stream of documents. @@ -454,7 +455,7 @@ cdef class Parser: hists.append([st.get_hist(j+1) for j in range(8)]) hists = numpy.asarray(hists) vectors = state2vec(token_ids[:next_step.size()]) - if USE_HISTORY: + if self.cfg.get('hist_size'): scores = vec2scores((vectors, hists)) else: scores = vec2scores(vectors) @@ -577,7 +578,7 @@ cdef class Parser: mask = vec2scores.ops.get_dropout_mask(vector.shape, drop) vector *= mask hists = numpy.asarray([st.history for st in states], dtype='i') - if USE_HISTORY: + if self.cfg.get('hist_size', 0): scores, bp_scores = vec2scores.begin_update((vector, hists), drop=drop) else: scores, bp_scores = vec2scores.begin_update(vector, drop=drop) From fc06b0a33357352c99c5b1e41789c15920daac73 Mon Sep 17 00:00:00 2001 From: Matthew Honnibal Date: Thu, 5 Oct 2017 21:52:28 -0500 Subject: [PATCH 11/27] Fix training when hist_size==0 --- spacy/_ml.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/spacy/_ml.py b/spacy/_ml.py index 6223715b5..d6e745f22 100644 --- a/spacy/_ml.py +++ b/spacy/_ml.py @@ -257,6 +257,8 @@ class Embed(Model): def HistoryFeatures(nr_class, hist_size=8, nr_dim=8): '''Wrap a model, adding features representing action history.''' + if hist_size == 0: + return layerize(noop()) embed_tables = [Embed(nr_dim, nr_class, column=i, name='embed%d') for i in range(hist_size)] embed = concatenate(*embed_tables) From ca1276477289b570249967d595f232225df90184 Mon Sep 17 00:00:00 2001 From: Matthew Honnibal Date: Thu, 5 Oct 2017 21:53:29 -0500 Subject: [PATCH 12/27] Enable history features for beam parser --- spacy/syntax/_beam_utils.pyx | 11 ++++++++--- spacy/syntax/nn_parser.pyx | 11 +++++++++-- 2 files changed, 17 insertions(+), 5 deletions(-) diff --git a/spacy/syntax/_beam_utils.pyx b/spacy/syntax/_beam_utils.pyx index a26900f6b..da4efefbc 100644 --- a/spacy/syntax/_beam_utils.pyx +++ b/spacy/syntax/_beam_utils.pyx @@ -21,6 +21,7 @@ cdef int _transition_state(void* _dest, void* _src, class_t clas, void* _moves) moves = _moves dest.clone(src) moves[clas].do(dest.c, moves[clas].label) + dest.c.push_hist(clas) cdef int _check_final_state(void* _state, void* extra_args) except -1: @@ -148,8 +149,8 @@ def get_token_ids(states, int n_tokens): nr_update = 0 def update_beam(TransitionSystem moves, int nr_feature, int max_steps, states, golds, - state2vec, vec2scores, - int width, float density, + state2vec, vec2scores, + int width, float density, int hist_feats, losses=None, drop=0.): global nr_update cdef MaxViolation violn @@ -180,7 +181,11 @@ def update_beam(TransitionSystem moves, int nr_feature, int max_steps, # Now that we have our flat list of states, feed them through the model token_ids = get_token_ids(states, nr_feature) vectors, bp_vectors = state2vec.begin_update(token_ids, drop=drop) - scores, bp_scores = vec2scores.begin_update(vectors, drop=drop) + if hist_feats: + hists = numpy.asarray([st.history[:hist_feats] for st in states], dtype='i') + scores, bp_scores = vec2scores.begin_update((vectors, hists), drop=drop) + else: + scores, bp_scores = vec2scores.begin_update(vectors, drop=drop) # Store the callbacks for the backward pass backprops.append((token_ids, bp_vectors, bp_scores)) diff --git a/spacy/syntax/nn_parser.pyx b/spacy/syntax/nn_parser.pyx index b57e8b466..9a071ae14 100644 --- a/spacy/syntax/nn_parser.pyx +++ b/spacy/syntax/nn_parser.pyx @@ -505,7 +505,12 @@ cdef class Parser: states.append(stcls) token_ids = self.get_token_ids(states) vectors = state2vec(token_ids) - scores = vec2scores(vectors) + if self.cfg.get('hist_size', 0): + hists = numpy.asarray([st.history[:self.cfg['hist_size']] + for st in states], dtype='i') + scores = vec2scores(vectors, drop=drop) + else: + scores = vec2scores(vectors, drop=drop) j = 0 c_scores = scores.data for i in range(beam.size): @@ -537,6 +542,7 @@ cdef class Parser: guess = arg_maxout_if_valid(scores, is_valid, nr_class, nr_piece) action = self.moves.c[guess] action.do(state, action.label) + state.push_hist(guess) free(is_valid) free(scores) @@ -634,7 +640,7 @@ cdef class Parser: states_d_scores, backprops = _beam_utils.update_beam(self.moves, self.nr_feature, 500, states, golds, state2vec, vec2scores, - width, density, + width, density, self.cfg.get('hist_size', 0), drop=drop, losses=losses) backprop_lower = [] cdef float batch_size = len(docs) @@ -967,6 +973,7 @@ cdef int _transition_state(void* _dest, void* _src, class_t clas, void* _moves) moves = _moves dest.clone(src) moves[clas].do(dest.c, moves[clas].label) + dest.c.push_hist(clas) cdef int _check_final_state(void* _state, void* extra_args) except -1: From 363aa47b40b281d40ee9bfc187a8ba9b964ac913 Mon Sep 17 00:00:00 2001 From: Matthew Honnibal Date: Thu, 5 Oct 2017 21:53:49 -0500 Subject: [PATCH 13/27] Clean up dead parsing code --- spacy/syntax/nn_parser.pyx | 2 -- spacy/syntax/stateclass.pyx | 24 ------------------------ 2 files changed, 26 deletions(-) diff --git a/spacy/syntax/nn_parser.pyx b/spacy/syntax/nn_parser.pyx index 9a071ae14..e2c2b41c7 100644 --- a/spacy/syntax/nn_parser.pyx +++ b/spacy/syntax/nn_parser.pyx @@ -529,8 +529,6 @@ cdef class Parser: const float* feat_weights, int nr_class, int nr_feat, int nr_piece) nogil: '''This only works with no hidden layers -- fast but inaccurate''' - #for i in cython.parallel.prange(next_step.size(), num_threads=4, nogil=True): - # self._parse_step(next_step[i], feat_weights, nr_class, nr_feat) token_ids = calloc(nr_feat, sizeof(int)) scores = calloc(nr_class * nr_piece, sizeof(float)) is_valid = calloc(nr_class, sizeof(int)) diff --git a/spacy/syntax/stateclass.pyx b/spacy/syntax/stateclass.pyx index 9c179820c..ddd1f558c 100644 --- a/spacy/syntax/stateclass.pyx +++ b/spacy/syntax/stateclass.pyx @@ -62,27 +62,3 @@ cdef class StateClass: n0 = words[self.B(0)] n1 = words[self.B(1)] return ' '.join((third, second, top, '|', n0, n1)) - - @classmethod - def nr_context_tokens(cls): - return 13 - - def set_context_tokens(self, int[::1] output): - output[0] = self.B(0) - output[1] = self.B(1) - output[2] = self.S(0) - output[3] = self.S(1) - output[4] = self.S(2) - output[5] = self.L(self.S(0), 1) - output[6] = self.L(self.S(0), 2) - output[6] = self.R(self.S(0), 1) - output[7] = self.L(self.B(0), 1) - output[8] = self.R(self.S(0), 2) - output[9] = self.L(self.S(1), 1) - output[10] = self.L(self.S(1), 2) - output[11] = self.R(self.S(1), 1) - output[12] = self.R(self.S(1), 2) - - for i in range(13): - if output[i] != -1: - output[i] += self.c.offset From b0618def8d5e03d24e732432d858a34f1301b314 Mon Sep 17 00:00:00 2001 From: Matthew Honnibal Date: Thu, 5 Oct 2017 21:54:12 -0500 Subject: [PATCH 14/27] Add support for 2-token state option --- spacy/syntax/_state.pxd | 3 +++ 1 file changed, 3 insertions(+) diff --git a/spacy/syntax/_state.pxd b/spacy/syntax/_state.pxd index f4fa49286..50146401e 100644 --- a/spacy/syntax/_state.pxd +++ b/spacy/syntax/_state.pxd @@ -93,6 +93,9 @@ cdef cppclass StateC: free(this.shifted - PADDING) void set_context_tokens(int* ids, int n) nogil: + if n == 2: + ids[0] = this.B(0) + ids[1] = this.S(0) if n == 8: ids[0] = this.B(0) ids[1] = this.B(1) From 3db0a32fd651d7d8bd99f9a73eeae1124875a85e Mon Sep 17 00:00:00 2001 From: Matthew Honnibal Date: Thu, 5 Oct 2017 22:21:30 -0500 Subject: [PATCH 15/27] Fix dropout for history features --- spacy/_ml.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/spacy/_ml.py b/spacy/_ml.py index d6e745f22..7761e6d11 100644 --- a/spacy/_ml.py +++ b/spacy/_ml.py @@ -217,6 +217,7 @@ class PrecomputableMaxouts(Model): from thinc import describe from thinc.neural._classes.embed import _uniform_init + @describe.attributes( nV=describe.Dimension("Number of vectors"), nO=describe.Dimension("Size of output"), @@ -240,12 +241,12 @@ class Embed(Model): def predict(self, ids): if ids.ndim == 2: ids = ids[:, self.column] - return self.ops.xp.ascontiguousarray(self.vectors[ids]) + return self.ops.xp.ascontiguousarray(self.vectors[ids], dtype='f') def begin_update(self, ids, drop=0.): if ids.ndim == 2: ids = ids[:, self.column] - vectors = self.ops.xp.ascontiguousarray(self.vectors[ids]) + vectors = self.ops.xp.ascontiguousarray(self.vectors[ids], dtype='f') def backprop_embed(d_vectors, sgd=None): n_vectors = d_vectors.shape[0] self.ops.scatter_add(self.d_vectors, ids, d_vectors) @@ -267,8 +268,13 @@ def HistoryFeatures(nr_class, hist_size=8, nr_dim=8): vectors, hist_ids = vectors_hists hist_feats, bp_hists = embed.begin_update(hist_ids) outputs = ops.xp.hstack((vectors, hist_feats)) + mask = ops.get_dropout_mask(outputs.shape, drop) + if mask is not None: + outputs *= mask def add_history_bwd(d_outputs, sgd=None): + if mask is not None: + d_outputs *= mask d_vectors = d_outputs[:, :vectors.shape[1]] d_hists = d_outputs[:, vectors.shape[1]:] bp_hists(d_hists, sgd=sgd) From 555d8c8bffc8a3b31c0f3396b02fcf45cba4bd96 Mon Sep 17 00:00:00 2001 From: Matthew Honnibal Date: Thu, 5 Oct 2017 22:21:50 -0500 Subject: [PATCH 16/27] Fix beam history features --- spacy/syntax/nn_parser.pyx | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/spacy/syntax/nn_parser.pyx b/spacy/syntax/nn_parser.pyx index e2c2b41c7..2b244bb70 100644 --- a/spacy/syntax/nn_parser.pyx +++ b/spacy/syntax/nn_parser.pyx @@ -508,9 +508,9 @@ cdef class Parser: if self.cfg.get('hist_size', 0): hists = numpy.asarray([st.history[:self.cfg['hist_size']] for st in states], dtype='i') - scores = vec2scores(vectors, drop=drop) + scores = vec2scores((vectors, hists)) else: - scores = vec2scores(vectors, drop=drop) + scores = vec2scores(vectors) j = 0 c_scores = scores.data for i in range(beam.size): @@ -723,7 +723,7 @@ cdef class Parser: lower, stream, drop=0.0) return (tokvecs, bp_tokvecs), state2vec, upper - nr_feature = 8 + nr_feature = 2 def get_token_ids(self, states): cdef StateClass state From 21d11936fea53d9b67a2ae306a4825cdd15fcc6c Mon Sep 17 00:00:00 2001 From: Matthew Honnibal Date: Fri, 6 Oct 2017 06:08:50 -0500 Subject: [PATCH 17/27] Fix significant train/test skew error in history feats --- spacy/syntax/nn_parser.pyx | 1 + 1 file changed, 1 insertion(+) diff --git a/spacy/syntax/nn_parser.pyx b/spacy/syntax/nn_parser.pyx index 2b244bb70..3bca59b60 100644 --- a/spacy/syntax/nn_parser.pyx +++ b/spacy/syntax/nn_parser.pyx @@ -684,6 +684,7 @@ cdef class Parser: while state.B(0) < start and not state.is_final(): action = self.moves.c[oracle_actions.pop(0)] action.do(state.c, action.label) + state.c.push_hist(action.clas) n_moves += 1 has_gold = self.moves.has_gold(gold, start=start, end=start+max_length) From fbba7c517ece539f9b1c24df4f545b56189def72 Mon Sep 17 00:00:00 2001 From: Matthew Honnibal Date: Fri, 6 Oct 2017 06:09:18 -0500 Subject: [PATCH 18/27] Pass dropout through to embed tables --- spacy/_ml.py | 7 +------ 1 file changed, 1 insertion(+), 6 deletions(-) diff --git a/spacy/_ml.py b/spacy/_ml.py index 7761e6d11..f79c5668a 100644 --- a/spacy/_ml.py +++ b/spacy/_ml.py @@ -266,15 +266,10 @@ def HistoryFeatures(nr_class, hist_size=8, nr_dim=8): ops = embed.ops def add_history_fwd(vectors_hists, drop=0.): vectors, hist_ids = vectors_hists - hist_feats, bp_hists = embed.begin_update(hist_ids) + hist_feats, bp_hists = embed.begin_update(hist_ids, drop=drop) outputs = ops.xp.hstack((vectors, hist_feats)) - mask = ops.get_dropout_mask(outputs.shape, drop) - if mask is not None: - outputs *= mask def add_history_bwd(d_outputs, sgd=None): - if mask is not None: - d_outputs *= mask d_vectors = d_outputs[:, :vectors.shape[1]] d_hists = d_outputs[:, vectors.shape[1]:] bp_hists(d_hists, sgd=sgd) From 5c750a9c2f5c69e16c7c6c5e90d10870d0210e29 Mon Sep 17 00:00:00 2001 From: Matthew Honnibal Date: Fri, 6 Oct 2017 06:10:13 -0500 Subject: [PATCH 19/27] Reserve 0 for 'missing' in history features --- spacy/_ml.py | 2 ++ spacy/syntax/_state.pxd | 2 +- 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/spacy/_ml.py b/spacy/_ml.py index f79c5668a..898d6ab49 100644 --- a/spacy/_ml.py +++ b/spacy/_ml.py @@ -231,6 +231,8 @@ class Embed(Model): name = 'embed' def __init__(self, nO, nV=None, **kwargs): + if nV is not None: + nV += 1 Model.__init__(self, **kwargs) if 'name' in kwargs: self.name = kwargs['name'] diff --git a/spacy/syntax/_state.pxd b/spacy/syntax/_state.pxd index 50146401e..1864b22b3 100644 --- a/spacy/syntax/_state.pxd +++ b/spacy/syntax/_state.pxd @@ -297,7 +297,7 @@ cdef cppclass StateC: + hash64(&this._hist, sizeof(RingBufferC), 1) void push_hist(int act) nogil: - ring_push(&this._hist, act) + ring_push(&this._hist, act+1) int get_hist(int i) nogil: return ring_get(&this._hist, i) From c66399d8ae1e65580491fa7b0873fea1f8aeca0c Mon Sep 17 00:00:00 2001 From: Matthew Honnibal Date: Fri, 6 Oct 2017 06:20:05 -0500 Subject: [PATCH 20/27] Fix depth definition with history features --- spacy/syntax/nn_parser.pyx | 24 +++++++++++++----------- 1 file changed, 13 insertions(+), 11 deletions(-) diff --git a/spacy/syntax/nn_parser.pyx b/spacy/syntax/nn_parser.pyx index 3bca59b60..f9c8c0c14 100644 --- a/spacy/syntax/nn_parser.pyx +++ b/spacy/syntax/nn_parser.pyx @@ -246,6 +246,9 @@ cdef class Parser: embed_size = util.env_opt('embed_size', 7000) hist_size = util.env_opt('history_feats', cfg.get('history_feats', 0)) hist_width = util.env_opt('history_width', cfg.get('history_width', 0)) + if hist_size >= 1 and depth == 0: + raise ValueError("Inconsistent hyper-params: " + "history_feats >= 1 but parser_hidden_depth==0") tok2vec = Tok2Vec(token_vector_width, embed_size, pretrained_dims=cfg.get('pretrained_dims', 0)) tok2vec = chain(tok2vec, flatten) @@ -261,16 +264,15 @@ cdef class Parser: with Model.use_device('cpu'): if depth == 0: - if hist_size: - upper = chain( - HistoryFeatures(nr_class=nr_class, hist_size=hist_size, - nr_dim=hist_width), - zero_init(Affine(nr_class, nr_class+hist_size*hist_size, - drop_factor=0.0))) - upper.is_noop = False - else: - upper = chain() - upper.is_noop = True + upper = chain() + upper.is_noop = True + elif hist_size and depth == 1: + upper = chain( + HistoryFeatures(nr_class=nr_class, hist_size=hist_size, + nr_dim=hist_width), + zero_init(Affine(nr_class, hidden_width+hist_size*hist_width, + drop_factor=0.0))) + upper.is_noop = False elif hist_size: upper = chain( HistoryFeatures(nr_class=nr_class, hist_size=hist_size, @@ -282,7 +284,7 @@ cdef class Parser: upper.is_noop = False else: upper = chain( - Maxout(hidden_width, hidden_width), + clone(Maxout(hidden_width, hidden_width), depth-1), zero_init(Affine(nr_class, hidden_width, drop_factor=0.0)) ) upper.is_noop = False From 16ba6aa8a66b69eeeef482dc3247bc46e938aec7 Mon Sep 17 00:00:00 2001 From: Matthew Honnibal Date: Fri, 6 Oct 2017 13:17:31 -0500 Subject: [PATCH 21/27] Fix parser config serialization --- spacy/syntax/nn_parser.pyx | 23 ++++++++++++----------- 1 file changed, 12 insertions(+), 11 deletions(-) diff --git a/spacy/syntax/nn_parser.pyx b/spacy/syntax/nn_parser.pyx index f9c8c0c14..9ae53b103 100644 --- a/spacy/syntax/nn_parser.pyx +++ b/spacy/syntax/nn_parser.pyx @@ -238,14 +238,15 @@ cdef class Parser: Base class of the DependencyParser and EntityRecognizer. """ @classmethod - def Model(cls, nr_class, token_vector_width=128, hidden_width=200, depth=1, **cfg): - depth = util.env_opt('parser_hidden_depth', depth) - token_vector_width = util.env_opt('token_vector_width', token_vector_width) - hidden_width = util.env_opt('hidden_width', hidden_width) - parser_maxout_pieces = util.env_opt('parser_maxout_pieces', 2) - embed_size = util.env_opt('embed_size', 7000) - hist_size = util.env_opt('history_feats', cfg.get('history_feats', 0)) - hist_width = util.env_opt('history_width', cfg.get('history_width', 0)) + def Model(cls, nr_class, **cfg): + depth = util.env_opt('parser_hidden_depth', cfg.get('parser_hidden_depth', 1)) + token_vector_width = util.env_opt('token_vector_width', cfg.get('token_vector_width', 128)) + hidden_width = util.env_opt('hidden_width', cfg.get('hidden_width', 200)) + parser_maxout_pieces = util.env_opt('parser_maxout_pieces', cfg.get('parser_maxout_pieces', 3)) + embed_size = util.env_opt('embed_size', cfg.get('embed_size', 7000)) + hist_size = util.env_opt('history_feats', cfg.get('hist_size', 0)) + hist_width = util.env_opt('history_width', cfg.get('hist_width', 0)) + print("Create parser model", locals()) if hist_size >= 1 and depth == 0: raise ValueError("Inconsistent hyper-params: " "history_feats >= 1 but parser_hidden_depth==0") @@ -277,14 +278,14 @@ cdef class Parser: upper = chain( HistoryFeatures(nr_class=nr_class, hist_size=hist_size, nr_dim=hist_width), - Maxout(hidden_width, hidden_width+hist_size*hist_width), - clone(Maxout(hidden_width, hidden_width), depth-2), + LayerNorm(Maxout(hidden_width, hidden_width+hist_size*hist_width)), + clone(LayerNorm(Maxout(hidden_width, hidden_width)), depth-2), zero_init(Affine(nr_class, hidden_width, drop_factor=0.0)) ) upper.is_noop = False else: upper = chain( - clone(Maxout(hidden_width, hidden_width), depth-1), + clone(LayerNorm(Maxout(hidden_width, hidden_width)), depth-1), zero_init(Affine(nr_class, hidden_width, drop_factor=0.0)) ) upper.is_noop = False From f4c9a98166feacc788f2d93e834ae2cf3e0332d2 Mon Sep 17 00:00:00 2001 From: Matthew Honnibal Date: Fri, 6 Oct 2017 13:17:47 -0500 Subject: [PATCH 22/27] Fix spacy evaluate command on non-GPU --- spacy/cli/evaluate.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/spacy/cli/evaluate.py b/spacy/cli/evaluate.py index 42e077dc2..29e30b7d2 100644 --- a/spacy/cli/evaluate.py +++ b/spacy/cli/evaluate.py @@ -42,7 +42,8 @@ def evaluate(cmd, model, data_path, gpu_id=-1, gold_preproc=False, Evaluate a model. To render a sample of parses in a HTML file, set an output directory as the displacy_path argument. """ - util.use_gpu(gpu_id) + if gpu_id >= 0: + util.use_gpu(gpu_id) util.set_env_log(False) data_path = util.ensure_path(data_path) displacy_path = util.ensure_path(displacy_path) From 8e731009fea6afa67862c6293248fac244836d70 Mon Sep 17 00:00:00 2001 From: Matthew Honnibal Date: Fri, 6 Oct 2017 13:50:52 -0500 Subject: [PATCH 23/27] Fix parser config serialization --- spacy/syntax/nn_parser.pyx | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/spacy/syntax/nn_parser.pyx b/spacy/syntax/nn_parser.pyx index 9ae53b103..bb1ec1b4a 100644 --- a/spacy/syntax/nn_parser.pyx +++ b/spacy/syntax/nn_parser.pyx @@ -239,10 +239,10 @@ cdef class Parser: """ @classmethod def Model(cls, nr_class, **cfg): - depth = util.env_opt('parser_hidden_depth', cfg.get('parser_hidden_depth', 1)) + depth = util.env_opt('parser_hidden_depth', cfg.get('hidden_depth', 1)) token_vector_width = util.env_opt('token_vector_width', cfg.get('token_vector_width', 128)) hidden_width = util.env_opt('hidden_width', cfg.get('hidden_width', 200)) - parser_maxout_pieces = util.env_opt('parser_maxout_pieces', cfg.get('parser_maxout_pieces', 3)) + parser_maxout_pieces = util.env_opt('parser_maxout_pieces', cfg.get('maxout_pieces', 3)) embed_size = util.env_opt('embed_size', cfg.get('embed_size', 7000)) hist_size = util.env_opt('history_feats', cfg.get('hist_size', 0)) hist_width = util.env_opt('history_width', cfg.get('hist_width', 0)) @@ -295,7 +295,7 @@ cdef class Parser: lower.begin_training(lower.ops.allocate((500, token_vector_width))) cfg = { 'nr_class': nr_class, - 'depth': depth, + 'hidden_depth': depth, 'token_vector_width': token_vector_width, 'hidden_width': hidden_width, 'maxout_pieces': parser_maxout_pieces, @@ -727,7 +727,7 @@ cdef class Parser: lower, stream, drop=0.0) return (tokvecs, bp_tokvecs), state2vec, upper - nr_feature = 2 + nr_feature = 8 def get_token_ids(self, states): cdef StateClass state From 8be46d766e1b5c97abe44793ca0e278ac0b3657c Mon Sep 17 00:00:00 2001 From: Matthew Honnibal Date: Fri, 6 Oct 2017 16:19:02 -0500 Subject: [PATCH 24/27] Remove print statement --- spacy/syntax/nn_parser.pyx | 1 - 1 file changed, 1 deletion(-) diff --git a/spacy/syntax/nn_parser.pyx b/spacy/syntax/nn_parser.pyx index bb1ec1b4a..b5f218d75 100644 --- a/spacy/syntax/nn_parser.pyx +++ b/spacy/syntax/nn_parser.pyx @@ -246,7 +246,6 @@ cdef class Parser: embed_size = util.env_opt('embed_size', cfg.get('embed_size', 7000)) hist_size = util.env_opt('history_feats', cfg.get('hist_size', 0)) hist_width = util.env_opt('history_width', cfg.get('hist_width', 0)) - print("Create parser model", locals()) if hist_size >= 1 and depth == 0: raise ValueError("Inconsistent hyper-params: " "history_feats >= 1 but parser_hidden_depth==0") From e22067e3b538e70e55cd20a6ffd4d0a2e64a2f26 Mon Sep 17 00:00:00 2001 From: Matthew Honnibal Date: Sat, 7 Oct 2017 07:10:10 -0500 Subject: [PATCH 25/27] Document new hyper-parameters --- website/api/_top-level/_cli.jade | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/website/api/_top-level/_cli.jade b/website/api/_top-level/_cli.jade index f59d5afdd..5c91b48e8 100644 --- a/website/api/_top-level/_cli.jade +++ b/website/api/_top-level/_cli.jade @@ -314,6 +314,16 @@ p +cell Size of the parser's and NER's hidden layers. +cell #[code 128] + +row + +cell #[code history_feats] + +cell Number of previous action ID features for parser and NER + +cell #[code 128] + + +row + +cell #[code history_width] + +cell Number of embedding dimensions for each action ID + +cell #[code 128] + +row +cell #[code learn_rate] +cell Learning rate. From 3d22ccf4954fabc5bbbdf766b6a3ad3a8609692d Mon Sep 17 00:00:00 2001 From: Matthew Honnibal Date: Sat, 7 Oct 2017 07:16:41 -0500 Subject: [PATCH 26/27] Update default hyper-parameters --- spacy/syntax/nn_parser.pyx | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/spacy/syntax/nn_parser.pyx b/spacy/syntax/nn_parser.pyx index b5f218d75..fdcf1d2d1 100644 --- a/spacy/syntax/nn_parser.pyx +++ b/spacy/syntax/nn_parser.pyx @@ -239,13 +239,13 @@ cdef class Parser: """ @classmethod def Model(cls, nr_class, **cfg): - depth = util.env_opt('parser_hidden_depth', cfg.get('hidden_depth', 1)) + depth = util.env_opt('parser_hidden_depth', cfg.get('hidden_depth', 2)) token_vector_width = util.env_opt('token_vector_width', cfg.get('token_vector_width', 128)) - hidden_width = util.env_opt('hidden_width', cfg.get('hidden_width', 200)) - parser_maxout_pieces = util.env_opt('parser_maxout_pieces', cfg.get('maxout_pieces', 3)) + hidden_width = util.env_opt('hidden_width', cfg.get('hidden_width', 128)) + parser_maxout_pieces = util.env_opt('parser_maxout_pieces', cfg.get('maxout_pieces', 1)) embed_size = util.env_opt('embed_size', cfg.get('embed_size', 7000)) - hist_size = util.env_opt('history_feats', cfg.get('hist_size', 0)) - hist_width = util.env_opt('history_width', cfg.get('hist_width', 0)) + hist_size = util.env_opt('history_feats', cfg.get('hist_size', 4)) + hist_width = util.env_opt('history_width', cfg.get('hist_width', 16)) if hist_size >= 1 and depth == 0: raise ValueError("Inconsistent hyper-params: " "history_feats >= 1 but parser_hidden_depth==0") From d70cf1915889bbc4463d353427fb1655a2e922a1 Mon Sep 17 00:00:00 2001 From: ines Date: Sat, 7 Oct 2017 15:06:38 +0200 Subject: [PATCH 27/27] Fix formatting --- website/api/_top-level/_cli.jade | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/website/api/_top-level/_cli.jade b/website/api/_top-level/_cli.jade index 5c91b48e8..3a4b4702a 100644 --- a/website/api/_top-level/_cli.jade +++ b/website/api/_top-level/_cli.jade @@ -316,12 +316,12 @@ p +row +cell #[code history_feats] - +cell Number of previous action ID features for parser and NER + +cell Number of previous action ID features for parser and NER. +cell #[code 128] +row +cell #[code history_width] - +cell Number of embedding dimensions for each action ID + +cell Number of embedding dimensions for each action ID. +cell #[code 128] +row