mirror of https://github.com/explosion/spaCy.git
* Remove hard-coding of vector lengths
This commit is contained in:
parent
7b8275fcc4
commit
e7003f1cf3
|
@ -28,7 +28,7 @@ cdef class TheanoModel(Model):
|
|||
self.model_loc = model_loc
|
||||
|
||||
def predict(self, Example eg):
|
||||
self.input_layer.fill(eg.embeddings, eg.atoms)
|
||||
self.input_layer.fill(eg.embeddings, eg.atoms, use_avg=True)
|
||||
theano_scores = self.predict_func(eg.embeddings)[0]
|
||||
cdef int i
|
||||
for i in range(self.n_classes):
|
||||
|
@ -37,7 +37,7 @@ cdef class TheanoModel(Model):
|
|||
self.n_classes)
|
||||
|
||||
def train(self, Example eg):
|
||||
self.input_layer.fill(eg.embeddings, eg.atoms)
|
||||
self.input_layer.fill(eg.embeddings, eg.atoms, use_avg=False)
|
||||
theano_scores, update, y = self.train_func(eg.embeddings, eg.costs, self.eta)
|
||||
self.input_layer.update(update, eg.atoms, self.t, self.eta, self.mu)
|
||||
for i in range(self.n_classes):
|
||||
|
|
|
@ -52,7 +52,7 @@ def get_templates(name):
|
|||
elif name == 'debug':
|
||||
return pf.unigrams
|
||||
elif name.startswith('embed'):
|
||||
return ((10, pf.words), (10, pf.tags), (10, pf.labels))
|
||||
return (pf.words, pf.tags, pf.labels)
|
||||
else:
|
||||
return (pf.unigrams + pf.s0_n0 + pf.s1_n0 + pf.s1_s0 + pf.s0_n1 + pf.n0_n1 + \
|
||||
pf.tree_shape + pf.trigrams)
|
||||
|
|
Loading…
Reference in New Issue