mirror of https://github.com/explosion/spaCy.git
* Make constructor of ParserModel and TaggerModel the same as AveragedPerceptron, for each pickling.
This commit is contained in:
parent
1cfa20fb17
commit
6f47074214
|
@ -63,10 +63,6 @@ def ParserFactory(transition_system):
|
|||
|
||||
|
||||
cdef class ParserModel(AveragedPerceptron):
|
||||
def __init__(self, n_classes, templates):
|
||||
AveragedPerceptron.__init__(self, n_classes,
|
||||
ConjunctionExtracter(CONTEXT_SIZE, templates))
|
||||
|
||||
cdef void set_features(self, ExampleC* eg, StateClass stcls) except *:
|
||||
fill_context(eg.atoms, stcls)
|
||||
eg.nr_feat = self.extracter.set_features(eg.features, eg.atoms)
|
||||
|
@ -86,7 +82,8 @@ cdef class Parser:
|
|||
cfg = Config.read(model_dir, 'config')
|
||||
moves = transition_system(strings, cfg.labels)
|
||||
templates = get_templates(cfg.features)
|
||||
model = ParserModel(moves.n_moves, templates)
|
||||
model = ParserModel(moves.n_moves,
|
||||
ConjunctionExtracter(CONTEXT_SIZE, templates))
|
||||
if path.exists(path.join(model_dir, 'model')):
|
||||
model.load(path.join(model_dir, 'model'))
|
||||
return cls(strings, moves, model)
|
||||
|
|
|
@ -67,10 +67,6 @@ cpdef enum:
|
|||
|
||||
|
||||
cdef class TaggerModel(AveragedPerceptron):
|
||||
def __init__(self, n_classes, templates):
|
||||
AveragedPerceptron.__init__(self, n_classes,
|
||||
ConjunctionExtracter(N_CONTEXT_FIELDS, templates))
|
||||
|
||||
cdef void set_features(self, ExampleC* eg, const TokenC* tokens, int i) except *:
|
||||
_fill_from_token(&eg.atoms[P2_orth], &tokens[i-2])
|
||||
_fill_from_token(&eg.atoms[P1_orth], &tokens[i-1])
|
||||
|
@ -145,7 +141,8 @@ cdef class Tagger:
|
|||
|
||||
@classmethod
|
||||
def blank(cls, vocab, templates):
|
||||
model = TaggerModel(vocab.morphology.n_tags, templates)
|
||||
model = TaggerModel(vocab.morphology.n_tags,
|
||||
ConjunctionExtracter(N_CONTEXT_FIELDS, templates))
|
||||
return cls(vocab, model)
|
||||
|
||||
@classmethod
|
||||
|
@ -154,7 +151,8 @@ cdef class Tagger:
|
|||
templates = json.loads(open(path.join(data_dir, 'templates.json')))
|
||||
else:
|
||||
templates = cls.default_templates()
|
||||
model = TaggerModel(vocab.morphology.n_tags, templates)
|
||||
model = TaggerModel(vocab.morphology.n_tags,
|
||||
ConjunctionExtracter(N_CONTEXT_FIELDS, templates))
|
||||
if path.exists(path.join(data_dir, 'model')):
|
||||
model.load(path.join(data_dir, 'model'))
|
||||
return cls(vocab, model)
|
||||
|
|
Loading…
Reference in New Issue