* Make constructor of ParserModel and TaggerModel the same as AveragedPerceptron, for each pickling.

This commit is contained in:
Matthew Honnibal 2015-11-07 18:25:17 +11:00
parent 1cfa20fb17
commit 6f47074214
2 changed files with 6 additions and 11 deletions

View File

@ -63,10 +63,6 @@ def ParserFactory(transition_system):
cdef class ParserModel(AveragedPerceptron):
def __init__(self, n_classes, templates):
AveragedPerceptron.__init__(self, n_classes,
ConjunctionExtracter(CONTEXT_SIZE, templates))
cdef void set_features(self, ExampleC* eg, StateClass stcls) except *:
fill_context(eg.atoms, stcls)
eg.nr_feat = self.extracter.set_features(eg.features, eg.atoms)
@ -86,7 +82,8 @@ cdef class Parser:
cfg = Config.read(model_dir, 'config')
moves = transition_system(strings, cfg.labels)
templates = get_templates(cfg.features)
model = ParserModel(moves.n_moves, templates)
model = ParserModel(moves.n_moves,
ConjunctionExtracter(CONTEXT_SIZE, templates))
if path.exists(path.join(model_dir, 'model')):
model.load(path.join(model_dir, 'model'))
return cls(strings, moves, model)

View File

@ -67,10 +67,6 @@ cpdef enum:
cdef class TaggerModel(AveragedPerceptron):
def __init__(self, n_classes, templates):
AveragedPerceptron.__init__(self, n_classes,
ConjunctionExtracter(N_CONTEXT_FIELDS, templates))
cdef void set_features(self, ExampleC* eg, const TokenC* tokens, int i) except *:
_fill_from_token(&eg.atoms[P2_orth], &tokens[i-2])
_fill_from_token(&eg.atoms[P1_orth], &tokens[i-1])
@ -145,7 +141,8 @@ cdef class Tagger:
@classmethod
def blank(cls, vocab, templates):
model = TaggerModel(vocab.morphology.n_tags, templates)
model = TaggerModel(vocab.morphology.n_tags,
ConjunctionExtracter(N_CONTEXT_FIELDS, templates))
return cls(vocab, model)
@classmethod
@ -154,7 +151,8 @@ cdef class Tagger:
templates = json.loads(open(path.join(data_dir, 'templates.json')))
else:
templates = cls.default_templates()
model = TaggerModel(vocab.morphology.n_tags, templates)
model = TaggerModel(vocab.morphology.n_tags,
ConjunctionExtracter(N_CONTEXT_FIELDS, templates))
if path.exists(path.join(data_dir, 'model')):
model.load(path.join(data_dir, 'model'))
return cls(vocab, model)