Set vectors in chainer example

This commit is contained in:
Matthew Honnibal 2016-11-19 18:42:58 -06:00
parent b701a08249
commit 1ed40682a3
1 changed files with 15 additions and 8 deletions

View File

@ -3,6 +3,9 @@ import plac
import random import random
import six import six
import cProfile
import pstats
import pathlib import pathlib
import cPickle as pickle import cPickle as pickle
from itertools import izip from itertools import izip
@ -81,7 +84,7 @@ class SentimentModel(Chain):
def __init__(self, nlp, shape, **settings): def __init__(self, nlp, shape, **settings):
Chain.__init__(self, Chain.__init__(self,
embed=_Embed(shape['nr_vector'], shape['nr_dim'], shape['nr_hidden'], embed=_Embed(shape['nr_vector'], shape['nr_dim'], shape['nr_hidden'],
initialW=lambda arr: set_vectors(arr, nlp.vocab)), set_vectors=lambda arr: set_vectors(arr, nlp.vocab)),
encode=_Encode(shape['nr_hidden'], shape['nr_hidden']), encode=_Encode(shape['nr_hidden'], shape['nr_hidden']),
attend=_Attend(shape['nr_hidden'], shape['nr_hidden']), attend=_Attend(shape['nr_hidden'], shape['nr_hidden']),
predict=_Predict(shape['nr_hidden'], shape['nr_class'])) predict=_Predict(shape['nr_hidden'], shape['nr_class']))
@ -95,11 +98,11 @@ class SentimentModel(Chain):
class _Embed(Chain): class _Embed(Chain):
def __init__(self, nr_vector, nr_dim, nr_out): def __init__(self, nr_vector, nr_dim, nr_out, set_vectors=None):
Chain.__init__(self, Chain.__init__(self,
embed=L.EmbedID(nr_vector, nr_dim), embed=L.EmbedID(nr_vector, nr_dim, initialW=set_vectors),
project=L.Linear(None, nr_out, nobias=True)) project=L.Linear(None, nr_out, nobias=True))
#self.embed.unchain_backward() self.embed.W.volatile = False
def __call__(self, sentence): def __call__(self, sentence):
return [self.project(self.embed(ts)) for ts in F.transpose(sentence)] return [self.project(self.embed(ts)) for ts in F.transpose(sentence)]
@ -214,7 +217,6 @@ def set_vectors(vectors, vocab):
vectors[lex.rank + 1] = lex.vector vectors[lex.rank + 1] = lex.vector
else: else:
lex.norm = 0 lex.norm = 0
vectors.unchain_backwards()
return vectors return vectors
@ -223,7 +225,9 @@ def train(train_texts, train_labels, dev_texts, dev_labels,
by_sentence=True): by_sentence=True):
nlp = spacy.load('en', entity=False) nlp = spacy.load('en', entity=False)
if 'nr_vector' not in lstm_shape: if 'nr_vector' not in lstm_shape:
lstm_shape['nr_vector'] = max(lex.rank+1 for lex in vocab if lex.has_vector) lstm_shape['nr_vector'] = max(lex.rank+1 for lex in nlp.vocab if lex.has_vector)
if 'nr_dim' not in lstm_shape:
lstm_shape['nr_dim'] = nlp.vocab.vectors_length
print("Make model") print("Make model")
model = Classifier(SentimentModel(nlp, lstm_shape, **lstm_settings)) model = Classifier(SentimentModel(nlp, lstm_shape, **lstm_settings))
print("Parsing texts...") print("Parsing texts...")
@ -240,7 +244,7 @@ def train(train_texts, train_labels, dev_texts, dev_labels,
optimizer = chainer.optimizers.Adam() optimizer = chainer.optimizers.Adam()
optimizer.setup(model) optimizer.setup(model)
updater = chainer.training.StandardUpdater(train_iter, optimizer, device=0) updater = chainer.training.StandardUpdater(train_iter, optimizer, device=0)
trainer = chainer.training.Trainer(updater, (20, 'epoch'), out='result') trainer = chainer.training.Trainer(updater, (1, 'epoch'), out='result')
trainer.extend(extensions.Evaluator(dev_iter, model, device=0)) trainer.extend(extensions.Evaluator(dev_iter, model, device=0))
trainer.extend(extensions.LogReport()) trainer.extend(extensions.LogReport())
@ -305,11 +309,14 @@ def main(model_dir, train_dir, dev_dir,
dev_labels = xp.asarray(dev_labels, dtype='i') dev_labels = xp.asarray(dev_labels, dtype='i')
lstm = train(train_texts, train_labels, dev_texts, dev_labels, lstm = train(train_texts, train_labels, dev_texts, dev_labels,
{'nr_hidden': nr_hidden, 'max_length': max_length, 'nr_class': 2, {'nr_hidden': nr_hidden, 'max_length': max_length, 'nr_class': 2,
'nr_vector': 2000, 'nr_dim': 32}, 'nr_vector': 5000},
{'dropout': 0.5, 'lr': learn_rate}, {'dropout': 0.5, 'lr': learn_rate},
{}, {},
nb_epoch=nb_epoch, batch_size=batch_size) nb_epoch=nb_epoch, batch_size=batch_size)
if __name__ == '__main__': if __name__ == '__main__':
#cProfile.runctx("plac.call(main)", globals(), locals(), "Profile.prof")
#s = pstats.Stats("Profile.prof")
#s.strip_dirs().sort_stats("time").print_stats()
plac.call(main) plac.call(main)