diff --git a/spacy/tagger.pxd b/spacy/tagger.pxd index 30626e775..51e465188 100644 --- a/spacy/tagger.pxd +++ b/spacy/tagger.pxd @@ -1,14 +1,13 @@ -from thinc.api cimport AveragedPerceptron -from thinc.api cimport ExampleC +from thinc.linear.avgtron cimport AveragedPerceptron +from thinc.extra.eg cimport Example +from thinc.structs cimport ExampleC from .structs cimport TokenC from .vocab cimport Vocab cdef class TaggerModel(AveragedPerceptron): - cdef void set_features(self, ExampleC* eg, const TokenC* tokens, int i) except * - cdef void set_costs(self, ExampleC* eg, int gold) except * - cdef void update(self, ExampleC* eg) except * + cdef void set_featuresC(self, ExampleC* eg, const TokenC* tokens, int i) except * cdef class Tagger: diff --git a/spacy/tagger.pyx b/spacy/tagger.pyx index 493cc4f99..81b9b2da0 100644 --- a/spacy/tagger.pyx +++ b/spacy/tagger.pyx @@ -5,8 +5,10 @@ from libc.string cimport memset from cymem.cymem cimport Pool from thinc.typedefs cimport atom_t, weight_t -from thinc.api cimport Example, ExampleC -from thinc.features cimport ConjunctionExtracter +from thinc.extra.eg cimport Example +from thinc.structs cimport ExampleC +from thinc.linear.avgtron cimport AveragedPerceptron +from thinc.linalg cimport VecVec from .typedefs cimport attr_t from .tokens.doc cimport Doc @@ -69,7 +71,7 @@ cpdef enum: cdef class TaggerModel(AveragedPerceptron): - cdef void set_features(self, ExampleC* eg, const TokenC* tokens, int i) except *: + cdef void set_featuresC(self, ExampleC* eg, const TokenC* tokens, int i) except *: _fill_from_token(&eg.atoms[P2_orth], &tokens[i-2]) _fill_from_token(&eg.atoms[P1_orth], &tokens[i-1]) _fill_from_token(&eg.atoms[W_orth], &tokens[i]) @@ -78,9 +80,6 @@ cdef class TaggerModel(AveragedPerceptron): eg.nr_feat = self.extracter.set_features(eg.features, eg.atoms) - cdef void update(self, ExampleC* eg) except *: - self.updater.update(eg) - cdef inline void _fill_from_token(atom_t* context, const TokenC* t) nogil: context[0] = t.lex.lower @@ -143,8 +142,7 @@ cdef class Tagger: @classmethod def blank(cls, vocab, templates): - model = TaggerModel(vocab.morphology.n_tags, - ConjunctionExtracter(N_CONTEXT_FIELDS, templates)) + model = TaggerModel(N_CONTEXT_FIELDS, templates) return cls(vocab, model) @classmethod @@ -159,13 +157,9 @@ cdef class Tagger: # 'pos', 'templates.json', # default=cls.default_templates()) - model = TaggerModel(vocab.morphology.n_tags, - ConjunctionExtracter(N_CONTEXT_FIELDS, templates)) - - - if pkg.has_file('pos', 'model'): # TODO: really optional? + model = TaggerModel(templates) + if pkg.has_file('pos', 'model'): model.load(pkg.file_path('pos', 'model')) - return cls(vocab, model) def __init__(self, Vocab vocab, TaggerModel model): @@ -202,15 +196,16 @@ cdef class Tagger: return 0 cdef Pool mem = Pool() - cdef ExampleC eg cdef int i, tag + cdef Example eg = Example(self.vocab.morphology.n_tags) for i in range(tokens.length): if tokens.c[i].pos == 0: - eg = self.model.allocate(mem) - self.model.set_features(&eg, tokens.c, i) - self.model.set_prediction(&eg) - self.vocab.morphology.assign_tag(&tokens.c[i], eg.guess) + self.model.set_featuresC(&eg.c, tokens.c, i) + self.model.set_scoresC(eg.c.scores, + eg.c.features, eg.c.nr_feat) + guess = VecVec.arg_max_if_true(eg.c.scores, eg.c.is_valid, eg.c.nr_class) + self.vocab.morphology.assign_tag(&tokens.c[i], guess) tokens.is_tagged = True tokens._py_tokens = [None] * tokens.length @@ -219,18 +214,20 @@ cdef class Tagger: golds = [self.tag_names.index(g) if g is not None else -1 for g in gold_tag_strs] cdef int correct = 0 cdef Pool mem = Pool() - cdef ExampleC eg + cdef Example eg = Example(self.vocab.morphology.n_tags) for i in range(tokens.length): - eg = self.model.allocate(mem) - self.model.set_features(&eg, tokens.c, i) - self.model.set_costs(&eg, golds[i]) - self.model.set_prediction(&eg) - self.model.update(&eg) + self.model.set_featuresC(&eg.c, tokens.c, i) + eg.set_label(golds[i]) + self.model.set_scoresC(eg.c.scores, + eg.c.features, eg.c.nr_feat) + + self.model.updateC(&eg.c) self.vocab.morphology.assign_tag(&tokens.c[i], eg.guess) correct += eg.cost == 0 self.freqs[TAG][tokens.c[i].tag] += 1 + eg.wipe(tuple()) tokens.is_tagged = True tokens._py_tokens = [None] * tokens.length return correct