diff --git a/examples/chainer_sentiment.py b/examples/chainer_sentiment.py index ac3881e75..747ef508a 100644 --- a/examples/chainer_sentiment.py +++ b/examples/chainer_sentiment.py @@ -3,6 +3,9 @@ import plac import random import six +import cProfile +import pstats + import pathlib import cPickle as pickle from itertools import izip @@ -81,7 +84,7 @@ class SentimentModel(Chain): def __init__(self, nlp, shape, **settings): Chain.__init__(self, embed=_Embed(shape['nr_vector'], shape['nr_dim'], shape['nr_hidden'], - initialW=lambda arr: set_vectors(arr, nlp.vocab)), + set_vectors=lambda arr: set_vectors(arr, nlp.vocab)), encode=_Encode(shape['nr_hidden'], shape['nr_hidden']), attend=_Attend(shape['nr_hidden'], shape['nr_hidden']), predict=_Predict(shape['nr_hidden'], shape['nr_class'])) @@ -95,11 +98,11 @@ class SentimentModel(Chain): class _Embed(Chain): - def __init__(self, nr_vector, nr_dim, nr_out): + def __init__(self, nr_vector, nr_dim, nr_out, set_vectors=None): Chain.__init__(self, - embed=L.EmbedID(nr_vector, nr_dim), + embed=L.EmbedID(nr_vector, nr_dim, initialW=set_vectors), project=L.Linear(None, nr_out, nobias=True)) - #self.embed.unchain_backward() + self.embed.W.volatile = False def __call__(self, sentence): return [self.project(self.embed(ts)) for ts in F.transpose(sentence)] @@ -214,7 +217,6 @@ def set_vectors(vectors, vocab): vectors[lex.rank + 1] = lex.vector else: lex.norm = 0 - vectors.unchain_backwards() return vectors @@ -223,7 +225,9 @@ def train(train_texts, train_labels, dev_texts, dev_labels, by_sentence=True): nlp = spacy.load('en', entity=False) if 'nr_vector' not in lstm_shape: - lstm_shape['nr_vector'] = max(lex.rank+1 for lex in vocab if lex.has_vector) + lstm_shape['nr_vector'] = max(lex.rank+1 for lex in nlp.vocab if lex.has_vector) + if 'nr_dim' not in lstm_shape: + lstm_shape['nr_dim'] = nlp.vocab.vectors_length print("Make model") model = Classifier(SentimentModel(nlp, lstm_shape, **lstm_settings)) print("Parsing texts...") @@ -240,7 +244,7 @@ def train(train_texts, train_labels, dev_texts, dev_labels, optimizer = chainer.optimizers.Adam() optimizer.setup(model) updater = chainer.training.StandardUpdater(train_iter, optimizer, device=0) - trainer = chainer.training.Trainer(updater, (20, 'epoch'), out='result') + trainer = chainer.training.Trainer(updater, (1, 'epoch'), out='result') trainer.extend(extensions.Evaluator(dev_iter, model, device=0)) trainer.extend(extensions.LogReport()) @@ -305,11 +309,14 @@ def main(model_dir, train_dir, dev_dir, dev_labels = xp.asarray(dev_labels, dtype='i') lstm = train(train_texts, train_labels, dev_texts, dev_labels, {'nr_hidden': nr_hidden, 'max_length': max_length, 'nr_class': 2, - 'nr_vector': 2000, 'nr_dim': 32}, + 'nr_vector': 5000}, {'dropout': 0.5, 'lr': learn_rate}, {}, nb_epoch=nb_epoch, batch_size=batch_size) if __name__ == '__main__': + #cProfile.runctx("plac.call(main)", globals(), locals(), "Profile.prof") + #s = pstats.Stats("Profile.prof") + #s.strip_dirs().sort_stats("time").print_stats() plac.call(main)