diff --git a/examples/spacy_dynet_lstm.py b/examples/spacy_dynet_lstm.py index ca3b36056..4ddc7d2a7 100644 --- a/examples/spacy_dynet_lstm.py +++ b/examples/spacy_dynet_lstm.py @@ -118,6 +118,42 @@ class BiTagger(object): tags.append(self.vt.i2w[np.argmax(out.npvalue())]) return tags + def predict_batch(self, words_batch): + dynet.renew_cg() + length = max(len(words) for words in words_batch) + word_ids = np.zeros((length, len(words_batch)), dtype='int32') + for j, words in enumerate(words_batch): + for i, word in enumerate(words): + word_ids[i, j] = self.vw.w2i.get(word, self.UNK) + wembs = [dynet.lookup_batch(self._E, word_ids[i]) for i in range(length)] + + f_state = self._fwd_lstm.initial_state() + b_state = self._bwd_lstm.initial_state() + + fw = [x.output() for x in f_state.add_inputs(wembs)] + bw = [x.output() for x in b_state.add_inputs(reversed(wembs))] + + H = dynet.parameter(self._pH) + O = dynet.parameter(self._pO) + + tags_batch = [[] for _ in range(len(words_batch))] + for i, (f, b) in enumerate(zip(fw, reversed(bw))): + r_t = O * (dynet.tanh(H * dynet.concatenate([f, b]))) + out = dynet.softmax(r_t).npvalue() + for j in range(len(words_batch)): + tags_batch[j].append(self.vt.i2w[np.argmax(out.T[j])]) + return tags_batch + + def pipe(self, sentences): + batch = [] + for words in sentences: + batch.append(words) + if len(batch) == self._minibatch_size: + tags_batch = self.predict_batch(batch) + for words, tags in zip(batch, tags_batch): + yield tags + batch = [] + def update(self, words, tags): self._words_batch.append(words) self._tags_batch.append(tags) @@ -193,10 +229,9 @@ def main(train_loc, dev_loc, model_dir): tagged = 0 if i % 10000 == 0: good = bad = 0.0 - for sent in test: - #word_ids = [vw.w2i.get(w, UNK) for w, t in sent] - tags = tagger([w for w, t in sent]) - golds = [t for w, t in sent] + word_sents = [[w for w, t in sent] for sent in test] + gold_sents = [[t for w, t in sent] for sent in test] + for words, tags, golds in zip(words, tagger.pipe(words), gold_sents): for go, gu in zip(golds, tags): if go == gu: good += 1