diff --git a/examples/spacy_dynet_lstm.py b/examples/spacy_dynet_lstm.py new file mode 100644 index 000000000..003d54ee9 --- /dev/null +++ b/examples/spacy_dynet_lstm.py @@ -0,0 +1,190 @@ +from __future__ import print_function +from __future__ import unicode_literals +from __future__ import division + +import pathlib +import plac +import random +from collections import Counter +import numpy as np + +from collections import defaultdict +from itertools import count + +import dynet +from dynet import cg + + +class Vocab: + def __init__(self, w2i=None): + if w2i is None: w2i = defaultdict(count(0).next) + self.w2i = dict(w2i) + self.i2w = {i:w for w,i in w2i.iteritems()} + @classmethod + def from_corpus(cls, corpus): + w2i = defaultdict(count(0).next) + for sent in corpus: + [w2i[word] for word in sent] + return Vocab(w2i) + + def size(self): + return len(self.w2i.keys()) + + +def read_data(path): + with path.open() as file_: + sent = [] + for line in file_: + line = line.strip().split() + if not line: + if sent: + yield sent + sent = [] + else: + pieces = line + w = pieces[1] + pos = pieces[3] + sent.append((w, pos)) + + +def get_vocab(train, test): + words = [] + tags = [] + wc = Counter() + for s in train: + for w, p in s: + words.append(w) + tags.append(p) + wc[w] += 1 + words.append("_UNK_") + #words=[w if wc[w] > 1 else "_UNK_" for w in words] + tags.append("_START_") + + for s in test: + for w, p in s: + words.append(w) + vw = Vocab.from_corpus([words]) + vt = Vocab.from_corpus([tags]) + UNK = vw.w2i["_UNK_"] + return words, tags, wc, vw, vt + + +class BiTagger(object): + def __init__(self, nwords, ntags): + self.nwords = nwords + self.ntags = ntags + + self._model = dynet.Model() + self._sgd = dynet.SimpleSGDTrainer(self._model) + + self._E = self._model.add_lookup_parameters((self.nwords, 128)) + self._p_t1 = self._model.add_lookup_parameters((self.ntags, 30)) + + self._pH = self._model.add_parameters((32, 50*2)) + self._pO = self._model.add_parameters((self.ntags, 32)) + + self._fwd_lstm = dynet.LSTMBuilder(1, 128, 50, self._model) + self._bwd_lstm = dynet.LSTMBuilder(1, 128, 50, self._model) + + def __call__(self, doc): + dynet.renew_cg() + + wembs = [self._E[word.rank] for word in doc] + + f_state = self._fwd_lstm.initial_state() + b_state = self._bwd_lstm.initial_state() + + fw = [x.output() for x in f_state.add_inputs(wembs)] + bw = [x.output() for x in b_state.add_inputs(reversed(wembs))] + + H = dynet.parameter(self._pH) + O = dynet.parameter(self._pO) + + for i, (f, b) in enumerate(zip(fw, reversed(bw))): + r_t = O * (dynet.tanh(H * dynet.concatenate([f, b]))) + out = dynet.softmax(r_t) + doc[i].tag = np.argmax(out.npvalue()) + + def update(self, doc, gold): + dynet.renew_cg() + wembs = [self._E[word.rank] for word in doc] + wembs = [dynet.noise(we, 0.1) for we in wembs] + + f_state = self._fwd_lstm.initial_state() + b_state = self._bwd_lstm.initial_state() + + fw = [x.output() for x in f_state.add_inputs(wembs)] + bw = [x.output() for x in b_state.add_inputs(reversed(wembs))] + + H = dynet.parameter(self._pH) + O = dynet.parameter(self._pO) + + errs = [] + for f, b, t in zip(fw, reversed(bw), tags): + f_b = dynet.concatenate([f,b]) + r_t = O * (dynet.tanh(H * f_b)) + err = dynet.pickneglogsoftmax(r_t, t) + errs.append(err) + + sum_errs = dynet.esum(errs) + squared = -sum_errs # * sum_errs + loss += sum_errs.scalar_value() + sum_errs.backward() + sgd.update() + + +def main(train_loc, dev_loc, model_dir): + train_loc = pathlib.Path(train_loc) + dev_loc = pathlib.Path(dev_loc) + + train = list(read_data((train_loc))) + test = list(read_data(dev_loc)) + + tagger = BiTagger(vocab) + + UNK = vw.w2i["_UNK_"] + nwords = vw.size() + ntags = vt.size() + + model = dynet.Model() + sgd = dynet.SimpleSGDTrainer(model) + + E = model.add_lookup_parameters((nwords, 128)) + p_t1 = model.add_lookup_parameters((ntags, 30)) + + pH = model.add_parameters((32, 50*2)) + pO = model.add_parameters((ntags, 32)) + + builders=[ + dynet.LSTMBuilder(1, 128, 50, model), + dynet.LSTMBuilder(1, 128, 50, model), + ] + + tagged = loss = 0 + for ITER in xrange(50): + random.shuffle(train) + for i, s in enumerate(train,1): + if i % 5000 == 0: + sgd.status() + print(loss / tagged) + loss = 0 + tagged = 0 + if i % 10000 == 0: + good = bad = 0.0 + for sent in test: + word_ids = [vw.w2i.get(w, UNK) for w, t in sent] + tags = tagger.tag_sent(word_ids) + golds = [t for w, t in sent] + for go, gu in zip(golds, tags): + if go == gu: + good += 1 + else: + bad += 1 + print(good / (good+bad)) + ws = [vw.w2i.get(w, UNK) for w,p in s] + ps = [vt.w2i[p] for w, p in s] + model.update(ws, ps) + + +if __name__ == '__main__': + plac.call(main)