* Remove hard-coding of vector lengths

This commit is contained in:
Matthew Honnibal 2015-06-27 04:18:47 +02:00
parent 7b8275fcc4
commit e7003f1cf3
2 changed files with 3 additions and 3 deletions

View File

@ -28,7 +28,7 @@ cdef class TheanoModel(Model):
self.model_loc = model_loc self.model_loc = model_loc
def predict(self, Example eg): def predict(self, Example eg):
self.input_layer.fill(eg.embeddings, eg.atoms) self.input_layer.fill(eg.embeddings, eg.atoms, use_avg=True)
theano_scores = self.predict_func(eg.embeddings)[0] theano_scores = self.predict_func(eg.embeddings)[0]
cdef int i cdef int i
for i in range(self.n_classes): for i in range(self.n_classes):
@ -37,7 +37,7 @@ cdef class TheanoModel(Model):
self.n_classes) self.n_classes)
def train(self, Example eg): def train(self, Example eg):
self.input_layer.fill(eg.embeddings, eg.atoms) self.input_layer.fill(eg.embeddings, eg.atoms, use_avg=False)
theano_scores, update, y = self.train_func(eg.embeddings, eg.costs, self.eta) theano_scores, update, y = self.train_func(eg.embeddings, eg.costs, self.eta)
self.input_layer.update(update, eg.atoms, self.t, self.eta, self.mu) self.input_layer.update(update, eg.atoms, self.t, self.eta, self.mu)
for i in range(self.n_classes): for i in range(self.n_classes):

View File

@ -52,7 +52,7 @@ def get_templates(name):
elif name == 'debug': elif name == 'debug':
return pf.unigrams return pf.unigrams
elif name.startswith('embed'): elif name.startswith('embed'):
return ((10, pf.words), (10, pf.tags), (10, pf.labels)) return (pf.words, pf.tags, pf.labels)
else: else:
return (pf.unigrams + pf.s0_n0 + pf.s1_n0 + pf.s1_s0 + pf.s0_n1 + pf.n0_n1 + \ return (pf.unigrams + pf.s0_n0 + pf.s1_n0 + pf.s1_s0 + pf.s0_n1 + pf.n0_n1 + \
pf.tree_shape + pf.trigrams) pf.tree_shape + pf.trigrams)