diff --git a/bin/parser/nn_train.py b/bin/parser/nn_train.py new file mode 100755 index 000000000..375996f4f --- /dev/null +++ b/bin/parser/nn_train.py @@ -0,0 +1,255 @@ +#!/usr/bin/env python +from __future__ import division +from __future__ import unicode_literals + +import os +from os import path +import shutil +import codecs +import random + +import plac +import cProfile +import pstats +import re + +import spacy.util +from spacy.en import English +from spacy.en.pos import POS_TEMPLATES, POS_TAGS, setup_model_dir + +from spacy.syntax.util import Config +from spacy.gold import read_json_file +from spacy.gold import GoldParse + +from spacy.scorer import Scorer + +from thinc.theano_nn import compile_theano_model + +from spacy.syntax.parser import Parser +from spacy._theano import TheanoModel + + +def _corrupt(c, noise_level): + if random.random() >= noise_level: + return c + elif c == ' ': + return '\n' + elif c == '\n': + return ' ' + elif c in ['.', "'", "!", "?"]: + return '' + else: + return c.lower() + + +def add_noise(orig, noise_level): + if random.random() >= noise_level: + return orig + elif type(orig) == list: + corrupted = [_corrupt(word, noise_level) for word in orig] + corrupted = [w for w in corrupted if w] + return corrupted + else: + return ''.join(_corrupt(c, noise_level) for c in orig) + + +def score_model(scorer, nlp, raw_text, annot_tuples, verbose=False): + if raw_text is None: + tokens = nlp.tokenizer.tokens_from_list(annot_tuples[1]) + else: + tokens = nlp.tokenizer(raw_text) + nlp.tagger(tokens) + nlp.entity(tokens) + nlp.parser(tokens) + gold = GoldParse(tokens, annot_tuples) + scorer.score(tokens, gold, verbose=verbose) + + +def _merge_sents(sents): + m_deps = [[], [], [], [], [], []] + m_brackets = [] + i = 0 + for (ids, words, tags, heads, labels, ner), brackets in sents: + m_deps[0].extend(id_ + i for id_ in ids) + m_deps[1].extend(words) + m_deps[2].extend(tags) + m_deps[3].extend(head + i for head in heads) + m_deps[4].extend(labels) + m_deps[5].extend(ner) + m_brackets.extend((b['first'] + i, b['last'] + i, b['label']) for b in brackets) + i += len(ids) + return [(m_deps, m_brackets)] + + +def train(Language, gold_tuples, model_dir, n_iter=15, feat_set=u'basic', + seed=0, gold_preproc=False, n_sents=0, corruption_level=0, + verbose=False, + eta=0.01, mu=0.9, n_hidden=100, word_vec_len=10, pos_vec_len=10): + dep_model_dir = path.join(model_dir, 'deps') + pos_model_dir = path.join(model_dir, 'pos') + ner_model_dir = path.join(model_dir, 'ner') + if path.exists(dep_model_dir): + shutil.rmtree(dep_model_dir) + if path.exists(pos_model_dir): + shutil.rmtree(pos_model_dir) + if path.exists(ner_model_dir): + shutil.rmtree(ner_model_dir) + os.mkdir(dep_model_dir) + os.mkdir(pos_model_dir) + os.mkdir(ner_model_dir) + setup_model_dir(sorted(POS_TAGS.keys()), POS_TAGS, POS_TEMPLATES, pos_model_dir) + + Config.write(dep_model_dir, 'config', features=feat_set, seed=seed, + labels=Language.ParserTransitionSystem.get_labels(gold_tuples)) + Config.write(ner_model_dir, 'config', features='ner', seed=seed, + labels=Language.EntityTransitionSystem.get_labels(gold_tuples), + beam_width=0) + + if n_sents > 0: + gold_tuples = gold_tuples[:n_sents] + + nlp = Language(data_dir=model_dir) + + def make_model(n_classes, input_spec, model_dir): + print input_spec + n_in = sum(n_cols * len(fields) for (n_cols, fields) in input_spec) + print 'Compiling' + debug, train_func, predict_func = compile_theano_model(n_classes, n_hidden, + n_in, 0.0, 0.0) + print 'Done' + return TheanoModel( + n_classes, + input_spec, + train_func, + predict_func, + model_loc=model_dir, + debug=debug) + + nlp._parser = Parser(nlp.vocab.strings, dep_model_dir, nlp.ParserTransitionSystem, + make_model) + + print "Itn.\tP.Loss\tUAS\tNER F.\tTag %\tToken %" + for itn in range(n_iter): + scorer = Scorer() + loss = 0 + for raw_text, sents in gold_tuples: + if gold_preproc: + raw_text = None + else: + sents = _merge_sents(sents) + for annot_tuples, ctnt in sents: + if len(annot_tuples[1]) == 1: + continue + score_model(scorer, nlp, raw_text, annot_tuples, + verbose=verbose if itn >= 2 else False) + if raw_text is None: + words = add_noise(annot_tuples[1], corruption_level) + tokens = nlp.tokenizer.tokens_from_list(words) + else: + raw_text = add_noise(raw_text, corruption_level) + tokens = nlp.tokenizer(raw_text) + nlp.tagger(tokens) + gold = GoldParse(tokens, annot_tuples, make_projective=True) + if not gold.is_projective: + raise Exception( + "Non-projective sentence in training, after we should " + "have enforced projectivity: %s" % annot_tuples + ) + loss += nlp.parser.train(tokens, gold) + nlp.entity.train(tokens, gold) + nlp.tagger.train(tokens, gold.tags) + random.shuffle(gold_tuples) + print '%d:\t%d\t%.3f\t%.3f\t%.3f\t%.3f' % (itn, loss, scorer.uas, scorer.ents_f, + scorer.tags_acc, + scorer.token_acc) + nlp.parser.model.end_training() + nlp.entity.model.end_training() + nlp.tagger.model.end_training() + nlp.vocab.strings.dump(path.join(model_dir, 'vocab', 'strings.txt')) + return nlp + + +def evaluate(nlp, gold_tuples, gold_preproc=True): + scorer = Scorer() + for raw_text, sents in gold_tuples: + if gold_preproc: + raw_text = None + else: + sents = _merge_sents(sents) + for annot_tuples, brackets in sents: + if raw_text is None: + tokens = nlp.tokenizer.tokens_from_list(annot_tuples[1]) + nlp.tagger(tokens) + nlp.entity(tokens) + nlp.parser(tokens) + else: + tokens = nlp(raw_text, merge_mwes=False) + gold = GoldParse(tokens, annot_tuples) + scorer.score(tokens, gold) + return scorer + + +def write_parses(Language, dev_loc, model_dir, out_loc, beam_width=None): + nlp = Language(data_dir=model_dir) + if beam_width is not None: + nlp.parser.cfg.beam_width = beam_width + gold_tuples = read_json_file(dev_loc) + scorer = Scorer() + out_file = codecs.open(out_loc, 'w', 'utf8') + for raw_text, sents in gold_tuples: + sents = _merge_sents(sents) + for annot_tuples, brackets in sents: + if raw_text is None: + tokens = nlp.tokenizer.tokens_from_list(annot_tuples[1]) + nlp.tagger(tokens) + nlp.entity(tokens) + nlp.parser(tokens) + else: + tokens = nlp(raw_text, merge_mwes=False) + gold = GoldParse(tokens, annot_tuples) + scorer.score(tokens, gold, verbose=False) + for t in tokens: + out_file.write( + '%s\t%s\t%s\t%s\n' % (t.orth_, t.tag_, t.head.orth_, t.dep_) + ) + return scorer + + +@plac.annotations( + train_loc=("Location of training file or directory"), + dev_loc=("Location of development file or directory"), + model_dir=("Location of output model directory",), + eval_only=("Skip training, and only evaluate", "flag", "e", bool), + corruption_level=("Amount of noise to add to training data", "option", "c", float), + gold_preproc=("Use gold-standard sentence boundaries in training?", "flag", "g", bool), + out_loc=("Out location", "option", "o", str), + n_sents=("Number of training sentences", "option", "n", int), + n_iter=("Number of training iterations", "option", "i", int), + verbose=("Verbose error reporting", "flag", "v", bool), + debug=("Debug mode", "flag", "d", bool), +) +def main(train_loc, dev_loc, model_dir, n_sents=0, n_iter=15, out_loc="", verbose=False, + debug=False, corruption_level=0.0, gold_preproc=False, beam_width=1, + eval_only=False): + gold_train = list(read_json_file(train_loc)) + nlp = train(English, gold_train, model_dir, + feat_set='embed', + gold_preproc=gold_preproc, n_sents=n_sents, + corruption_level=corruption_level, n_iter=n_iter, + verbose=verbose) + #if out_loc: + # write_parses(English, dev_loc, model_dir, out_loc, beam_width=beam_width) + scorer = evaluate(nlp, list(read_json_file(dev_loc)), gold_preproc=gold_preproc) + + print 'TOK', 100-scorer.token_acc + print 'POS', scorer.tags_acc + print 'UAS', scorer.uas + print 'LAS', scorer.las + + print 'NER P', scorer.ents_p + print 'NER R', scorer.ents_r + print 'NER F', scorer.ents_f + + +if __name__ == '__main__': + plac.call(main) diff --git a/spacy/_theano.pxd b/spacy/_theano.pxd new file mode 100644 index 000000000..cad0736c2 --- /dev/null +++ b/spacy/_theano.pxd @@ -0,0 +1,13 @@ +from ._ml cimport Model +from thinc.nn cimport InputLayer + + +cdef class TheanoModel(Model): + cdef InputLayer input_layer + cdef object train_func + cdef object predict_func + cdef object debug + + cdef public float eta + cdef public float mu + cdef public float t diff --git a/spacy/_theano.pyx b/spacy/_theano.pyx index 702208d18..b791c4f42 100644 --- a/spacy/_theano.pyx +++ b/spacy/_theano.pyx @@ -9,7 +9,8 @@ from os import path cdef class TheanoModel(Model): - def __init__(self, n_classes, input_spec, train_func, predict_func, model_loc=None): + def __init__(self, n_classes, input_spec, train_func, predict_func, model_loc=None, + debug=None): if model_loc is not None and path.isdir(model_loc): model_loc = path.join(model_loc, 'model') @@ -20,6 +21,7 @@ cdef class TheanoModel(Model): self.input_layer = InputLayer(input_spec, initializer) self.train_func = train_func self.predict_func = predict_func + self.debug = debug self.n_classes = n_classes self.n_feats = len(self.input_layer) @@ -27,7 +29,7 @@ cdef class TheanoModel(Model): def predict(self, Example eg): self.input_layer.fill(eg.embeddings, eg.atoms) - theano_scores = self.predict_func(eg.embeddings) + theano_scores = self.predict_func(eg.embeddings)[0] cdef int i for i in range(self.n_classes): eg.scores[i] = theano_scores[i] @@ -35,10 +37,17 @@ cdef class TheanoModel(Model): self.n_classes) def train(self, Example eg): - self.predict(eg) - update, t, eta, mu = self.train_func(eg.embeddings, eg.scores, eg.costs) - self.input_layer.update(eg.atoms, update, self.t, self.eta, self.mu) + self.input_layer.fill(eg.embeddings, eg.atoms) + theano_scores, update, y = self.train_func(eg.embeddings, eg.costs, self.eta) + self.input_layer.update(update, eg.atoms, self.t, self.eta, self.mu) + for i in range(self.n_classes): + eg.scores[i] = theano_scores[i] + eg.guess = arg_max_if_true(eg.scores.data, eg.is_valid.data, + self.n_classes) eg.best = arg_max_if_zero(eg.scores.data, eg.costs.data, self.n_classes) eg.cost = eg.costs[eg.guess] self.t += 1 + + def end_training(self): + pass diff --git a/spacy/syntax/_parse_features.pyx b/spacy/syntax/_parse_features.pyx index efefc7273..1adeaef83 100644 --- a/spacy/syntax/_parse_features.pyx +++ b/spacy/syntax/_parse_features.pyx @@ -355,3 +355,7 @@ trigrams = ( (N0W, N0p, N0lL, N0l2L), (N0p, N0lL, N0l2L), ) + +words = (S0w, N0w, S1w, N1w) +tags = (S0p, N0p, S1p, N1p) +labels = (S0L, N0L, S1L, S2L) diff --git a/spacy/syntax/parser.pyx b/spacy/syntax/parser.pyx index 33ae5b497..66d598b88 100644 --- a/spacy/syntax/parser.pyx +++ b/spacy/syntax/parser.pyx @@ -52,18 +52,21 @@ def get_templates(name): return pf.ner elif name == 'debug': return pf.unigrams + elif name.startswith('embed'): + return ((10, pf.words), (10, pf.tags), (10, pf.labels)) else: return (pf.unigrams + pf.s0_n0 + pf.s1_n0 + pf.s1_s0 + pf.s0_n1 + pf.n0_n1 + \ pf.tree_shape + pf.trigrams) cdef class Parser: - def __init__(self, StringStore strings, model_dir, transition_system): + def __init__(self, StringStore strings, model_dir, transition_system, + get_model=Model): assert os.path.exists(model_dir) and os.path.isdir(model_dir) self.cfg = Config.read(model_dir, 'config') self.moves = transition_system(strings, self.cfg.labels) templates = get_templates(self.cfg.features) - self.model = Model(self.moves.n_moves, templates, model_dir) + self.model = get_model(self.moves.n_moves, templates, model_dir) def __call__(self, Tokens tokens): cdef StateClass stcls = StateClass.init(tokens.data, tokens.length)