From d44b1b337a14c4b78bbf48958d45561b88bbaa1d Mon Sep 17 00:00:00 2001 From: Matthew Honnibal Date: Mon, 13 Mar 2017 11:24:02 +0100 Subject: [PATCH 1/7] Try using LinearModel in tagger. --- spacy/tagger.pxd | 9 ++- spacy/tagger.pyx | 148 +++++++++++++++++++++++++++++++++-------------- 2 files changed, 109 insertions(+), 48 deletions(-) diff --git a/spacy/tagger.pxd b/spacy/tagger.pxd index 6d2cef1f4..ed4e3d9c4 100644 --- a/spacy/tagger.pxd +++ b/spacy/tagger.pxd @@ -1,17 +1,20 @@ from thinc.linear.avgtron cimport AveragedPerceptron from thinc.extra.eg cimport Example from thinc.structs cimport ExampleC +from thinc.linear.features cimport ConjunctionExtracter from .structs cimport TokenC from .vocab cimport Vocab -cdef class TaggerModel(AveragedPerceptron): - cdef void set_featuresC(self, ExampleC* eg, const TokenC* tokens, int i) except * - +cdef class TaggerModel: + cdef ConjunctionExtracter extracter + cdef object model + cdef class Tagger: cdef readonly Vocab vocab cdef readonly TaggerModel model cdef public dict freqs cdef public object cfg + cdef public object optimizer diff --git a/spacy/tagger.pyx b/spacy/tagger.pyx index 6f034f3de..1c11387b3 100644 --- a/spacy/tagger.pyx +++ b/spacy/tagger.pyx @@ -1,14 +1,25 @@ +# cython: infer_types=True +# cython: profile=True import json import pathlib from collections import defaultdict -from libc.string cimport memset +from libc.string cimport memset, memcpy +from libcpp.vector cimport vector +from libc.stdint cimport uint64_t, int32_t, int64_t +cimport numpy as np +import numpy as np +np.import_array() from cymem.cymem cimport Pool from thinc.typedefs cimport atom_t, weight_t from thinc.extra.eg cimport Example from thinc.structs cimport ExampleC from thinc.linear.avgtron cimport AveragedPerceptron -from thinc.linalg cimport VecVec +from thinc.linalg cimport Vec, VecVec +from thinc.linear.linear import LinearModel +from thinc.structs cimport FeatureC +from thinc.neural.optimizers import Adam +from thinc.neural.ops import NumpyOps from .typedefs cimport attr_t from .tokens.doc cimport Doc @@ -69,24 +80,69 @@ cpdef enum: N_CONTEXT_FIELDS -cdef class TaggerModel(AveragedPerceptron): - def update(self, Example eg): - self.time += 1 - guess = eg.guess - best = VecVec.arg_max_if_zero(eg.c.scores, eg.c.costs, eg.c.nr_class) - if guess != best: - for feat in eg.c.features[:eg.c.nr_feat]: - self.update_weight_ftrl(feat.key, best, -feat.value) - self.update_weight_ftrl(feat.key, guess, feat.value) +cdef class TaggerModel: + def __init__(self, int nr_tag, templates): + self.extracter = ConjunctionExtracter(templates) + self.model = LinearModel(nr_tag) - cdef void set_featuresC(self, ExampleC* eg, const TokenC* tokens, int i) except *: - _fill_from_token(&eg.atoms[P2_orth], &tokens[i-2]) - _fill_from_token(&eg.atoms[P1_orth], &tokens[i-1]) - _fill_from_token(&eg.atoms[W_orth], &tokens[i]) - _fill_from_token(&eg.atoms[N1_orth], &tokens[i+1]) - _fill_from_token(&eg.atoms[N2_orth], &tokens[i+2]) + def begin_update(self, atom_t[:, ::1] contexts, drop=0.): + cdef vector[uint64_t]* keys = new vector[uint64_t]() + cdef vector[float]* values = new vector[float]() + cdef vector[int64_t]* lengths = new vector[int64_t]() + features = new vector[FeatureC](self.extracter.nr_templ) + features.resize(self.extracter.nr_templ) + cdef FeatureC feat + cdef int i, j + for i in range(contexts.shape[0]): + nr_feat = self.extracter.set_features(features.data(), &contexts[i, 0]) + for j in range(nr_feat): + keys.push_back(features.at(j).key) + values.push_back(features.at(j).value) + lengths.push_back(nr_feat) + cdef np.ndarray[uint64_t, ndim=1] py_keys + cdef np.ndarray[float, ndim=1] py_values + cdef np.ndarray[long, ndim=1] py_lengths + py_keys = vector_uint64_2numpy(keys) + py_values = vector_float_2numpy(values) + py_lengths = vector_long_2numpy(lengths) + instance = (py_keys, py_values, py_lengths) + del keys + del values + del lengths + del features + return self.model.begin_update(instance, drop=drop) - eg.nr_feat = self.extracter.set_features(eg.features, eg.atoms) + def end_training(self, *args, **kwargs): + pass + + def dump(self, *args, **kwargs): + pass + + +cdef np.ndarray[uint64_t, ndim=1] vector_uint64_2numpy(vector[uint64_t]* vec): + cdef np.ndarray[uint64_t, ndim=1, mode="c"] arr = np.zeros(vec.size(), dtype='uint64') + memcpy(arr.data, vec.data(), sizeof(uint64_t) * vec.size()) + return arr + + +cdef np.ndarray[long, ndim=1] vector_long_2numpy(vector[int64_t]* vec): + cdef np.ndarray[long, ndim=1, mode="c"] arr = np.zeros(vec.size(), dtype='int64') + memcpy(arr.data, vec.data(), sizeof(int64_t) * vec.size()) + return arr + + +cdef np.ndarray[float, ndim=1] vector_float_2numpy(vector[float]* vec): + cdef np.ndarray[float, ndim=1, mode="c"] arr = np.zeros(vec.size(), dtype='float32') + memcpy(arr.data, vec.data(), sizeof(float) * vec.size()) + return arr + + +cdef void fill_context(atom_t* context, const TokenC* tokens, int i) nogil: + _fill_from_token(&context[P2_orth], &tokens[i-2]) + _fill_from_token(&context[P1_orth], &tokens[i-1]) + _fill_from_token(&context[W_orth], &tokens[i]) + _fill_from_token(&context[N1_orth], &tokens[i+1]) + _fill_from_token(&context[N2_orth], &tokens[i+2]) cdef inline void _fill_from_token(atom_t* context, const TokenC* t) nogil: @@ -157,17 +213,17 @@ cdef class Tagger: The newly constructed object. """ if model is None: - model = TaggerModel(cfg.get('features', self.feature_templates), - L1=0.0) + model = TaggerModel(vocab.morphology.n_tags, + cfg.get('features', self.feature_templates)) self.vocab = vocab self.model = model - self.model.l1_penalty = 0.0 # TODO: Move this to tag map self.freqs = {TAG: defaultdict(int)} for tag in self.tag_names: self.freqs[TAG][self.vocab.strings[tag]] = 1 self.freqs[TAG][0] = 1 self.cfg = cfg + self.optimizer = Adam(NumpyOps(), 0.001) @property def tag_names(self): @@ -194,20 +250,20 @@ cdef class Tagger: if tokens.length == 0: return 0 - cdef Pool mem = Pool() + cdef atom_t[1][N_CONTEXT_FIELDS] c_context + memset(c_context, 0, sizeof(c_context)) + cdef atom_t[:, ::1] context = c_context + cdef float[:, ::1] scores - cdef int i, tag - cdef Example eg = Example(nr_atom=N_CONTEXT_FIELDS, - nr_class=self.vocab.morphology.n_tags, - nr_feat=self.model.nr_feat) + cdef int nr_class = self.vocab.morphology.n_tags for i in range(tokens.length): if tokens.c[i].pos == 0: - self.model.set_featuresC(&eg.c, tokens.c, i) - self.model.set_scoresC(eg.c.scores, - eg.c.features, eg.c.nr_feat) - guess = VecVec.arg_max_if_true(eg.c.scores, eg.c.is_valid, eg.c.nr_class) + fill_context(&context[0, 0], tokens.c, i) + scores, _ = self.model.begin_update(context) + + guess = Vec.arg_max(&scores[0, 0], nr_class) self.vocab.morphology.assign_tag_id(&tokens.c[i], guess) - eg.fill_scores(0, eg.c.nr_class) + memset(&scores[0, 0], 0, sizeof(float) * scores.size) tokens.is_tagged = True tokens._py_tokens = [None] * tokens.length @@ -239,6 +295,7 @@ cdef class Tagger: Returns (int): Number of tags correct. """ + cdef int nr_class = self.vocab.morphology.n_tags gold_tag_strs = gold.tags assert len(tokens) == len(gold_tag_strs) for tag in gold_tag_strs: @@ -248,24 +305,25 @@ cdef class Tagger: raise ValueError(msg % tag) golds = [self.tag_names.index(g) if g is not None else -1 for g in gold_tag_strs] cdef int correct = 0 - cdef Pool mem = Pool() - cdef Example eg = Example( - nr_atom=N_CONTEXT_FIELDS, - nr_class=self.vocab.morphology.n_tags, - nr_feat=self.model.nr_feat) + + cdef atom_t[:, ::1] context = np.zeros((1, N_CONTEXT_FIELDS), dtype='uint64') + cdef float[:, ::1] scores + for i in range(tokens.length): - self.model.set_featuresC(&eg.c, tokens.c, i) - eg.costs = [ 1 if golds[i] not in (c, -1) else 0 for c in xrange(eg.nr_class) ] - self.model.set_scoresC(eg.c.scores, - eg.c.features, eg.c.nr_feat) - self.model.update(eg) + fill_context(&context[0, 0], tokens.c, i) + scores, finish_update = self.model.begin_update(context) + guess = Vec.arg_max(&scores[0, 0], nr_class) + self.vocab.morphology.assign_tag_id(&tokens.c[i], guess) - self.vocab.morphology.assign_tag_id(&tokens.c[i], eg.guess) + if golds[i] != -1: + scores[0, golds[i]] -= 1 + finish_update(scores, lambda *args, **kwargs: None) - correct += eg.cost == 0 + if (golds[i] in (guess, -1)): + correct += 1 self.freqs[TAG][tokens.c[i].tag] += 1 - eg.fill_scores(0, eg.c.nr_class) - eg.fill_costs(0, eg.c.nr_class) + self.optimizer(self.model.model.weights, self.model.model.d_weights, + key=self.model.model.id) tokens.is_tagged = True tokens._py_tokens = [None] * tokens.length return correct From 2ac166eacd80c2b0054cb1a6c8adbd0b09176e69 Mon Sep 17 00:00:00 2001 From: Matthew Honnibal Date: Mon, 13 Mar 2017 11:24:36 +0100 Subject: [PATCH 2/7] Add cython compilation flags to gold.pyx --- spacy/gold.pyx | 2 ++ 1 file changed, 2 insertions(+) diff --git a/spacy/gold.pyx b/spacy/gold.pyx index 1e9a0194f..806ab9857 100644 --- a/spacy/gold.pyx +++ b/spacy/gold.pyx @@ -1,3 +1,5 @@ +# cython: profile=True +# cython: infer_types=True from __future__ import unicode_literals, print_function import numpy From 755d7d486c298962e718ed1cd738d68390431cab Mon Sep 17 00:00:00 2001 From: Matthew Honnibal Date: Tue, 14 Mar 2017 21:28:43 +0100 Subject: [PATCH 3/7] WIP on hash kernel --- setup.py | 1 + spacy/_ml.pxd | 31 +++++++ spacy/_ml.pyx | 151 ++++++++++++++++++++++++++++++++ spacy/about.py | 6 +- spacy/syntax/_state.pxd | 12 +-- spacy/syntax/arc_eager.pyx | 8 +- spacy/syntax/parser.pxd | 11 ++- spacy/syntax/parser.pyx | 170 +++++++++++++++++++++++-------------- spacy/tagger.pxd | 15 ++-- spacy/tagger.pyx | 153 +++++++++++++-------------------- spacy/train.py | 4 +- 11 files changed, 383 insertions(+), 179 deletions(-) create mode 100644 spacy/_ml.pxd create mode 100644 spacy/_ml.pyx diff --git a/setup.py b/setup.py index 26f395ea5..373d5af9d 100644 --- a/setup.py +++ b/setup.py @@ -56,6 +56,7 @@ MOD_NAMES = [ 'spacy.lexeme', 'spacy.vocab', 'spacy.attrs', + 'spacy._ml', 'spacy.morphology', 'spacy.tagger', 'spacy.pipeline', diff --git a/spacy/_ml.pxd b/spacy/_ml.pxd new file mode 100644 index 000000000..4f2f42427 --- /dev/null +++ b/spacy/_ml.pxd @@ -0,0 +1,31 @@ +from thinc.linear.features cimport ConjunctionExtracter +from thinc.typedefs cimport atom_t, weight_t +from thinc.structs cimport FeatureC +from libc.stdint cimport uint32_t +cimport numpy as np +from cymem.cymem cimport Pool + + +cdef class LinearModel: + cdef ConjunctionExtracter extracter + cdef readonly int nr_class + cdef readonly uint32_t nr_weight + cdef public weight_t learn_rate + cdef Pool mem + cdef weight_t* W + cdef weight_t* d_W + + cdef void hinge_lossC(self, weight_t* d_scores, + const weight_t* scores, const weight_t* costs) nogil + + cdef void log_lossC(self, weight_t* d_scores, + const weight_t* scores, const weight_t* costs) nogil + + cdef void regression_lossC(self, weight_t* d_scores, + const weight_t* scores, const weight_t* costs) nogil + + cdef void set_scoresC(self, weight_t* scores, + const FeatureC* features, int nr_feat) nogil + + cdef void set_gradientC(self, const weight_t* d_scores, const FeatureC* + features, int nr_feat) nogil diff --git a/spacy/_ml.pyx b/spacy/_ml.pyx new file mode 100644 index 000000000..c3413f561 --- /dev/null +++ b/spacy/_ml.pyx @@ -0,0 +1,151 @@ +# cython: infer_types=True +# cython: profile=True +# cython: cdivision=True + +from libcpp.vector cimport vector +from libc.stdint cimport uint64_t, uint32_t, int32_t +from libc.string cimport memcpy, memset +cimport libcpp.algorithm +from libc.math cimport exp + +from cymem.cymem cimport Pool +from thinc.linalg cimport Vec, VecVec +from murmurhash.mrmr cimport hash64 +cimport numpy as np +import numpy +np.import_array() + + +cdef class LinearModel: + def __init__(self, int nr_class, templates, weight_t learn_rate=0.001, + size=2**18): + self.extracter = ConjunctionExtracter(templates) + self.nr_weight = size + self.nr_class = nr_class + self.learn_rate = learn_rate + self.mem = Pool() + self.W = self.mem.alloc(self.nr_weight * self.nr_class, + sizeof(weight_t)) + self.d_W = self.mem.alloc(self.nr_weight * self.nr_class, + sizeof(weight_t)) + + cdef void hinge_lossC(self, weight_t* d_scores, + const weight_t* scores, const weight_t* costs) nogil: + guess = 0 + best = -1 + for i in range(1, self.nr_class): + if scores[i] > scores[guess]: + guess = i + if costs[i] == 0 and (best == -1 or scores[i] > scores[best]): + best = i + if best != -1 and scores[guess] >= scores[best]: + d_scores[guess] = 1. + d_scores[best] = -1. + + cdef void log_lossC(self, weight_t* d_scores, + const weight_t* scores, const weight_t* costs) nogil: + for i in range(self.nr_class): + if costs[i] <= 0: + break + else: + return + cdef double Z = 1e-10 + cdef double gZ = 1e-10 + cdef double max_ = scores[0] + cdef double g_max = -9000 + for i in range(self.nr_class): + max_ = max(max_, scores[i]) + if costs[i] <= 0: + g_max = max(g_max, scores[i]) + for i in range(self.nr_class): + Z += exp(scores[i]-max_) + if costs[i] <= 0: + gZ += exp(scores[i]-g_max) + for i in range(self.nr_class): + score = exp(scores[i]-max_) + if costs[i] >= 1: + d_scores[i] = score / Z + else: + g_score = exp(scores[i]-g_max) + d_scores[i] = (score / Z) - (g_score / gZ) + + cdef void regression_lossC(self, weight_t* d_scores, + const weight_t* scores, const weight_t* costs) nogil: + best = -1 + for i in range(self.nr_class): + if costs[i] <= 0: + if best == -1: + best = i + elif scores[i] > scores[best]: + best = i + if best == -1: + return + for i in range(self.nr_class): + if scores[i] < scores[best]: + d_scores[i] = 0 + elif costs[i] <= 0 and scores[i] == best: + continue + else: + d_scores[i] = scores[i] - -costs[i] + + cdef void set_scoresC(self, weight_t* scores, + const FeatureC* features, int nr_feat) nogil: + cdef uint64_t nr_weight = self.nr_weight + cdef int nr_class = self.nr_class + cdef vector[uint64_t] indices + # Collect all feature indices + cdef uint32_t[2] hashed + cdef FeatureC feat + cdef uint64_t hash2 + for feat in features[:nr_feat]: + if feat.value == 0: + continue + memcpy(hashed, &feat.key, sizeof(hashed)) + indices.push_back(hashed[0] % nr_weight) + indices.push_back(hashed[1] % nr_weight) + + # Sort them, to improve memory access pattern + libcpp.algorithm.sort(indices.begin(), indices.end()) + for idx in indices: + W = &self.W[idx * nr_class] + for clas in range(nr_class): + scores[clas] += W[clas] + + cdef void set_gradientC(self, const weight_t* d_scores, const FeatureC* + features, int nr_feat) nogil: + cdef uint64_t nr_weight = self.nr_weight + cdef int nr_class = self.nr_class + cdef vector[uint64_t] indices + # Collect all feature indices + cdef uint32_t[2] hashed + cdef uint64_t hash2 + for feat in features[:nr_feat]: + if feat.value == 0: + continue + memcpy(hashed, &feat.key, sizeof(hashed)) + indices.push_back(hashed[0] % nr_weight) + indices.push_back(hashed[1] % nr_weight) + + # Sort them, to improve memory access pattern + libcpp.algorithm.sort(indices.begin(), indices.end()) + for idx in indices: + W = &self.W[idx * nr_class] + for clas in range(nr_class): + if d_scores[clas] < 0: + W[clas] -= self.learn_rate * max(-10., d_scores[clas]) + else: + W[clas] -= self.learn_rate * min(10., d_scores[clas]) + + @property + def nr_active_feat(self): + return self.nr_weight + + @property + def nr_feat(self): + return self.extracter.nr_templ + + def end_training(self, *args, **kwargs): + pass + + def dump(self, *args, **kwargs): + pass diff --git a/spacy/about.py b/spacy/about.py index d51dea286..57e845a5c 100644 --- a/spacy/about.py +++ b/spacy/about.py @@ -4,13 +4,13 @@ # https://github.com/pypa/warehouse/blob/master/warehouse/__about__.py __title__ = 'spacy' -__version__ = '1.6.0' +__version__ = '1.7.0' __summary__ = 'Industrial-strength Natural Language Processing (NLP) with Python and Cython' __uri__ = 'https://spacy.io' __author__ = 'Matthew Honnibal' __email__ = 'matt@explosion.ai' __license__ = 'MIT' __models__ = { - 'en': 'en>=1.1.0,<1.2.0', - 'de': 'de>=1.0.0,<1.1.0', + 'en': 'en>=1.2.0,<1.3.0', + 'de': 'de>=1.2.0,<1.3.0', } diff --git a/spacy/syntax/_state.pxd b/spacy/syntax/_state.pxd index c764e877d..383e91faa 100644 --- a/spacy/syntax/_state.pxd +++ b/spacy/syntax/_state.pxd @@ -304,11 +304,13 @@ cdef cppclass StateC: this._break = this._b_i void clone(const StateC* src) nogil: - memcpy(this._sent, src._sent, this.length * sizeof(TokenC)) - memcpy(this._stack, src._stack, this.length * sizeof(int)) - memcpy(this._buffer, src._buffer, this.length * sizeof(int)) - memcpy(this._ents, src._ents, this.length * sizeof(Entity)) - memcpy(this.shifted, src.shifted, this.length * sizeof(this.shifted[0])) + # This is still quadratic, but make it a it faster. + # Not carefully reviewed for accuracy yet. + memcpy(this._sent, src._sent, this.B(1) * sizeof(TokenC)) + memcpy(this._stack, src._stack, this._s_i * sizeof(int)) + memcpy(this._buffer, src._buffer, this._b_i * sizeof(int)) + memcpy(this._ents, src._ents, this._e_i * sizeof(Entity)) + memcpy(this.shifted, src.shifted, this.B(2) * sizeof(this.shifted[0])) this.length = src.length this._b_i = src._b_i this._s_i = src._s_i diff --git a/spacy/syntax/arc_eager.pyx b/spacy/syntax/arc_eager.pyx index 7049b8595..a0e2bf4d0 100644 --- a/spacy/syntax/arc_eager.pyx +++ b/spacy/syntax/arc_eager.pyx @@ -70,7 +70,7 @@ cdef weight_t push_cost(StateClass stcls, const GoldParseC* gold, int target) no cdef weight_t pop_cost(StateClass stcls, const GoldParseC* gold, int target) nogil: cdef weight_t cost = 0 cdef int i, B_i - for i in range(stcls.buffer_length()): + for i in range(min(30, stcls.buffer_length())): B_i = stcls.B(i) cost += gold.heads[B_i] == target cost += gold.heads[target] == B_i @@ -268,10 +268,12 @@ cdef class Break: cdef int i, j, S_i, B_i for i in range(s.stack_depth()): S_i = s.S(i) - for j in range(s.buffer_length()): + for j in range(min(30, s.buffer_length())): B_i = s.B(j) cost += gold.heads[S_i] == B_i cost += gold.heads[B_i] == S_i + if cost != 0: + break # Check for sentence boundary --- if it's here, we can't have any deps # between stack and buffer, so rest of action is irrelevant. s0_root = _get_root(s.S(0), gold) @@ -462,7 +464,7 @@ cdef class ArcEager(TransitionSystem): cdef int* labels = gold.c.labels cdef int* heads = gold.c.heads - n_gold = 0 + cdef int n_gold = 0 for i in range(self.n_moves): if self.c[i].is_valid(stcls.c, self.c[i].label): is_valid[i] = True diff --git a/spacy/syntax/parser.pxd b/spacy/syntax/parser.pxd index aaed10303..020e1e793 100644 --- a/spacy/syntax/parser.pxd +++ b/spacy/syntax/parser.pxd @@ -1,5 +1,6 @@ from thinc.linear.avgtron cimport AveragedPerceptron -from thinc.typedefs cimport atom_t +from thinc.linear.features cimport ConjunctionExtracter +from thinc.typedefs cimport atom_t, weight_t from thinc.structs cimport FeatureC from .stateclass cimport StateClass @@ -8,17 +9,19 @@ from ..vocab cimport Vocab from ..tokens.doc cimport Doc from ..structs cimport TokenC from ._state cimport StateC +from .._ml cimport LinearModel -cdef class ParserModel(AveragedPerceptron): +cdef class ParserModel(LinearModel): cdef int set_featuresC(self, atom_t* context, FeatureC* features, const StateC* state) nogil - - + + cdef class Parser: cdef readonly Vocab vocab cdef readonly ParserModel model cdef readonly TransitionSystem moves cdef readonly object cfg + cdef public object optimizer cdef int parseC(self, TokenC* tokens, int length, int nr_feat, int nr_class) with gil diff --git a/spacy/syntax/parser.pyx b/spacy/syntax/parser.pyx index 804542cc8..dc157d13d 100644 --- a/spacy/syntax/parser.pyx +++ b/spacy/syntax/parser.pyx @@ -1,4 +1,6 @@ # cython: infer_types=True +# cython: cdivision=True +# cython: profile=True """ MALT-style dependency parser """ @@ -20,15 +22,22 @@ import shutil import json import sys from .nonproj import PseudoProjectivity +import numpy +import random +cimport numpy as np +np.import_array() from cymem.cymem cimport Pool, Address -from murmurhash.mrmr cimport hash64 +from murmurhash.mrmr cimport hash64, hash32 from thinc.typedefs cimport weight_t, class_t, feat_t, atom_t, hash_t from thinc.linear.avgtron cimport AveragedPerceptron from thinc.linalg cimport VecVec from thinc.structs cimport SparseArrayC from preshed.maps cimport MapStruct from preshed.maps cimport map_get +from thinc.neural.ops import NumpyOps +from thinc.neural.optimizers import Adam +from thinc.neural.optimizers import SGD from thinc.structs cimport FeatureC from thinc.structs cimport ExampleC @@ -51,6 +60,7 @@ from ._parse_features cimport CONTEXT_SIZE from ._parse_features cimport fill_context from .stateclass cimport StateClass from ._state cimport StateC +from .._ml cimport LinearModel DEBUG = False @@ -72,57 +82,65 @@ def get_templates(name): pf.tree_shape + pf.trigrams) -cdef class ParserModel(AveragedPerceptron): +#cdef class ParserModel(AveragedPerceptron): +# cdef int set_featuresC(self, atom_t* context, FeatureC* features, +# const StateC* state) nogil: +# fill_context(context, state) +# nr_feat = self.extracter.set_features(features, context) +# return nr_feat +# +# def update(self, Example eg, itn=0): +# '''Does regression on negative cost. Sort of cute?''' +# self.time += 1 +# best = arg_max_if_gold(eg.c.scores, eg.c.costs, eg.c.nr_class) +# guess = eg.guess +# cdef weight_t loss = 0.0 +# if guess == best: +# return loss +# for clas in [guess, best]: +# loss += (-eg.c.costs[clas] - eg.c.scores[clas]) ** 2 +# d_loss = eg.c.scores[clas] - -eg.c.costs[clas] +# for feat in eg.c.features[:eg.c.nr_feat]: +# self.update_weight_ftrl(feat.key, clas, feat.value * d_loss) +# return loss +# +# def update_from_histories(self, TransitionSystem moves, Doc doc, histories, weight_t min_grad=0.0): +# cdef Pool mem = Pool() +# features = mem.alloc(self.nr_feat, sizeof(FeatureC)) +# +# cdef StateClass stcls +# +# cdef class_t clas +# self.time += 1 +# cdef atom_t[CONTEXT_SIZE] atoms +# histories = [(grad, hist) for grad, hist in histories if abs(grad) >= min_grad and hist] +# if not histories: +# return None +# gradient = [Counter() for _ in range(max([max(h)+1 for _, h in histories]))] +# for d_loss, history in histories: +# stcls = StateClass.init(doc.c, doc.length) +# moves.initialize_state(stcls.c) +# for clas in history: +# nr_feat = self.set_featuresC(atoms, features, stcls.c) +# clas_grad = gradient[clas] +# for feat in features[:nr_feat]: +# clas_grad[feat.key] += d_loss * feat.value +# moves.c[clas].do(stcls.c, moves.c[clas].label) +# cdef feat_t key +# cdef weight_t d_feat +# for clas, clas_grad in enumerate(gradient): +# for key, d_feat in clas_grad.items(): +# if d_feat != 0: +# self.update_weight_ftrl(key, clas, d_feat) +# + +cdef class ParserModel(LinearModel): cdef int set_featuresC(self, atom_t* context, FeatureC* features, const StateC* state) nogil: fill_context(context, state) nr_feat = self.extracter.set_features(features, context) return nr_feat - def update(self, Example eg, itn=0): - '''Does regression on negative cost. Sort of cute?''' - self.time += 1 - best = arg_max_if_gold(eg.c.scores, eg.c.costs, eg.c.nr_class) - guess = eg.guess - cdef weight_t loss = 0.0 - if guess == best: - return loss - for clas in [guess, best]: - loss += (-eg.c.costs[clas] - eg.c.scores[clas]) ** 2 - d_loss = eg.c.scores[clas] - -eg.c.costs[clas] - for feat in eg.c.features[:eg.c.nr_feat]: - self.update_weight_ftrl(feat.key, clas, feat.value * d_loss) - return loss - - def update_from_histories(self, TransitionSystem moves, Doc doc, histories, weight_t min_grad=0.0): - cdef Pool mem = Pool() - features = mem.alloc(self.nr_feat, sizeof(FeatureC)) - - cdef StateClass stcls - - cdef class_t clas - self.time += 1 - cdef atom_t[CONTEXT_SIZE] atoms - histories = [(grad, hist) for grad, hist in histories if abs(grad) >= min_grad and hist] - if not histories: - return None - gradient = [Counter() for _ in range(max([max(h)+1 for _, h in histories]))] - for d_loss, history in histories: - stcls = StateClass.init(doc.c, doc.length) - moves.initialize_state(stcls.c) - for clas in history: - nr_feat = self.set_featuresC(atoms, features, stcls.c) - clas_grad = gradient[clas] - for feat in features[:nr_feat]: - clas_grad[feat.key] += d_loss * feat.value - moves.c[clas].do(stcls.c, moves.c[clas].label) - cdef feat_t key - cdef weight_t d_feat - for clas, clas_grad in enumerate(gradient): - for key, d_feat in clas_grad.items(): - if d_feat != 0: - self.update_weight_ftrl(key, clas, d_feat) - cdef class Parser: """Base class of the DependencyParser and EntityRecognizer.""" @@ -174,9 +192,14 @@ cdef class Parser: cfg['features'] = get_templates(cfg['features']) elif 'features' not in cfg: cfg['features'] = self.feature_templates - self.model = ParserModel(cfg['features']) - self.model.l1_penalty = cfg.get('L1', 1e-8) - self.model.learn_rate = cfg.get('learn_rate', 0.001) + self.model = ParserModel(self.moves.n_moves, cfg['features'], + size=2**18, + learn_rate=cfg.get('learn_rate', 0.001)) + #self.model.l1_penalty = cfg.get('L1', 1e-8) + #self.model.learn_rate = cfg.get('learn_rate', 0.001) + + self.optimizer = SGD(NumpyOps(), cfg.get('learn_rate', 0.001), + momentum=0.9) self.cfg = cfg @@ -300,27 +323,48 @@ cdef class Parser: self.moves.preprocess_gold(gold) cdef StateClass stcls = StateClass.init(tokens.c, tokens.length) self.moves.initialize_state(stcls.c) + + cdef int nr_class = self.model.nr_class cdef Pool mem = Pool() - cdef Example eg = Example( - nr_class=self.moves.n_moves, - nr_atom=CONTEXT_SIZE, - nr_feat=self.model.nr_feat) + d_scores = mem.alloc(nr_class, sizeof(weight_t)) + scores = mem.alloc(nr_class, sizeof(weight_t)) + costs = mem.alloc(nr_class, sizeof(weight_t)) + features = mem.alloc(self.model.nr_feat, sizeof(FeatureC)) + is_valid = mem.alloc(self.moves.n_moves, sizeof(int)) + cdef atom_t[CONTEXT_SIZE] context + cdef weight_t loss = 0 cdef Transition action + words = [w.text for w in tokens] + while not stcls.is_final(): - eg.c.nr_feat = self.model.set_featuresC(eg.c.atoms, eg.c.features, - stcls.c) - self.moves.set_costs(eg.c.is_valid, eg.c.costs, stcls, gold) - self.model.set_scoresC(eg.c.scores, eg.c.features, eg.c.nr_feat) - guess = VecVec.arg_max_if_true(eg.c.scores, eg.c.is_valid, eg.c.nr_class) - self.model.update(eg) + + nr_feat = self.model.set_featuresC(context, features, stcls.c) + self.moves.set_costs(is_valid, costs, stcls, gold) + self.model.set_scoresC(scores, features, nr_feat) + + guess = VecVec.arg_max_if_true(scores, is_valid, nr_class) + best = arg_max_if_gold(scores, costs, nr_class) + + self.model.regression_lossC(d_scores, scores, costs) + self.model.set_gradientC(d_scores, features, nr_feat) action = self.moves.c[guess] action.do(stcls.c, action.label) - loss += eg.costs[guess] - eg.fill_scores(0, eg.c.nr_class) - eg.fill_costs(0, eg.c.nr_class) - eg.fill_is_valid(1, eg.c.nr_class) + #print(scores[guess], scores[best], d_scores[guess], costs[guess], + # self.moves.move_name(action.move, action.label), stcls.print_state(words)) + + loss += scores[guess] + memset(context, 0, sizeof(context)) + memset(features, 0, sizeof(features[0]) * nr_feat) + memset(scores, 0, sizeof(scores[0]) * nr_class) + memset(d_scores, 0, sizeof(d_scores[0]) * nr_class) + memset(costs, 0, sizeof(costs[0]) * nr_class) + for i in range(nr_class): + is_valid[i] = 1 + #if itn % 100 == 0: + # self.optimizer(self.model.model[0].ravel(), + # self.model.model[1].ravel(), key=1) return loss def step_through(self, Doc doc): diff --git a/spacy/tagger.pxd b/spacy/tagger.pxd index ed4e3d9c4..deab79fab 100644 --- a/spacy/tagger.pxd +++ b/spacy/tagger.pxd @@ -1,15 +1,14 @@ -from thinc.linear.avgtron cimport AveragedPerceptron -from thinc.extra.eg cimport Example -from thinc.structs cimport ExampleC -from thinc.linear.features cimport ConjunctionExtracter - from .structs cimport TokenC from .vocab cimport Vocab +from ._ml cimport LinearModel +from thinc.structs cimport FeatureC +from thinc.typedefs cimport atom_t -cdef class TaggerModel: - cdef ConjunctionExtracter extracter - cdef object model +cdef class TaggerModel(LinearModel): + cdef int set_featuresC(self, FeatureC* features, atom_t* context, + const TokenC* tokens, int i) nogil + cdef class Tagger: diff --git a/spacy/tagger.pyx b/spacy/tagger.pyx index 1c11387b3..76807b328 100644 --- a/spacy/tagger.pyx +++ b/spacy/tagger.pyx @@ -16,9 +16,8 @@ from thinc.extra.eg cimport Example from thinc.structs cimport ExampleC from thinc.linear.avgtron cimport AveragedPerceptron from thinc.linalg cimport Vec, VecVec -from thinc.linear.linear import LinearModel from thinc.structs cimport FeatureC -from thinc.neural.optimizers import Adam +from thinc.neural.optimizers import Adam, SGD from thinc.neural.ops import NumpyOps from .typedefs cimport attr_t @@ -80,69 +79,16 @@ cpdef enum: N_CONTEXT_FIELDS -cdef class TaggerModel: - def __init__(self, int nr_tag, templates): - self.extracter = ConjunctionExtracter(templates) - self.model = LinearModel(nr_tag) - - def begin_update(self, atom_t[:, ::1] contexts, drop=0.): - cdef vector[uint64_t]* keys = new vector[uint64_t]() - cdef vector[float]* values = new vector[float]() - cdef vector[int64_t]* lengths = new vector[int64_t]() - features = new vector[FeatureC](self.extracter.nr_templ) - features.resize(self.extracter.nr_templ) - cdef FeatureC feat - cdef int i, j - for i in range(contexts.shape[0]): - nr_feat = self.extracter.set_features(features.data(), &contexts[i, 0]) - for j in range(nr_feat): - keys.push_back(features.at(j).key) - values.push_back(features.at(j).value) - lengths.push_back(nr_feat) - cdef np.ndarray[uint64_t, ndim=1] py_keys - cdef np.ndarray[float, ndim=1] py_values - cdef np.ndarray[long, ndim=1] py_lengths - py_keys = vector_uint64_2numpy(keys) - py_values = vector_float_2numpy(values) - py_lengths = vector_long_2numpy(lengths) - instance = (py_keys, py_values, py_lengths) - del keys - del values - del lengths - del features - return self.model.begin_update(instance, drop=drop) - - def end_training(self, *args, **kwargs): - pass - - def dump(self, *args, **kwargs): - pass - - -cdef np.ndarray[uint64_t, ndim=1] vector_uint64_2numpy(vector[uint64_t]* vec): - cdef np.ndarray[uint64_t, ndim=1, mode="c"] arr = np.zeros(vec.size(), dtype='uint64') - memcpy(arr.data, vec.data(), sizeof(uint64_t) * vec.size()) - return arr - - -cdef np.ndarray[long, ndim=1] vector_long_2numpy(vector[int64_t]* vec): - cdef np.ndarray[long, ndim=1, mode="c"] arr = np.zeros(vec.size(), dtype='int64') - memcpy(arr.data, vec.data(), sizeof(int64_t) * vec.size()) - return arr - - -cdef np.ndarray[float, ndim=1] vector_float_2numpy(vector[float]* vec): - cdef np.ndarray[float, ndim=1, mode="c"] arr = np.zeros(vec.size(), dtype='float32') - memcpy(arr.data, vec.data(), sizeof(float) * vec.size()) - return arr - - -cdef void fill_context(atom_t* context, const TokenC* tokens, int i) nogil: - _fill_from_token(&context[P2_orth], &tokens[i-2]) - _fill_from_token(&context[P1_orth], &tokens[i-1]) - _fill_from_token(&context[W_orth], &tokens[i]) - _fill_from_token(&context[N1_orth], &tokens[i+1]) - _fill_from_token(&context[N2_orth], &tokens[i+2]) +cdef class TaggerModel(LinearModel): + cdef int set_featuresC(self, FeatureC* features, atom_t* context, + const TokenC* tokens, int i) nogil: + _fill_from_token(&context[P2_orth], &tokens[i-2]) + _fill_from_token(&context[P1_orth], &tokens[i-1]) + _fill_from_token(&context[W_orth], &tokens[i]) + _fill_from_token(&context[N1_orth], &tokens[i+1]) + _fill_from_token(&context[N2_orth], &tokens[i+2]) + nr_feat = self.extracter.set_features(features, context) + return nr_feat cdef inline void _fill_from_token(atom_t* context, const TokenC* t) nogil: @@ -213,8 +159,10 @@ cdef class Tagger: The newly constructed object. """ if model is None: + print("Create tagger") model = TaggerModel(vocab.morphology.n_tags, - cfg.get('features', self.feature_templates)) + cfg.get('features', self.feature_templates), + learn_rate=0.01, size=2**18) self.vocab = vocab self.model = model # TODO: Move this to tag map @@ -223,7 +171,7 @@ cdef class Tagger: self.freqs[TAG][self.vocab.strings[tag]] = 1 self.freqs[TAG][0] = 1 self.cfg = cfg - self.optimizer = Adam(NumpyOps(), 0.001) + self.optimizer = SGD(NumpyOps(), 0.001, momentum=0.9) @property def tag_names(self): @@ -250,20 +198,22 @@ cdef class Tagger: if tokens.length == 0: return 0 - cdef atom_t[1][N_CONTEXT_FIELDS] c_context - memset(c_context, 0, sizeof(c_context)) - cdef atom_t[:, ::1] context = c_context - cdef float[:, ::1] scores + cdef atom_t[N_CONTEXT_FIELDS] context cdef int nr_class = self.vocab.morphology.n_tags + cdef Pool mem = Pool() + scores = mem.alloc(nr_class, sizeof(weight_t)) + features = mem.alloc(self.model.nr_feat, sizeof(FeatureC)) for i in range(tokens.length): if tokens.c[i].pos == 0: - fill_context(&context[0, 0], tokens.c, i) - scores, _ = self.model.begin_update(context) - - guess = Vec.arg_max(&scores[0, 0], nr_class) + nr_feat = self.model.set_featuresC(features, context, tokens.c, i) + self.model.set_scoresC(scores, + features, nr_feat) + guess = Vec.arg_max(scores, nr_class) self.vocab.morphology.assign_tag_id(&tokens.c[i], guess) - memset(&scores[0, 0], 0, sizeof(float) * scores.size) + memset(scores, 0, sizeof(weight_t) * nr_class) + memset(features, 0, sizeof(FeatureC) * nr_feat) + memset(context, 0, sizeof(N_CONTEXT_FIELDS)) tokens.is_tagged = True tokens._py_tokens = [None] * tokens.length @@ -295,7 +245,6 @@ cdef class Tagger: Returns (int): Number of tags correct. """ - cdef int nr_class = self.vocab.morphology.n_tags gold_tag_strs = gold.tags assert len(tokens) == len(gold_tag_strs) for tag in gold_tag_strs: @@ -303,27 +252,47 @@ cdef class Tagger: msg = ("Unrecognized gold tag: %s. tag_map.json must contain all " "gold tags, to maintain coarse-grained mapping.") raise ValueError(msg % tag) - golds = [self.tag_names.index(g) if g is not None else -1 for g in gold_tag_strs] + cdef Pool mem = Pool() + golds = mem.alloc(sizeof(int), len(gold_tag_strs)) + for i, g in enumerate(gold_tag_strs): + golds[i] = self.tag_names.index(g) if g is not None else -1 + + cdef atom_t[N_CONTEXT_FIELDS] context + cdef int nr_class = self.model.nr_class + costs = mem.alloc(sizeof(weight_t), nr_class) + features = mem.alloc(sizeof(FeatureC), self.model.nr_feat) + scores = mem.alloc(sizeof(weight_t), nr_class) + d_scores = mem.alloc(sizeof(weight_t), nr_class) + cdef int correct = 0 - - cdef atom_t[:, ::1] context = np.zeros((1, N_CONTEXT_FIELDS), dtype='uint64') - cdef float[:, ::1] scores - for i in range(tokens.length): - fill_context(&context[0, 0], tokens.c, i) - scores, finish_update = self.model.begin_update(context) - guess = Vec.arg_max(&scores[0, 0], nr_class) - self.vocab.morphology.assign_tag_id(&tokens.c[i], guess) + nr_feat = self.model.set_featuresC(features, context, tokens.c, i) + self.model.set_scoresC(scores, + features, nr_feat) if golds[i] != -1: - scores[0, golds[i]] -= 1 - finish_update(scores, lambda *args, **kwargs: None) + for j in range(nr_class): + costs[j] = 1 + costs[golds[i]] = 0 + self.model.log_lossC(d_scores, scores, costs) + self.model.set_gradientC(d_scores, features, nr_feat) + + guess = Vec.arg_max(scores, nr_class) + #print(tokens[i].text, golds[i], guess, [features[i].key for i in range(nr_feat)]) + + self.vocab.morphology.assign_tag_id(&tokens.c[i], guess) - if (golds[i] in (guess, -1)): - correct += 1 self.freqs[TAG][tokens.c[i].tag] += 1 - self.optimizer(self.model.model.weights, self.model.model.d_weights, - key=self.model.model.id) + correct += costs[guess] == 0 + + memset(features, 0, sizeof(FeatureC) * nr_feat) + memset(costs, 0, sizeof(weight_t) * nr_class) + memset(scores, 0, sizeof(weight_t) * nr_class) + memset(d_scores, 0, sizeof(weight_t) * nr_class) + + #if itn % 10 == 0: + # self.optimizer(self.model.weights.ravel(), self.model.d_weights.ravel(), + # key=1) tokens.is_tagged = True tokens._py_tokens = [None] * tokens.length return correct diff --git a/spacy/train.py b/spacy/train.py index 175c99cf2..2f8748791 100644 --- a/spacy/train.py +++ b/spacy/train.py @@ -14,6 +14,7 @@ class Trainer(object): self.nlp = nlp self.gold_tuples = gold_tuples self.nr_epoch = 0 + self.nr_itn = 0 def epochs(self, nr_epoch, augment_data=None, gold_preproc=False): cached_golds = {} @@ -36,6 +37,7 @@ class Trainer(object): golds = self.make_golds(docs, paragraph_tuples) for doc, gold in zip(docs, golds): yield doc, gold + self.nr_itn += 1 indices = list(range(len(self.gold_tuples))) for itn in range(nr_epoch): @@ -46,7 +48,7 @@ class Trainer(object): def update(self, doc, gold): for process in self.nlp.pipeline: if hasattr(process, 'update'): - loss = process.update(doc, gold, itn=self.nr_epoch) + loss = process.update(doc, gold, itn=self.nr_itn) process(doc) return doc From abb209f631d4623196892db742f091148464318a Mon Sep 17 00:00:00 2001 From: Matthew Honnibal Date: Fri, 24 Mar 2017 00:23:32 +0100 Subject: [PATCH 4/7] Track which indices are being used --- spacy/_ml.pxd | 4 +++- spacy/_ml.pyx | 20 +++++++++++++++++--- 2 files changed, 20 insertions(+), 4 deletions(-) diff --git a/spacy/_ml.pxd b/spacy/_ml.pxd index 4f2f42427..fdf7b359c 100644 --- a/spacy/_ml.pxd +++ b/spacy/_ml.pxd @@ -1,9 +1,10 @@ from thinc.linear.features cimport ConjunctionExtracter from thinc.typedefs cimport atom_t, weight_t from thinc.structs cimport FeatureC -from libc.stdint cimport uint32_t +from libc.stdint cimport uint32_t, uint64_t cimport numpy as np from cymem.cymem cimport Pool +from libcpp.vector cimport vector cdef class LinearModel: @@ -14,6 +15,7 @@ cdef class LinearModel: cdef Pool mem cdef weight_t* W cdef weight_t* d_W + cdef vector[uint64_t]* _indices cdef void hinge_lossC(self, weight_t* d_scores, const weight_t* scores, const weight_t* costs) nogil diff --git a/spacy/_ml.pyx b/spacy/_ml.pyx index c3413f561..bec5b0cbc 100644 --- a/spacy/_ml.pyx +++ b/spacy/_ml.pyx @@ -15,6 +15,9 @@ cimport numpy as np import numpy np.import_array() +from thinc.neural.optimizers import Adam +from thinc.neural.ops import NumpyOps + cdef class LinearModel: def __init__(self, int nr_class, templates, weight_t learn_rate=0.001, @@ -28,6 +31,10 @@ cdef class LinearModel: sizeof(weight_t)) self.d_W = self.mem.alloc(self.nr_weight * self.nr_class, sizeof(weight_t)) + self._indices = new vector[uint64_t]() + + def __dealloc__(self): + del self._indices cdef void hinge_lossC(self, weight_t* d_scores, const weight_t* scores, const weight_t* costs) nogil: @@ -129,12 +136,19 @@ cdef class LinearModel: # Sort them, to improve memory access pattern libcpp.algorithm.sort(indices.begin(), indices.end()) for idx in indices: - W = &self.W[idx * nr_class] + d_W = &self.d_W[idx * nr_class] for clas in range(nr_class): if d_scores[clas] < 0: - W[clas] -= self.learn_rate * max(-10., d_scores[clas]) + d_W[clas] += max(-10., d_scores[clas]) else: - W[clas] -= self.learn_rate * min(10., d_scores[clas]) + d_W[clas] += min(10., d_scores[clas]) + + def finish_update(self, optimizer): + cdef np.npy_intp[1] shape + shape[0] = self.nr_weight * self.nr_class + W_arr = np.PyArray_SimpleNewFromData(1, shape, np.NPY_FLOAT, self.W) + dW_arr = np.PyArray_SimpleNewFromData(1, shape, np.NPY_FLOAT, self.d_W) + optimizer(W_arr, dW_arr, key=1) @property def nr_active_feat(self): From 6c31a7222f63753a27afaab1a2e655a66cf46c87 Mon Sep 17 00:00:00 2001 From: Matthew Honnibal Date: Fri, 24 Mar 2017 00:23:59 +0100 Subject: [PATCH 5/7] Remove incorrect feature zeroing --- spacy/syntax/_parse_features.pyx | 1 - 1 file changed, 1 deletion(-) diff --git a/spacy/syntax/_parse_features.pyx b/spacy/syntax/_parse_features.pyx index bc54e0c9d..36a78c638 100644 --- a/spacy/syntax/_parse_features.pyx +++ b/spacy/syntax/_parse_features.pyx @@ -33,7 +33,6 @@ cdef inline void fill_token(atom_t* context, const TokenC* token) nogil: context[9] = 0 context[10] = 0 context[11] = 0 - context[12] = 0 else: context[0] = token.lex.orth context[1] = token.lemma From 1f292bfd17a59cce08ddaf5e3c6e866cf5d0ec65 Mon Sep 17 00:00:00 2001 From: Matthew Honnibal Date: Thu, 30 Mar 2017 02:35:36 +0200 Subject: [PATCH 6/7] Play with hash kernel class --- spacy/_ml.pxd | 7 ++++-- spacy/_ml.pyx | 64 +++++++++++++++++++++++++++++++++------------------ 2 files changed, 47 insertions(+), 24 deletions(-) diff --git a/spacy/_ml.pxd b/spacy/_ml.pxd index fdf7b359c..8a5d35573 100644 --- a/spacy/_ml.pxd +++ b/spacy/_ml.pxd @@ -12,10 +12,13 @@ cdef class LinearModel: cdef readonly int nr_class cdef readonly uint32_t nr_weight cdef public weight_t learn_rate + cdef public weight_t momentum cdef Pool mem + cdef weight_t time cdef weight_t* W - cdef weight_t* d_W - cdef vector[uint64_t]* _indices + cdef weight_t* mom + cdef weight_t* averages + cdef weight_t* last_upd cdef void hinge_lossC(self, weight_t* d_scores, const weight_t* scores, const weight_t* costs) nogil diff --git a/spacy/_ml.pyx b/spacy/_ml.pyx index bec5b0cbc..582ea3624 100644 --- a/spacy/_ml.pyx +++ b/spacy/_ml.pyx @@ -20,21 +20,23 @@ from thinc.neural.ops import NumpyOps cdef class LinearModel: - def __init__(self, int nr_class, templates, weight_t learn_rate=0.001, - size=2**18): + def __init__(self, int nr_class, templates, + weight_t momentum=0.9, weight_t learn_rate=0.001, size=2**18): self.extracter = ConjunctionExtracter(templates) self.nr_weight = size self.nr_class = nr_class self.learn_rate = learn_rate + self.momentum = momentum self.mem = Pool() + self.time = 0 self.W = self.mem.alloc(self.nr_weight * self.nr_class, sizeof(weight_t)) - self.d_W = self.mem.alloc(self.nr_weight * self.nr_class, + self.mom = self.mem.alloc(self.nr_weight * self.nr_class, + sizeof(weight_t)) + self.averages = self.mem.alloc(self.nr_weight * self.nr_class, + sizeof(weight_t)) + self.last_upd = self.mem.alloc(self.nr_weight * self.nr_class, sizeof(weight_t)) - self._indices = new vector[uint64_t]() - - def __dealloc__(self): - del self._indices cdef void hinge_lossC(self, weight_t* d_scores, const weight_t* scores, const weight_t* costs) nogil: @@ -97,8 +99,8 @@ cdef class LinearModel: cdef void set_scoresC(self, weight_t* scores, const FeatureC* features, int nr_feat) nogil: - cdef uint64_t nr_weight = self.nr_weight cdef int nr_class = self.nr_class + cdef uint64_t nr_weight = self.nr_weight * nr_class - nr_class cdef vector[uint64_t] indices # Collect all feature indices cdef uint32_t[2] hashed @@ -114,16 +116,23 @@ cdef class LinearModel: # Sort them, to improve memory access pattern libcpp.algorithm.sort(indices.begin(), indices.end()) for idx in indices: - W = &self.W[idx * nr_class] + W = &self.W[idx] for clas in range(nr_class): scores[clas] += W[clas] cdef void set_gradientC(self, const weight_t* d_scores, const FeatureC* features, int nr_feat) nogil: - cdef uint64_t nr_weight = self.nr_weight + self.time += 1 cdef int nr_class = self.nr_class + cdef weight_t abs_grad = 0 + for i in range(nr_class): + abs_grad += d_scores[i] if d_scores[i] > 0 else -d_scores[i] + if abs_grad < 0.1: + return + cdef uint64_t nr_weight = self.nr_weight * nr_class - nr_class cdef vector[uint64_t] indices # Collect all feature indices + indices.reserve(nr_feat * 2) cdef uint32_t[2] hashed cdef uint64_t hash2 for feat in features[:nr_feat]: @@ -136,19 +145,24 @@ cdef class LinearModel: # Sort them, to improve memory access pattern libcpp.algorithm.sort(indices.begin(), indices.end()) for idx in indices: - d_W = &self.d_W[idx * nr_class] - for clas in range(nr_class): - if d_scores[clas] < 0: - d_W[clas] += max(-10., d_scores[clas]) - else: - d_W[clas] += min(10., d_scores[clas]) + #avg = &self.averages[idx] + #last_upd = &self.last_upd[idx] + W = &self.W[idx] + #mom = &self.mom[idx] + for i in range(nr_class): + if d_scores[i] == 0: + continue + d = d_scores[i] + W[i] -= self.learn_rate * d + #unchanged = self.time - last_upd[i] + #avg[i] += unchanged * W[i] + #mom[i] *= self.momentum ** unchanged + #mom[i] += self.learn_rate * d + #W[i] -= mom[i] + #last_upd[i] = self.time def finish_update(self, optimizer): - cdef np.npy_intp[1] shape - shape[0] = self.nr_weight * self.nr_class - W_arr = np.PyArray_SimpleNewFromData(1, shape, np.NPY_FLOAT, self.W) - dW_arr = np.PyArray_SimpleNewFromData(1, shape, np.NPY_FLOAT, self.d_W) - optimizer(W_arr, dW_arr, key=1) + pass @property def nr_active_feat(self): @@ -159,7 +173,13 @@ cdef class LinearModel: return self.extracter.nr_templ def end_training(self, *args, **kwargs): - pass + # Average weights + for i in range(self.nr_weight * self.nr_class): + unchanged = self.time - self.last_upd[i] + self.averages[i] += self.W[i] * unchanged + self.W[i], self.averages[i] = self.averages[i], self.W[i] + self.W[i] /= self.time + self.last_upd[i] = self.time def dump(self, *args, **kwargs): pass From 2a91d641e6ccafb888121c50b41c4fd90a00a816 Mon Sep 17 00:00:00 2001 From: Matthew Honnibal Date: Thu, 30 Mar 2017 02:36:33 +0200 Subject: [PATCH 7/7] Add dropout to parser --- spacy/syntax/parser.pyx | 26 ++++++++++++++++++-------- 1 file changed, 18 insertions(+), 8 deletions(-) diff --git a/spacy/syntax/parser.pyx b/spacy/syntax/parser.pyx index dc157d13d..cffa96423 100644 --- a/spacy/syntax/parser.pyx +++ b/spacy/syntax/parser.pyx @@ -193,13 +193,11 @@ cdef class Parser: elif 'features' not in cfg: cfg['features'] = self.feature_templates self.model = ParserModel(self.moves.n_moves, cfg['features'], - size=2**18, + size=2**14, learn_rate=cfg.get('learn_rate', 0.001)) - #self.model.l1_penalty = cfg.get('L1', 1e-8) - #self.model.learn_rate = cfg.get('learn_rate', 0.001) + #self.model.l1_penalty = cfg.get('L1', 0.0) - self.optimizer = SGD(NumpyOps(), cfg.get('learn_rate', 0.001), - momentum=0.9) + self.optimizer = Adam(NumpyOps(), cfg.get('learn_rate', 0.001)) self.cfg = cfg @@ -337,9 +335,19 @@ cdef class Parser: cdef Transition action words = [w.text for w in tokens] + cdef int i + cdef double[::1] py_dropout + cdef double* dropout while not stcls.is_final(): nr_feat = self.model.set_featuresC(context, features, stcls.c) + py_dropout = numpy.random.uniform(0., 1., nr_feat) + dropout = &py_dropout[0] + for i in range(nr_feat): + if dropout[i] < 0.5: + features[i].value = 0 + else: + features[i].value *= 2 self.moves.set_costs(is_valid, costs, stcls, gold) self.model.set_scoresC(scores, features, nr_feat) @@ -347,6 +355,9 @@ cdef class Parser: best = arg_max_if_gold(scores, costs, nr_class) self.model.regression_lossC(d_scores, scores, costs) + for i in range(nr_class): + if not is_valid[i]: + d_scores[i] = 0 self.model.set_gradientC(d_scores, features, nr_feat) action = self.moves.c[guess] @@ -354,7 +365,7 @@ cdef class Parser: #print(scores[guess], scores[best], d_scores[guess], costs[guess], # self.moves.move_name(action.move, action.label), stcls.print_state(words)) - loss += scores[guess] + loss += abs(scores[guess] + costs[guess]) memset(context, 0, sizeof(context)) memset(features, 0, sizeof(features[0]) * nr_feat) memset(scores, 0, sizeof(scores[0]) * nr_class) @@ -363,8 +374,7 @@ cdef class Parser: for i in range(nr_class): is_valid[i] = 1 #if itn % 100 == 0: - # self.optimizer(self.model.model[0].ravel(), - # self.model.model[1].ravel(), key=1) + # self.model.finish_update(self.optimizer) return loss def step_through(self, Doc doc):