mirror of https://github.com/explosion/spaCy.git
Set vectors in chainer example
This commit is contained in:
parent
b701a08249
commit
1ed40682a3
|
@ -3,6 +3,9 @@ import plac
|
||||||
import random
|
import random
|
||||||
import six
|
import six
|
||||||
|
|
||||||
|
import cProfile
|
||||||
|
import pstats
|
||||||
|
|
||||||
import pathlib
|
import pathlib
|
||||||
import cPickle as pickle
|
import cPickle as pickle
|
||||||
from itertools import izip
|
from itertools import izip
|
||||||
|
@ -81,7 +84,7 @@ class SentimentModel(Chain):
|
||||||
def __init__(self, nlp, shape, **settings):
|
def __init__(self, nlp, shape, **settings):
|
||||||
Chain.__init__(self,
|
Chain.__init__(self,
|
||||||
embed=_Embed(shape['nr_vector'], shape['nr_dim'], shape['nr_hidden'],
|
embed=_Embed(shape['nr_vector'], shape['nr_dim'], shape['nr_hidden'],
|
||||||
initialW=lambda arr: set_vectors(arr, nlp.vocab)),
|
set_vectors=lambda arr: set_vectors(arr, nlp.vocab)),
|
||||||
encode=_Encode(shape['nr_hidden'], shape['nr_hidden']),
|
encode=_Encode(shape['nr_hidden'], shape['nr_hidden']),
|
||||||
attend=_Attend(shape['nr_hidden'], shape['nr_hidden']),
|
attend=_Attend(shape['nr_hidden'], shape['nr_hidden']),
|
||||||
predict=_Predict(shape['nr_hidden'], shape['nr_class']))
|
predict=_Predict(shape['nr_hidden'], shape['nr_class']))
|
||||||
|
@ -95,11 +98,11 @@ class SentimentModel(Chain):
|
||||||
|
|
||||||
|
|
||||||
class _Embed(Chain):
|
class _Embed(Chain):
|
||||||
def __init__(self, nr_vector, nr_dim, nr_out):
|
def __init__(self, nr_vector, nr_dim, nr_out, set_vectors=None):
|
||||||
Chain.__init__(self,
|
Chain.__init__(self,
|
||||||
embed=L.EmbedID(nr_vector, nr_dim),
|
embed=L.EmbedID(nr_vector, nr_dim, initialW=set_vectors),
|
||||||
project=L.Linear(None, nr_out, nobias=True))
|
project=L.Linear(None, nr_out, nobias=True))
|
||||||
#self.embed.unchain_backward()
|
self.embed.W.volatile = False
|
||||||
|
|
||||||
def __call__(self, sentence):
|
def __call__(self, sentence):
|
||||||
return [self.project(self.embed(ts)) for ts in F.transpose(sentence)]
|
return [self.project(self.embed(ts)) for ts in F.transpose(sentence)]
|
||||||
|
@ -214,7 +217,6 @@ def set_vectors(vectors, vocab):
|
||||||
vectors[lex.rank + 1] = lex.vector
|
vectors[lex.rank + 1] = lex.vector
|
||||||
else:
|
else:
|
||||||
lex.norm = 0
|
lex.norm = 0
|
||||||
vectors.unchain_backwards()
|
|
||||||
return vectors
|
return vectors
|
||||||
|
|
||||||
|
|
||||||
|
@ -223,7 +225,9 @@ def train(train_texts, train_labels, dev_texts, dev_labels,
|
||||||
by_sentence=True):
|
by_sentence=True):
|
||||||
nlp = spacy.load('en', entity=False)
|
nlp = spacy.load('en', entity=False)
|
||||||
if 'nr_vector' not in lstm_shape:
|
if 'nr_vector' not in lstm_shape:
|
||||||
lstm_shape['nr_vector'] = max(lex.rank+1 for lex in vocab if lex.has_vector)
|
lstm_shape['nr_vector'] = max(lex.rank+1 for lex in nlp.vocab if lex.has_vector)
|
||||||
|
if 'nr_dim' not in lstm_shape:
|
||||||
|
lstm_shape['nr_dim'] = nlp.vocab.vectors_length
|
||||||
print("Make model")
|
print("Make model")
|
||||||
model = Classifier(SentimentModel(nlp, lstm_shape, **lstm_settings))
|
model = Classifier(SentimentModel(nlp, lstm_shape, **lstm_settings))
|
||||||
print("Parsing texts...")
|
print("Parsing texts...")
|
||||||
|
@ -240,7 +244,7 @@ def train(train_texts, train_labels, dev_texts, dev_labels,
|
||||||
optimizer = chainer.optimizers.Adam()
|
optimizer = chainer.optimizers.Adam()
|
||||||
optimizer.setup(model)
|
optimizer.setup(model)
|
||||||
updater = chainer.training.StandardUpdater(train_iter, optimizer, device=0)
|
updater = chainer.training.StandardUpdater(train_iter, optimizer, device=0)
|
||||||
trainer = chainer.training.Trainer(updater, (20, 'epoch'), out='result')
|
trainer = chainer.training.Trainer(updater, (1, 'epoch'), out='result')
|
||||||
|
|
||||||
trainer.extend(extensions.Evaluator(dev_iter, model, device=0))
|
trainer.extend(extensions.Evaluator(dev_iter, model, device=0))
|
||||||
trainer.extend(extensions.LogReport())
|
trainer.extend(extensions.LogReport())
|
||||||
|
@ -305,11 +309,14 @@ def main(model_dir, train_dir, dev_dir,
|
||||||
dev_labels = xp.asarray(dev_labels, dtype='i')
|
dev_labels = xp.asarray(dev_labels, dtype='i')
|
||||||
lstm = train(train_texts, train_labels, dev_texts, dev_labels,
|
lstm = train(train_texts, train_labels, dev_texts, dev_labels,
|
||||||
{'nr_hidden': nr_hidden, 'max_length': max_length, 'nr_class': 2,
|
{'nr_hidden': nr_hidden, 'max_length': max_length, 'nr_class': 2,
|
||||||
'nr_vector': 2000, 'nr_dim': 32},
|
'nr_vector': 5000},
|
||||||
{'dropout': 0.5, 'lr': learn_rate},
|
{'dropout': 0.5, 'lr': learn_rate},
|
||||||
{},
|
{},
|
||||||
nb_epoch=nb_epoch, batch_size=batch_size)
|
nb_epoch=nb_epoch, batch_size=batch_size)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
|
#cProfile.runctx("plac.call(main)", globals(), locals(), "Profile.prof")
|
||||||
|
#s = pstats.Stats("Profile.prof")
|
||||||
|
#s.strip_dirs().sort_stats("time").print_stats()
|
||||||
plac.call(main)
|
plac.call(main)
|
||||||
|
|
Loading…
Reference in New Issue