Minibatch the forward pass. THe output argmax is incorrect...

This commit is contained in:
Matthew Honnibal 2016-11-16 06:15:28 -06:00
parent 8f053fd943
commit 718e66a7b9
1 changed files with 39 additions and 4 deletions

View File

@ -118,6 +118,42 @@ class BiTagger(object):
tags.append(self.vt.i2w[np.argmax(out.npvalue())])
return tags
def predict_batch(self, words_batch):
dynet.renew_cg()
length = max(len(words) for words in words_batch)
word_ids = np.zeros((length, len(words_batch)), dtype='int32')
for j, words in enumerate(words_batch):
for i, word in enumerate(words):
word_ids[i, j] = self.vw.w2i.get(word, self.UNK)
wembs = [dynet.lookup_batch(self._E, word_ids[i]) for i in range(length)]
f_state = self._fwd_lstm.initial_state()
b_state = self._bwd_lstm.initial_state()
fw = [x.output() for x in f_state.add_inputs(wembs)]
bw = [x.output() for x in b_state.add_inputs(reversed(wembs))]
H = dynet.parameter(self._pH)
O = dynet.parameter(self._pO)
tags_batch = [[] for _ in range(len(words_batch))]
for i, (f, b) in enumerate(zip(fw, reversed(bw))):
r_t = O * (dynet.tanh(H * dynet.concatenate([f, b])))
out = dynet.softmax(r_t).npvalue()
for j in range(len(words_batch)):
tags_batch[j].append(self.vt.i2w[np.argmax(out.T[j])])
return tags_batch
def pipe(self, sentences):
batch = []
for words in sentences:
batch.append(words)
if len(batch) == self._minibatch_size:
tags_batch = self.predict_batch(batch)
for words, tags in zip(batch, tags_batch):
yield tags
batch = []
def update(self, words, tags):
self._words_batch.append(words)
self._tags_batch.append(tags)
@ -193,10 +229,9 @@ def main(train_loc, dev_loc, model_dir):
tagged = 0
if i % 10000 == 0:
good = bad = 0.0
for sent in test:
#word_ids = [vw.w2i.get(w, UNK) for w, t in sent]
tags = tagger([w for w, t in sent])
golds = [t for w, t in sent]
word_sents = [[w for w, t in sent] for sent in test]
gold_sents = [[t for w, t in sent] for sent in test]
for words, tags, golds in zip(words, tagger.pipe(words), gold_sents):
for go, gu in zip(golds, tags):
if go == gu:
good += 1