* Bridge to Theano working. Very disorganised. Using thinc adb60aba966ed2

This commit is contained in:
Matthew Honnibal 2015-06-27 02:38:51 +02:00
parent 2fe98b8a9a
commit f8bb43475e
5 changed files with 291 additions and 7 deletions

255
bin/parser/nn_train.py Executable file
View File

@ -0,0 +1,255 @@
#!/usr/bin/env python
from __future__ import division
from __future__ import unicode_literals
import os
from os import path
import shutil
import codecs
import random
import plac
import cProfile
import pstats
import re
import spacy.util
from spacy.en import English
from spacy.en.pos import POS_TEMPLATES, POS_TAGS, setup_model_dir
from spacy.syntax.util import Config
from spacy.gold import read_json_file
from spacy.gold import GoldParse
from spacy.scorer import Scorer
from thinc.theano_nn import compile_theano_model
from spacy.syntax.parser import Parser
from spacy._theano import TheanoModel
def _corrupt(c, noise_level):
if random.random() >= noise_level:
return c
elif c == ' ':
return '\n'
elif c == '\n':
return ' '
elif c in ['.', "'", "!", "?"]:
return ''
else:
return c.lower()
def add_noise(orig, noise_level):
if random.random() >= noise_level:
return orig
elif type(orig) == list:
corrupted = [_corrupt(word, noise_level) for word in orig]
corrupted = [w for w in corrupted if w]
return corrupted
else:
return ''.join(_corrupt(c, noise_level) for c in orig)
def score_model(scorer, nlp, raw_text, annot_tuples, verbose=False):
if raw_text is None:
tokens = nlp.tokenizer.tokens_from_list(annot_tuples[1])
else:
tokens = nlp.tokenizer(raw_text)
nlp.tagger(tokens)
nlp.entity(tokens)
nlp.parser(tokens)
gold = GoldParse(tokens, annot_tuples)
scorer.score(tokens, gold, verbose=verbose)
def _merge_sents(sents):
m_deps = [[], [], [], [], [], []]
m_brackets = []
i = 0
for (ids, words, tags, heads, labels, ner), brackets in sents:
m_deps[0].extend(id_ + i for id_ in ids)
m_deps[1].extend(words)
m_deps[2].extend(tags)
m_deps[3].extend(head + i for head in heads)
m_deps[4].extend(labels)
m_deps[5].extend(ner)
m_brackets.extend((b['first'] + i, b['last'] + i, b['label']) for b in brackets)
i += len(ids)
return [(m_deps, m_brackets)]
def train(Language, gold_tuples, model_dir, n_iter=15, feat_set=u'basic',
seed=0, gold_preproc=False, n_sents=0, corruption_level=0,
verbose=False,
eta=0.01, mu=0.9, n_hidden=100, word_vec_len=10, pos_vec_len=10):
dep_model_dir = path.join(model_dir, 'deps')
pos_model_dir = path.join(model_dir, 'pos')
ner_model_dir = path.join(model_dir, 'ner')
if path.exists(dep_model_dir):
shutil.rmtree(dep_model_dir)
if path.exists(pos_model_dir):
shutil.rmtree(pos_model_dir)
if path.exists(ner_model_dir):
shutil.rmtree(ner_model_dir)
os.mkdir(dep_model_dir)
os.mkdir(pos_model_dir)
os.mkdir(ner_model_dir)
setup_model_dir(sorted(POS_TAGS.keys()), POS_TAGS, POS_TEMPLATES, pos_model_dir)
Config.write(dep_model_dir, 'config', features=feat_set, seed=seed,
labels=Language.ParserTransitionSystem.get_labels(gold_tuples))
Config.write(ner_model_dir, 'config', features='ner', seed=seed,
labels=Language.EntityTransitionSystem.get_labels(gold_tuples),
beam_width=0)
if n_sents > 0:
gold_tuples = gold_tuples[:n_sents]
nlp = Language(data_dir=model_dir)
def make_model(n_classes, input_spec, model_dir):
print input_spec
n_in = sum(n_cols * len(fields) for (n_cols, fields) in input_spec)
print 'Compiling'
debug, train_func, predict_func = compile_theano_model(n_classes, n_hidden,
n_in, 0.0, 0.0)
print 'Done'
return TheanoModel(
n_classes,
input_spec,
train_func,
predict_func,
model_loc=model_dir,
debug=debug)
nlp._parser = Parser(nlp.vocab.strings, dep_model_dir, nlp.ParserTransitionSystem,
make_model)
print "Itn.\tP.Loss\tUAS\tNER F.\tTag %\tToken %"
for itn in range(n_iter):
scorer = Scorer()
loss = 0
for raw_text, sents in gold_tuples:
if gold_preproc:
raw_text = None
else:
sents = _merge_sents(sents)
for annot_tuples, ctnt in sents:
if len(annot_tuples[1]) == 1:
continue
score_model(scorer, nlp, raw_text, annot_tuples,
verbose=verbose if itn >= 2 else False)
if raw_text is None:
words = add_noise(annot_tuples[1], corruption_level)
tokens = nlp.tokenizer.tokens_from_list(words)
else:
raw_text = add_noise(raw_text, corruption_level)
tokens = nlp.tokenizer(raw_text)
nlp.tagger(tokens)
gold = GoldParse(tokens, annot_tuples, make_projective=True)
if not gold.is_projective:
raise Exception(
"Non-projective sentence in training, after we should "
"have enforced projectivity: %s" % annot_tuples
)
loss += nlp.parser.train(tokens, gold)
nlp.entity.train(tokens, gold)
nlp.tagger.train(tokens, gold.tags)
random.shuffle(gold_tuples)
print '%d:\t%d\t%.3f\t%.3f\t%.3f\t%.3f' % (itn, loss, scorer.uas, scorer.ents_f,
scorer.tags_acc,
scorer.token_acc)
nlp.parser.model.end_training()
nlp.entity.model.end_training()
nlp.tagger.model.end_training()
nlp.vocab.strings.dump(path.join(model_dir, 'vocab', 'strings.txt'))
return nlp
def evaluate(nlp, gold_tuples, gold_preproc=True):
scorer = Scorer()
for raw_text, sents in gold_tuples:
if gold_preproc:
raw_text = None
else:
sents = _merge_sents(sents)
for annot_tuples, brackets in sents:
if raw_text is None:
tokens = nlp.tokenizer.tokens_from_list(annot_tuples[1])
nlp.tagger(tokens)
nlp.entity(tokens)
nlp.parser(tokens)
else:
tokens = nlp(raw_text, merge_mwes=False)
gold = GoldParse(tokens, annot_tuples)
scorer.score(tokens, gold)
return scorer
def write_parses(Language, dev_loc, model_dir, out_loc, beam_width=None):
nlp = Language(data_dir=model_dir)
if beam_width is not None:
nlp.parser.cfg.beam_width = beam_width
gold_tuples = read_json_file(dev_loc)
scorer = Scorer()
out_file = codecs.open(out_loc, 'w', 'utf8')
for raw_text, sents in gold_tuples:
sents = _merge_sents(sents)
for annot_tuples, brackets in sents:
if raw_text is None:
tokens = nlp.tokenizer.tokens_from_list(annot_tuples[1])
nlp.tagger(tokens)
nlp.entity(tokens)
nlp.parser(tokens)
else:
tokens = nlp(raw_text, merge_mwes=False)
gold = GoldParse(tokens, annot_tuples)
scorer.score(tokens, gold, verbose=False)
for t in tokens:
out_file.write(
'%s\t%s\t%s\t%s\n' % (t.orth_, t.tag_, t.head.orth_, t.dep_)
)
return scorer
@plac.annotations(
train_loc=("Location of training file or directory"),
dev_loc=("Location of development file or directory"),
model_dir=("Location of output model directory",),
eval_only=("Skip training, and only evaluate", "flag", "e", bool),
corruption_level=("Amount of noise to add to training data", "option", "c", float),
gold_preproc=("Use gold-standard sentence boundaries in training?", "flag", "g", bool),
out_loc=("Out location", "option", "o", str),
n_sents=("Number of training sentences", "option", "n", int),
n_iter=("Number of training iterations", "option", "i", int),
verbose=("Verbose error reporting", "flag", "v", bool),
debug=("Debug mode", "flag", "d", bool),
)
def main(train_loc, dev_loc, model_dir, n_sents=0, n_iter=15, out_loc="", verbose=False,
debug=False, corruption_level=0.0, gold_preproc=False, beam_width=1,
eval_only=False):
gold_train = list(read_json_file(train_loc))
nlp = train(English, gold_train, model_dir,
feat_set='embed',
gold_preproc=gold_preproc, n_sents=n_sents,
corruption_level=corruption_level, n_iter=n_iter,
verbose=verbose)
#if out_loc:
# write_parses(English, dev_loc, model_dir, out_loc, beam_width=beam_width)
scorer = evaluate(nlp, list(read_json_file(dev_loc)), gold_preproc=gold_preproc)
print 'TOK', 100-scorer.token_acc
print 'POS', scorer.tags_acc
print 'UAS', scorer.uas
print 'LAS', scorer.las
print 'NER P', scorer.ents_p
print 'NER R', scorer.ents_r
print 'NER F', scorer.ents_f
if __name__ == '__main__':
plac.call(main)

13
spacy/_theano.pxd Normal file
View File

@ -0,0 +1,13 @@
from ._ml cimport Model
from thinc.nn cimport InputLayer
cdef class TheanoModel(Model):
cdef InputLayer input_layer
cdef object train_func
cdef object predict_func
cdef object debug
cdef public float eta
cdef public float mu
cdef public float t

View File

@ -9,7 +9,8 @@ from os import path
cdef class TheanoModel(Model):
def __init__(self, n_classes, input_spec, train_func, predict_func, model_loc=None):
def __init__(self, n_classes, input_spec, train_func, predict_func, model_loc=None,
debug=None):
if model_loc is not None and path.isdir(model_loc):
model_loc = path.join(model_loc, 'model')
@ -20,6 +21,7 @@ cdef class TheanoModel(Model):
self.input_layer = InputLayer(input_spec, initializer)
self.train_func = train_func
self.predict_func = predict_func
self.debug = debug
self.n_classes = n_classes
self.n_feats = len(self.input_layer)
@ -27,7 +29,7 @@ cdef class TheanoModel(Model):
def predict(self, Example eg):
self.input_layer.fill(eg.embeddings, eg.atoms)
theano_scores = self.predict_func(eg.embeddings)
theano_scores = self.predict_func(eg.embeddings)[0]
cdef int i
for i in range(self.n_classes):
eg.scores[i] = theano_scores[i]
@ -35,10 +37,17 @@ cdef class TheanoModel(Model):
self.n_classes)
def train(self, Example eg):
self.predict(eg)
update, t, eta, mu = self.train_func(eg.embeddings, eg.scores, eg.costs)
self.input_layer.update(eg.atoms, update, self.t, self.eta, self.mu)
self.input_layer.fill(eg.embeddings, eg.atoms)
theano_scores, update, y = self.train_func(eg.embeddings, eg.costs, self.eta)
self.input_layer.update(update, eg.atoms, self.t, self.eta, self.mu)
for i in range(self.n_classes):
eg.scores[i] = theano_scores[i]
eg.guess = arg_max_if_true(<weight_t*>eg.scores.data, <int*>eg.is_valid.data,
self.n_classes)
eg.best = arg_max_if_zero(<weight_t*>eg.scores.data, <int*>eg.costs.data,
self.n_classes)
eg.cost = eg.costs[eg.guess]
self.t += 1
def end_training(self):
pass

View File

@ -355,3 +355,7 @@ trigrams = (
(N0W, N0p, N0lL, N0l2L),
(N0p, N0lL, N0l2L),
)
words = (S0w, N0w, S1w, N1w)
tags = (S0p, N0p, S1p, N1p)
labels = (S0L, N0L, S1L, S2L)

View File

@ -52,18 +52,21 @@ def get_templates(name):
return pf.ner
elif name == 'debug':
return pf.unigrams
elif name.startswith('embed'):
return ((10, pf.words), (10, pf.tags), (10, pf.labels))
else:
return (pf.unigrams + pf.s0_n0 + pf.s1_n0 + pf.s1_s0 + pf.s0_n1 + pf.n0_n1 + \
pf.tree_shape + pf.trigrams)
cdef class Parser:
def __init__(self, StringStore strings, model_dir, transition_system):
def __init__(self, StringStore strings, model_dir, transition_system,
get_model=Model):
assert os.path.exists(model_dir) and os.path.isdir(model_dir)
self.cfg = Config.read(model_dir, 'config')
self.moves = transition_system(strings, self.cfg.labels)
templates = get_templates(self.cfg.features)
self.model = Model(self.moves.n_moves, templates, model_dir)
self.model = get_model(self.moves.n_moves, templates, model_dir)
def __call__(self, Tokens tokens):
cdef StateClass stcls = StateClass.init(tokens.data, tokens.length)