diff --git a/spacy/_ml.pxd b/spacy/_ml.pxd index 7024e88fc..e19a3a480 100644 --- a/spacy/_ml.pxd +++ b/spacy/_ml.pxd @@ -18,18 +18,10 @@ cdef int arg_max(const weight_t* scores, const int n_classes) nogil cdef class Model: cdef int n_classes - cdef int regularize(self, Feature* feats, int n, int a=*) except -1 + cdef const weight_t* score(self, atom_t* context, bint regularize) except NULL cdef int update(self, atom_t* context, class_t guess, class_t gold, int cost) except -1 - + cdef object model_loc cdef Extractor _extractor cdef LinearModel _model - - cdef inline const weight_t* score(self, atom_t* context, bint regularize) except NULL: - cdef int n_feats - feats = self._extractor.get_feats(context, &n_feats) - if regularize: - self.regularize(feats, n_feats, 3) - return self._model.get_scores(feats, n_feats) - diff --git a/spacy/_ml.pyx b/spacy/_ml.pyx index 3dffed611..a7599ecf6 100644 --- a/spacy/_ml.pyx +++ b/spacy/_ml.pyx @@ -33,6 +33,11 @@ cdef class Model: if self.model_loc and path.exists(self.model_loc): self._model.load(self.model_loc, freq_thresh=0) + cdef const weight_t* score(self, atom_t* context, bint regularize) except NULL: + cdef int n_feats + feats = self._extractor.get_feats(context, &n_feats) + return self._model.get_scores(feats, n_feats) + cdef int update(self, atom_t* context, class_t guess, class_t gold, int cost) except -1: cdef int n_feats if cost == 0: @@ -44,19 +49,6 @@ cdef class Model: count_feats(counts[guess], feats, n_feats, -cost) self._model.update(counts) - @cython.cdivision - @cython.boundscheck(False) - cdef int regularize(self, Feature* feats, int n, int a=3) except -1: - pass - # Disable this for now, while we investigate effect. - # Use the Zipfian corruptions technique from here: - # http://www.aclweb.org/anthology/N13-1077 - # This seems good for 0.1 - 0.3 % on OOD data. - #cdef int i - #cdef long[:] zipfs = numpy.random.zipf(a, n) - #for i in range(n): - # feats[i].value *= 1 / zipfs[i] - def end_training(self): self._model.end_training() self._model.dump(self.model_loc, freq_thresh=0)