Work on DyNet example

This commit is contained in:
Matthew Honnibal 2016-11-13 13:45:21 +01:00
parent 89df91846c
commit ae681aa555
1 changed files with 190 additions and 0 deletions

View File

@ -0,0 +1,190 @@
from __future__ import print_function
from __future__ import unicode_literals
from __future__ import division
import pathlib
import plac
import random
from collections import Counter
import numpy as np
from collections import defaultdict
from itertools import count
import dynet
from dynet import cg
class Vocab:
def __init__(self, w2i=None):
if w2i is None: w2i = defaultdict(count(0).next)
self.w2i = dict(w2i)
self.i2w = {i:w for w,i in w2i.iteritems()}
@classmethod
def from_corpus(cls, corpus):
w2i = defaultdict(count(0).next)
for sent in corpus:
[w2i[word] for word in sent]
return Vocab(w2i)
def size(self):
return len(self.w2i.keys())
def read_data(path):
with path.open() as file_:
sent = []
for line in file_:
line = line.strip().split()
if not line:
if sent:
yield sent
sent = []
else:
pieces = line
w = pieces[1]
pos = pieces[3]
sent.append((w, pos))
def get_vocab(train, test):
words = []
tags = []
wc = Counter()
for s in train:
for w, p in s:
words.append(w)
tags.append(p)
wc[w] += 1
words.append("_UNK_")
#words=[w if wc[w] > 1 else "_UNK_" for w in words]
tags.append("_START_")
for s in test:
for w, p in s:
words.append(w)
vw = Vocab.from_corpus([words])
vt = Vocab.from_corpus([tags])
UNK = vw.w2i["_UNK_"]
return words, tags, wc, vw, vt
class BiTagger(object):
def __init__(self, nwords, ntags):
self.nwords = nwords
self.ntags = ntags
self._model = dynet.Model()
self._sgd = dynet.SimpleSGDTrainer(self._model)
self._E = self._model.add_lookup_parameters((self.nwords, 128))
self._p_t1 = self._model.add_lookup_parameters((self.ntags, 30))
self._pH = self._model.add_parameters((32, 50*2))
self._pO = self._model.add_parameters((self.ntags, 32))
self._fwd_lstm = dynet.LSTMBuilder(1, 128, 50, self._model)
self._bwd_lstm = dynet.LSTMBuilder(1, 128, 50, self._model)
def __call__(self, doc):
dynet.renew_cg()
wembs = [self._E[word.rank] for word in doc]
f_state = self._fwd_lstm.initial_state()
b_state = self._bwd_lstm.initial_state()
fw = [x.output() for x in f_state.add_inputs(wembs)]
bw = [x.output() for x in b_state.add_inputs(reversed(wembs))]
H = dynet.parameter(self._pH)
O = dynet.parameter(self._pO)
for i, (f, b) in enumerate(zip(fw, reversed(bw))):
r_t = O * (dynet.tanh(H * dynet.concatenate([f, b])))
out = dynet.softmax(r_t)
doc[i].tag = np.argmax(out.npvalue())
def update(self, doc, gold):
dynet.renew_cg()
wembs = [self._E[word.rank] for word in doc]
wembs = [dynet.noise(we, 0.1) for we in wembs]
f_state = self._fwd_lstm.initial_state()
b_state = self._bwd_lstm.initial_state()
fw = [x.output() for x in f_state.add_inputs(wembs)]
bw = [x.output() for x in b_state.add_inputs(reversed(wembs))]
H = dynet.parameter(self._pH)
O = dynet.parameter(self._pO)
errs = []
for f, b, t in zip(fw, reversed(bw), tags):
f_b = dynet.concatenate([f,b])
r_t = O * (dynet.tanh(H * f_b))
err = dynet.pickneglogsoftmax(r_t, t)
errs.append(err)
sum_errs = dynet.esum(errs)
squared = -sum_errs # * sum_errs
loss += sum_errs.scalar_value()
sum_errs.backward()
sgd.update()
def main(train_loc, dev_loc, model_dir):
train_loc = pathlib.Path(train_loc)
dev_loc = pathlib.Path(dev_loc)
train = list(read_data((train_loc)))
test = list(read_data(dev_loc))
tagger = BiTagger(vocab)
UNK = vw.w2i["_UNK_"]
nwords = vw.size()
ntags = vt.size()
model = dynet.Model()
sgd = dynet.SimpleSGDTrainer(model)
E = model.add_lookup_parameters((nwords, 128))
p_t1 = model.add_lookup_parameters((ntags, 30))
pH = model.add_parameters((32, 50*2))
pO = model.add_parameters((ntags, 32))
builders=[
dynet.LSTMBuilder(1, 128, 50, model),
dynet.LSTMBuilder(1, 128, 50, model),
]
tagged = loss = 0
for ITER in xrange(50):
random.shuffle(train)
for i, s in enumerate(train,1):
if i % 5000 == 0:
sgd.status()
print(loss / tagged)
loss = 0
tagged = 0
if i % 10000 == 0:
good = bad = 0.0
for sent in test:
word_ids = [vw.w2i.get(w, UNK) for w, t in sent]
tags = tagger.tag_sent(word_ids)
golds = [t for w, t in sent]
for go, gu in zip(golds, tags):
if go == gu:
good += 1
else:
bad += 1
print(good / (good+bad))
ws = [vw.w2i.get(w, UNK) for w,p in s]
ps = [vt.w2i[p] for w, p in s]
model.update(ws, ps)
if __name__ == '__main__':
plac.call(main)