diff --git a/bin/parser/train.py b/bin/parser/train.py index 998c74819..bb66a2969 100755 --- a/bin/parser/train.py +++ b/bin/parser/train.py @@ -21,7 +21,8 @@ from spacy.en.pos import POS_TEMPLATES, POS_TAGS, setup_model_dir from spacy.syntax.parser import GreedyParser from spacy.syntax.parser import OracleError from spacy.syntax.util import Config -from spacy.syntax.conll import GoldParse, is_punct_label +from spacy.syntax.conll import read_docparse_file +from spacy.syntax.conll import GoldParse def is_punct_label(label): @@ -183,47 +184,56 @@ def get_labels(sents): return list(sorted(left_labels)), list(sorted(right_labels)) -def train(Language, paragraphs, model_dir, n_iter=15, feat_set=u'basic', seed=0, - gold_preproc=False, force_gold=False): +def train(Language, train_loc, model_dir, n_iter=15, feat_set=u'basic', seed=0, + gold_preproc=False, force_gold=False, n_sents=0): print "Setup model dir" dep_model_dir = path.join(model_dir, 'deps') pos_model_dir = path.join(model_dir, 'pos') + ner_model_dir = path.join(model_dir, 'ner') if path.exists(dep_model_dir): shutil.rmtree(dep_model_dir) if path.exists(pos_model_dir): shutil.rmtree(pos_model_dir) + if path.exists(ner_model_dir): + shutil.rmtree(ner_model_dir) os.mkdir(dep_model_dir) os.mkdir(pos_model_dir) - setup_model_dir(sorted(POS_TAGS.keys()), POS_TAGS, POS_TEMPLATES, - pos_model_dir) + os.mkdir(ner_model_dir) + + setup_model_dir(sorted(POS_TAGS.keys()), POS_TAGS, POS_TEMPLATES, pos_model_dir) + + gold_tuples = read_docparse_file(train_loc) - labels = Language.ParserTransitionSystem.get_labels(gold_sents) Config.write(dep_model_dir, 'config', features=feat_set, seed=seed, - labels=labels) + labels=Language.ParserTransitionSystem.get_labels(gold_tuples)) + Config.write(ner_model_dir, 'config', features=feat_set, seed=seed, + labels=Language.EntityTransitionSystem.get_labels(gold_tuples)) + nlp = Language() for itn in range(n_iter): - heads_corr = 0 + dep_corr = 0 pos_corr = 0 n_tokens = 0 - n_all_tokens = 0 - for gold_sent in gold_sents: + for raw_text, segmented_text, annot_tuples in gold_tuples: if gold_preproc: - #print ' '.join(gold_sent.words) - tokens = nlp.tokenizer.tokens_from_list(gold_sent.words) - gold_sent.map_heads(nlp.parser.moves.label_ids) + sents = [nlp.tokenizer.tokens_from_list(s) for s in segmented_text] else: - tokens = nlp.tokenizer(gold_sent.raw_text) - gold_sent.align_to_tokens(tokens, nlp.parser.moves.label_ids) - nlp.tagger(tokens) - heads_corr += nlp.parser.train(tokens, gold_sent, force_gold=force_gold) - pos_corr += nlp.tagger.train(tokens, gold_sent.tags) - n_tokens += gold_sent.n_non_punct - n_all_tokens += len(tokens) - acc = float(heads_corr) / n_tokens - pos_acc = float(pos_corr) / n_all_tokens + sents = [nlp.tokenizer(raw_text)] + for tokens in sents: + + gold = GoldParse(tokens, annot_tuples, nlp.tags, + nlp.parser.moves.label_ids, + nlp.entity.moves.label_ids) + + nlp.tagger(tokens) + dep_corr += nlp.parser.train(tokens, gold, force_gold=force_gold) + pos_corr += nlp.tagger.train(tokens, gold.tags_) + n_tokens += len(tokens) + acc = float(dep_corr) / n_tokens + pos_acc = float(pos_corr) / n_tokens print '%d: ' % itn, '%.3f' % acc, '%.3f' % pos_acc - random.shuffle(gold_sents) + random.shuffle(gold_tuples) nlp.parser.model.end_training() nlp.tagger.model.end_training() return acc @@ -239,22 +249,22 @@ def evaluate(Language, dev_loc, model_dir, gold_preproc=False): total = 0 skipped = 0 loss = 0 - with codecs.open(dev_loc, 'r', 'utf8') as file_: - #paragraphs = read_tokenized_gold(file_) - paragraphs = read_docparse_gold(file_) - for tokens, tag_strs, heads, labels in iter_data(paragraphs, nlp.tokenizer, - gold_preproc=gold_preproc): - assert len(tokens) == len(labels) - nlp.tagger(tokens) - nlp.parser(tokens) + gold_tuples = read_docparse_file(train_loc) + for raw_text, segmented_text, annot_tuples in gold_tuples: + if gold_preproc: + tokens = nlp.tokenizer.tokens_from_list(gold_sent.words) + nlp.tagger(tokens) + nlp.parser(tokens) + gold_sent.map_heads(nlp.parser.moves.label_ids) + else: + tokens = nlp(gold_sent.raw_text) + loss += gold_sent.align_to_tokens(tokens, nlp.parser.moves.label_ids) for i, token in enumerate(tokens): pos_corr += token.tag_ == gold_sent.tags[i] n_tokens += 1 if gold_sent.heads[i] is None: skipped += 1 continue - #print i, token.orth_, token.head.i, gold_sent.py_heads[i], gold_sent.labels[i], - #print gold_sent.is_correct(i, token.head.i) if gold_sent.labels[i] != 'P': n_corr += gold_sent.is_correct(i, token.head.i) total += 1 @@ -263,12 +273,6 @@ def evaluate(Language, dev_loc, model_dir, gold_preproc=False): return float(n_corr) / (total + loss) -def read_gold(loc, n=0): - sent_strs = open(loc).read().strip().split('\n\n') - if n == 0: - n = len(sent_strs) - return [GoldParse.from_docparse(sent) for sent in sent_strs[:n]] - @plac.annotations( train_loc=("Training file location",), @@ -277,9 +281,9 @@ def read_gold(loc, n=0): n_sents=("Number of training sentences", "option", "n", int) ) def main(train_loc, dev_loc, model_dir, n_sents=0): - #train(English, read_gold(train_loc, n=n_sents), model_dir, - # gold_preproc=False, force_gold=False) - print evaluate(English, read_gold(dev_loc), model_dir, gold_preproc=False) + train(English, train_loc, model_dir, + gold_preproc=False, force_gold=False, n_sents=n_sents) + print evaluate(English, dev_loc, model_dir, gold_preproc=False) if __name__ == '__main__':