2015-06-26 11:51:39 +00:00
|
|
|
from thinc.api cimport Example
|
|
|
|
from thinc.typedefs cimport weight_t
|
|
|
|
|
|
|
|
from ._ml cimport arg_max_if_true
|
|
|
|
from ._ml cimport arg_max_if_zero
|
|
|
|
|
|
|
|
import numpy
|
|
|
|
from os import path
|
2015-06-23 20:55:58 +00:00
|
|
|
|
|
|
|
|
|
|
|
cdef class TheanoModel(Model):
|
2015-06-27 00:38:51 +00:00
|
|
|
def __init__(self, n_classes, input_spec, train_func, predict_func, model_loc=None,
|
|
|
|
debug=None):
|
2015-06-23 20:55:58 +00:00
|
|
|
if model_loc is not None and path.isdir(model_loc):
|
|
|
|
model_loc = path.join(model_loc, 'model')
|
|
|
|
|
2015-06-26 11:51:39 +00:00
|
|
|
self.eta = 0.001
|
|
|
|
self.mu = 0.9
|
|
|
|
self.t = 1
|
|
|
|
initializer = lambda: 0.2 * numpy.random.uniform(-1.0, 1.0)
|
|
|
|
self.input_layer = InputLayer(input_spec, initializer)
|
2015-06-23 20:55:58 +00:00
|
|
|
self.train_func = train_func
|
|
|
|
self.predict_func = predict_func
|
2015-06-27 00:38:51 +00:00
|
|
|
self.debug = debug
|
2015-06-23 20:55:58 +00:00
|
|
|
|
2015-06-26 11:51:39 +00:00
|
|
|
self.n_classes = n_classes
|
|
|
|
self.n_feats = len(self.input_layer)
|
2015-06-23 20:55:58 +00:00
|
|
|
self.model_loc = model_loc
|
2015-06-26 11:51:39 +00:00
|
|
|
|
|
|
|
def predict(self, Example eg):
|
2015-06-27 02:18:47 +00:00
|
|
|
self.input_layer.fill(eg.embeddings, eg.atoms, use_avg=True)
|
2015-06-27 00:38:51 +00:00
|
|
|
theano_scores = self.predict_func(eg.embeddings)[0]
|
2015-06-26 11:51:39 +00:00
|
|
|
cdef int i
|
2015-06-23 20:55:58 +00:00
|
|
|
for i in range(self.n_classes):
|
2015-06-26 11:51:39 +00:00
|
|
|
eg.scores[i] = theano_scores[i]
|
|
|
|
eg.guess = arg_max_if_true(<weight_t*>eg.scores.data, <int*>eg.is_valid.data,
|
|
|
|
self.n_classes)
|
|
|
|
|
|
|
|
def train(self, Example eg):
|
2015-06-27 02:18:47 +00:00
|
|
|
self.input_layer.fill(eg.embeddings, eg.atoms, use_avg=False)
|
2015-06-27 00:38:51 +00:00
|
|
|
theano_scores, update, y = self.train_func(eg.embeddings, eg.costs, self.eta)
|
|
|
|
self.input_layer.update(update, eg.atoms, self.t, self.eta, self.mu)
|
|
|
|
for i in range(self.n_classes):
|
|
|
|
eg.scores[i] = theano_scores[i]
|
|
|
|
eg.guess = arg_max_if_true(<weight_t*>eg.scores.data, <int*>eg.is_valid.data,
|
|
|
|
self.n_classes)
|
2015-06-26 11:51:39 +00:00
|
|
|
eg.best = arg_max_if_zero(<weight_t*>eg.scores.data, <int*>eg.costs.data,
|
|
|
|
self.n_classes)
|
|
|
|
eg.cost = eg.costs[eg.guess]
|
|
|
|
self.t += 1
|
2015-06-27 00:38:51 +00:00
|
|
|
|
|
|
|
def end_training(self):
|
|
|
|
pass
|